Check that version in snapshot matches the one in chunks

This commit is contained in:
Torsten Grote 2021-03-04 16:58:55 -03:00 committed by Chirayu Desai
parent be9a84d704
commit 61fe823a04
6 changed files with 32 additions and 14 deletions

View file

@ -18,11 +18,12 @@ internal abstract class AbstractChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class)
protected suspend fun getAndDecryptChunk(
version: Int,
chunkId: String,
streamReader: suspend (InputStream) -> Unit,
) {
storagePlugin.getChunkInputStream(chunkId).use { inputStream ->
inputStream.readVersion()
inputStream.readVersion(version)
val ad = streamCrypto.getAssociatedDataForChunk(chunkId)
streamCrypto.newDecryptingStream(streamKey, inputStream, ad).use { decryptedStream ->
streamReader(decryptedStream)

View file

@ -25,6 +25,7 @@ internal class MultiChunkRestore(
) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) {
suspend fun restore(
version: Int,
chunkMap: Map<String, RestorableChunk>,
files: Collection<RestorableFile>,
observer: RestoreObserver?,
@ -33,7 +34,7 @@ internal class MultiChunkRestore(
files.forEach { file ->
try {
restoreFile(file, observer, "L") { outputStream ->
writeChunks(file, chunkMap, outputStream)
writeChunks(version, file, chunkMap, outputStream)
}
restoredFiles++
} catch (e: Exception) {
@ -47,6 +48,7 @@ internal class MultiChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class)
private suspend fun writeChunks(
version: Int,
file: RestorableFile,
chunkMap: Map<String, RestorableChunk>,
outputStream: OutputStream,
@ -59,8 +61,8 @@ internal class MultiChunkRestore(
bytes += decryptedStream.copyTo(outputStream)
}
val isCached = isCached(chunkId)
if (isCached || otherFiles.size > 1) getAndCacheChunk(chunkId, chunkWriter)
else getAndDecryptChunk(chunkId, chunkWriter)
if (isCached || otherFiles.size > 1) getAndCacheChunk(version, chunkId, chunkWriter)
else getAndDecryptChunk(version, chunkId, chunkWriter)
otherFiles.remove(file)
if (isCached && otherFiles.isEmpty()) removeCachedChunk(chunkId)
@ -74,13 +76,14 @@ internal class MultiChunkRestore(
@Throws(IOException::class, GeneralSecurityException::class)
private suspend fun getAndCacheChunk(
version: Int,
chunkId: String,
streamReader: suspend (InputStream) -> Unit,
) {
val file = getChunkCacheFile(chunkId)
if (!file.isFile) {
FileOutputStream(file).use { outputStream ->
getAndDecryptChunk(chunkId) { decryptedStream ->
getAndDecryptChunk(version, chunkId) { decryptedStream ->
decryptedStream.copyTo(outputStream)
}
}

View file

@ -73,7 +73,7 @@ internal class Restore(
Log.e(TAG, "Decrypting and parsing $numSnapshots snapshots took $time")
}
@Throws(IOException::class)
@Throws(IOException::class, GeneralSecurityException::class)
suspend fun restoreBackupSnapshot(timestamp: Long, observer: RestoreObserver?) {
val snapshot = snapshotRetriever.getSnapshot(streamKey, timestamp)
restoreBackupSnapshot(snapshot, observer)
@ -88,17 +88,19 @@ internal class Restore(
observer?.onRestoreStart(filesTotal, totalSize)
val split = FileSplitter.splitSnapshot(snapshot)
val version = snapshot.version
var restoredFiles = 0
val smallFilesDuration = measure {
restoredFiles += zipChunkRestore.restore(split.zipChunks, observer)
restoredFiles += zipChunkRestore.restore(version, split.zipChunks, observer)
}
Log.e(TAG, "Restoring ${split.zipChunks.size} zip chunks took $smallFilesDuration.")
val singleChunkDuration = measure {
restoredFiles += singleChunkRestore.restore(split.singleChunks, observer)
restoredFiles += singleChunkRestore.restore(version, split.singleChunks, observer)
}
Log.e(TAG, "Restoring ${split.singleChunks.size} single chunks took $singleChunkDuration.")
val multiChunkDuration = measure {
restoredFiles += multiChunkRestore.restore(
version,
split.multiChunkMap,
split.multiChunkFiles,
observer
@ -113,10 +115,13 @@ internal class Restore(
}
@Throws(IOException::class)
internal fun InputStream.readVersion() {
@Throws(IOException::class, GeneralSecurityException::class)
internal fun InputStream.readVersion(expectedVersion: Int? = null) {
val version = read()
if (version == -1) throw IOException()
if (expectedVersion != null && version != expectedVersion) {
throw GeneralSecurityException("Expected version $expectedVersion, not $version")
}
if (version > Backup.VERSION) {
// TODO maybe throw a different exception here and tell the user?
throw IOException()

View file

@ -36,6 +36,7 @@ public abstract class RestoreService : Service() {
startForeground(NOTIFICATION_ID_RESTORE, n.getRestoreNotification())
GlobalScope.launch {
// TODO offer a way to try again if failed, or do an automatic retry here
storageBackup.restoreBackupSnapshot(timestamp, restoreObserver)
stopSelf(startId)
}

View file

@ -15,13 +15,17 @@ internal class SingleChunkRestore(
streamKey: ByteArray
) : AbstractChunkRestore(storagePlugin, fileRestore, streamCrypto, streamKey) {
suspend fun restore(chunks: Collection<RestorableChunk>, observer: RestoreObserver?): Int {
suspend fun restore(
version: Int,
chunks: Collection<RestorableChunk>,
observer: RestoreObserver?
): Int {
var restoredFiles = 0
chunks.forEach { chunk ->
check(chunk.files.size == 1)
val file = chunk.files[0]
try {
getAndDecryptChunk(chunk.chunkId) { decryptedStream ->
getAndDecryptChunk(version, chunk.chunkId) { decryptedStream ->
restoreFile(file, observer, "M") { outputStream ->
decryptedStream.copyTo(outputStream)
}

View file

@ -22,11 +22,15 @@ internal class ZipChunkRestore(
/**
* Assumes that files in [zipChunks] are sorted by zipIndex with no duplicate indices.
*/
suspend fun restore(zipChunks: Collection<RestorableChunk>, observer: RestoreObserver?): Int {
suspend fun restore(
version: Int,
zipChunks: Collection<RestorableChunk>,
observer: RestoreObserver?
): Int {
var restoredFiles = 0
zipChunks.forEach { zipChunk ->
try {
getAndDecryptChunk(zipChunk.chunkId) { decryptedStream ->
getAndDecryptChunk(version, zipChunk.chunkId) { decryptedStream ->
restoredFiles += restoreZipChunk(zipChunk, decryptedStream, observer)
}
} catch (e: Exception) {