Check that version in snapshot matches the one in chunks

This commit is contained in:
Torsten Grote 2021-03-04 16:58:55 -03:00 committed by Chirayu Desai
parent be9a84d704
commit 61fe823a04
6 changed files with 32 additions and 14 deletions

View file

@ -18,11 +18,12 @@ internal abstract class AbstractChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)
protected suspend fun getAndDecryptChunk( protected suspend fun getAndDecryptChunk(
version: Int,
chunkId: String, chunkId: String,
streamReader: suspend (InputStream) -> Unit, streamReader: suspend (InputStream) -> Unit,
) { ) {
storagePlugin.getChunkInputStream(chunkId).use { inputStream -> storagePlugin.getChunkInputStream(chunkId).use { inputStream ->
inputStream.readVersion() inputStream.readVersion(version)
val ad = streamCrypto.getAssociatedDataForChunk(chunkId) val ad = streamCrypto.getAssociatedDataForChunk(chunkId)
streamCrypto.newDecryptingStream(streamKey, inputStream, ad).use { decryptedStream -> streamCrypto.newDecryptingStream(streamKey, inputStream, ad).use { decryptedStream ->
streamReader(decryptedStream) streamReader(decryptedStream)

View file

@ -25,6 +25,7 @@ internal class MultiChunkRestore(
) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) { ) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) {
suspend fun restore( suspend fun restore(
version: Int,
chunkMap: Map<String, RestorableChunk>, chunkMap: Map<String, RestorableChunk>,
files: Collection<RestorableFile>, files: Collection<RestorableFile>,
observer: RestoreObserver?, observer: RestoreObserver?,
@ -33,7 +34,7 @@ internal class MultiChunkRestore(
files.forEach { file -> files.forEach { file ->
try { try {
restoreFile(file, observer, "L") { outputStream -> restoreFile(file, observer, "L") { outputStream ->
writeChunks(file, chunkMap, outputStream) writeChunks(version, file, chunkMap, outputStream)
} }
restoredFiles++ restoredFiles++
} catch (e: Exception) { } catch (e: Exception) {
@ -47,6 +48,7 @@ internal class MultiChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)
private suspend fun writeChunks( private suspend fun writeChunks(
version: Int,
file: RestorableFile, file: RestorableFile,
chunkMap: Map<String, RestorableChunk>, chunkMap: Map<String, RestorableChunk>,
outputStream: OutputStream, outputStream: OutputStream,
@ -59,8 +61,8 @@ internal class MultiChunkRestore(
bytes += decryptedStream.copyTo(outputStream) bytes += decryptedStream.copyTo(outputStream)
} }
val isCached = isCached(chunkId) val isCached = isCached(chunkId)
if (isCached || otherFiles.size > 1) getAndCacheChunk(chunkId, chunkWriter) if (isCached || otherFiles.size > 1) getAndCacheChunk(version, chunkId, chunkWriter)
else getAndDecryptChunk(chunkId, chunkWriter) else getAndDecryptChunk(version, chunkId, chunkWriter)
otherFiles.remove(file) otherFiles.remove(file)
if (isCached && otherFiles.isEmpty()) removeCachedChunk(chunkId) if (isCached && otherFiles.isEmpty()) removeCachedChunk(chunkId)
@ -74,13 +76,14 @@ internal class MultiChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)
private suspend fun getAndCacheChunk( private suspend fun getAndCacheChunk(
version: Int,
chunkId: String, chunkId: String,
streamReader: suspend (InputStream) -> Unit, streamReader: suspend (InputStream) -> Unit,
) { ) {
val file = getChunkCacheFile(chunkId) val file = getChunkCacheFile(chunkId)
if (!file.isFile) { if (!file.isFile) {
FileOutputStream(file).use { outputStream -> FileOutputStream(file).use { outputStream ->
getAndDecryptChunk(chunkId) { decryptedStream -> getAndDecryptChunk(version, chunkId) { decryptedStream ->
decryptedStream.copyTo(outputStream) decryptedStream.copyTo(outputStream)
} }
} }

View file

@ -73,7 +73,7 @@ internal class Restore(
Log.e(TAG, "Decrypting and parsing $numSnapshots snapshots took $time") Log.e(TAG, "Decrypting and parsing $numSnapshots snapshots took $time")
} }
@Throws(IOException::class) @Throws(IOException::class, GeneralSecurityException::class)
suspend fun restoreBackupSnapshot(timestamp: Long, observer: RestoreObserver?) { suspend fun restoreBackupSnapshot(timestamp: Long, observer: RestoreObserver?) {
val snapshot = snapshotRetriever.getSnapshot(streamKey, timestamp) val snapshot = snapshotRetriever.getSnapshot(streamKey, timestamp)
restoreBackupSnapshot(snapshot, observer) restoreBackupSnapshot(snapshot, observer)
@ -88,17 +88,19 @@ internal class Restore(
observer?.onRestoreStart(filesTotal, totalSize) observer?.onRestoreStart(filesTotal, totalSize)
val split = FileSplitter.splitSnapshot(snapshot) val split = FileSplitter.splitSnapshot(snapshot)
val version = snapshot.version
var restoredFiles = 0 var restoredFiles = 0
val smallFilesDuration = measure { val smallFilesDuration = measure {
restoredFiles += zipChunkRestore.restore(split.zipChunks, observer) restoredFiles += zipChunkRestore.restore(version, split.zipChunks, observer)
} }
Log.e(TAG, "Restoring ${split.zipChunks.size} zip chunks took $smallFilesDuration.") Log.e(TAG, "Restoring ${split.zipChunks.size} zip chunks took $smallFilesDuration.")
val singleChunkDuration = measure { val singleChunkDuration = measure {
restoredFiles += singleChunkRestore.restore(split.singleChunks, observer) restoredFiles += singleChunkRestore.restore(version, split.singleChunks, observer)
} }
Log.e(TAG, "Restoring ${split.singleChunks.size} single chunks took $singleChunkDuration.") Log.e(TAG, "Restoring ${split.singleChunks.size} single chunks took $singleChunkDuration.")
val multiChunkDuration = measure { val multiChunkDuration = measure {
restoredFiles += multiChunkRestore.restore( restoredFiles += multiChunkRestore.restore(
version,
split.multiChunkMap, split.multiChunkMap,
split.multiChunkFiles, split.multiChunkFiles,
observer observer
@ -113,10 +115,13 @@ internal class Restore(
} }
@Throws(IOException::class) @Throws(IOException::class, GeneralSecurityException::class)
internal fun InputStream.readVersion() { internal fun InputStream.readVersion(expectedVersion: Int? = null) {
val version = read() val version = read()
if (version == -1) throw IOException() if (version == -1) throw IOException()
if (expectedVersion != null && version != expectedVersion) {
throw GeneralSecurityException("Expected version $expectedVersion, not $version")
}
if (version > Backup.VERSION) { if (version > Backup.VERSION) {
// TODO maybe throw a different exception here and tell the user? // TODO maybe throw a different exception here and tell the user?
throw IOException() throw IOException()

View file

@ -36,6 +36,7 @@ public abstract class RestoreService : Service() {
startForeground(NOTIFICATION_ID_RESTORE, n.getRestoreNotification()) startForeground(NOTIFICATION_ID_RESTORE, n.getRestoreNotification())
GlobalScope.launch { GlobalScope.launch {
// TODO offer a way to try again if failed, or do an automatic retry here
storageBackup.restoreBackupSnapshot(timestamp, restoreObserver) storageBackup.restoreBackupSnapshot(timestamp, restoreObserver)
stopSelf(startId) stopSelf(startId)
} }

View file

@ -15,13 +15,17 @@ internal class SingleChunkRestore(
streamKey: ByteArray streamKey: ByteArray
) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) { ) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) {
suspend fun restore(chunks: Collection<RestorableChunk>, observer: RestoreObserver?): Int { suspend fun restore(
version: Int,
chunks: Collection<RestorableChunk>,
observer: RestoreObserver?
): Int {
var restoredFiles = 0 var restoredFiles = 0
chunks.forEach { chunk -> chunks.forEach { chunk ->
check(chunk.files.size == 1) check(chunk.files.size == 1)
val file = chunk.files[0] val file = chunk.files[0]
try { try {
getAndDecryptChunk(chunk.chunkId) { decryptedStream -> getAndDecryptChunk(version, chunk.chunkId) { decryptedStream ->
restoreFile(file, observer, "M") { outputStream -> restoreFile(file, observer, "M") { outputStream ->
decryptedStream.copyTo(outputStream) decryptedStream.copyTo(outputStream)
} }

View file

@ -22,11 +22,15 @@ internal class ZipChunkRestore(
/** /**
* Assumes that files in [zipChunks] are sorted by zipIndex with no duplicate indices. * Assumes that files in [zipChunks] are sorted by zipIndex with no duplicate indices.
*/ */
suspend fun restore(zipChunks: Collection<RestorableChunk>, observer: RestoreObserver?): Int { suspend fun restore(
version: Int,
zipChunks: Collection<RestorableChunk>,
observer: RestoreObserver?
): Int {
var restoredFiles = 0 var restoredFiles = 0
zipChunks.forEach { zipChunk -> zipChunks.forEach { zipChunk ->
try { try {
getAndDecryptChunk(zipChunk.chunkId) { decryptedStream -> getAndDecryptChunk(version, zipChunk.chunkId) { decryptedStream ->
restoredFiles += restoreZipChunk(zipChunk, decryptedStream, observer) restoredFiles += restoreZipChunk(zipChunk, decryptedStream, observer)
} }
} catch (e: Exception) { } catch (e: Exception) {