diff --git a/app/src/main/java/com/stevesoltys/seedvault/repo/Checker.kt b/app/src/main/java/com/stevesoltys/seedvault/repo/Checker.kt index 7fda147a..76a0a806 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/repo/Checker.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/repo/Checker.kt @@ -6,6 +6,7 @@ package com.stevesoltys.seedvault.repo import androidx.annotation.WorkerThread +import com.google.protobuf.ByteString import com.stevesoltys.seedvault.backend.BackendManager import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.proto.Snapshot @@ -71,10 +72,12 @@ internal class Checker( this.handleSize = handles.size // remember number of snapshot handles we had // get total disk space used by snapshots - val sizeMap = mutableMapOf() + val sizeMap = mutableMapOf() // uses blob.id as key snapshots.forEach { snapshot -> // add sizes to a map first, so we don't double count - snapshot.blobsMap.forEach { (chunkId, blob) -> sizeMap[chunkId] = blob.length } + snapshot.blobsMap.forEach { (_, blob) -> + sizeMap[blob.id] = blob.length + } } return sizeMap.values.sumOf { it.toLong() } } @@ -96,13 +99,13 @@ internal class Checker( "Got $handleSize handles, but ${snapshots.size} snapshots." } val blobSample = getBlobSample(snapshots, percent) - val sampleSize = blobSample.values.sumOf { it.length.toLong() } + val sampleSize = blobSample.sumOf { it.blob.length.toLong() } log.info { "Blob sample has ${blobSample.size} blobs worth $sampleSize bytes." } // check blobs concurrently val semaphore = Semaphore(concurrencyLimit) val size = AtomicLong() - val badChunks = ConcurrentSkipListSet() + val badChunks = ConcurrentSkipListSet() val lastNotification = AtomicLong() val startTime = System.currentTimeMillis() coroutineScope { @@ -116,7 +119,7 @@ internal class Checker( } catch (e: Exception) { log.error(e) { "Error loading chunk $chunkId: " } // TODO we could try differentiating transient backend issues - badChunks.add(chunkId) + badChunks.add(ChunkIdBlobPair(chunkId, blob)) } } // keep track of how much we checked and for how long @@ -154,25 +157,30 @@ internal class Checker( checkerResult = null } - private fun getBlobSample(snapshots: List, percent: Int): Map { - // split up blobs for app data and for APKs - val appBlobs = mutableMapOf() - val apkBlobs = mutableMapOf() + private fun getBlobSample( + snapshots: List, + percent: Int, + ): List { + // split up blobs for app data and for APKs (use blob.id as key to prevent double counting) + val appBlobs = mutableMapOf() + val apkBlobs = mutableMapOf() snapshots.forEach { snapshot -> val appChunkIds = snapshot.appsMap.flatMap { it.value.chunkIdsList.hexFromProto() } val apkChunkIds = snapshot.appsMap.flatMap { it.value.apk.splitsList.flatMap { split -> split.chunkIdsList.hexFromProto() } } appChunkIds.forEach { chunkId -> - appBlobs[chunkId] = snapshot.blobsMap[chunkId] ?: error("No Blob for chunkId") + val blob = snapshot.blobsMap[chunkId] ?: error("No Blob for chunkId") + appBlobs[blob.id] = ChunkIdBlobPair(chunkId, blob) } apkChunkIds.forEach { chunkId -> - apkBlobs[chunkId] = snapshot.blobsMap[chunkId] ?: error("No Blob for chunkId") + val blob = snapshot.blobsMap[chunkId] ?: error("No Blob for chunkId") + apkBlobs[blob.id] = ChunkIdBlobPair(chunkId, blob) } } // calculate sizes - val appSize = appBlobs.values.sumOf { it.length.toLong() } - val apkSize = apkBlobs.values.sumOf { it.length.toLong() } + val appSize = appBlobs.values.sumOf { it.blob.length.toLong() } + val apkSize = apkBlobs.values.sumOf { it.blob.length.toLong() } // let's assume it is unlikely that app data and APKs have blobs in common val totalSize = appSize + apkSize log.info { "Got ${appBlobs.size + apkBlobs.size} blobs worth $totalSize bytes to check." } @@ -182,23 +190,21 @@ internal class Checker( val appTargetSize = min((targetSize * 0.75).roundToLong(), appSize) // 75% of targetSize log.info { "Sampling $targetSize bytes of which $appTargetSize bytes for apps." } - val blobSample = mutableMapOf() + val blobSample = mutableListOf() var currentSize = 0L // check apps first until we reach their target size - val appIterator = appBlobs.keys.shuffled().iterator() // random app blob iterator + val appIterator = appBlobs.values.shuffled().iterator() // random app blob iterator while (currentSize < appTargetSize && appIterator.hasNext()) { - val randomChunkId = appIterator.next() - val blob = appBlobs[randomChunkId] ?: error("No blob") - blobSample[randomChunkId] = blob - currentSize += blob.length + val pair = appIterator.next() + blobSample.add(pair) + currentSize += pair.blob.length } // now check APKs until we reach total targetSize - val apkIterator = apkBlobs.keys.shuffled().iterator() // random APK blob iterator + val apkIterator = apkBlobs.values.shuffled().iterator() // random APK blob iterator while (currentSize < targetSize && apkIterator.hasNext()) { - val randomChunkId = apkIterator.next() - val blob = apkBlobs[randomChunkId] ?: error("No blob") - blobSample[randomChunkId] = blob - currentSize += blob.length + val pair = apkIterator.next() + blobSample.add(pair) + currentSize += pair.blob.length } return blobSample } @@ -216,3 +222,9 @@ internal class Checker( if (readChunkId != chunkId) throw GeneralSecurityException("ChunkId doesn't match") } } + +data class ChunkIdBlobPair(val chunkId: String, val blob: Blob) : Comparable { + override fun compareTo(other: ChunkIdBlobPair): Int { + return chunkId.compareTo(other.chunkId) + } +} diff --git a/app/src/main/java/com/stevesoltys/seedvault/repo/CheckerResult.kt b/app/src/main/java/com/stevesoltys/seedvault/repo/CheckerResult.kt index a2d5ecdf..2be2066d 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/repo/CheckerResult.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/repo/CheckerResult.kt @@ -24,7 +24,7 @@ sealed class CheckerResult { /** * The list of chunkIDs that had errors. */ - val errorChunkIds: Set, + val errorChunkIdBlobPairs: Set, ) : CheckerResult() { val goodSnapshots: List val badSnapshots: List @@ -32,9 +32,23 @@ sealed class CheckerResult { init { val good = mutableListOf() val bad = mutableListOf() + val errorChunkIds = errorChunkIdBlobPairs.map { it.chunkId }.toSet() snapshots.forEach { snapshot -> - val isGood = snapshot.blobsMap.keys.intersect(errorChunkIds).isEmpty() - if (isGood) good.add(snapshot) else bad.add(snapshot) + val badChunkIds = snapshot.blobsMap.keys.intersect(errorChunkIds) + if (badChunkIds.isEmpty()) { + // snapshot doesn't contain chunks with erroneous blobs + good.add(snapshot) + } else { + // snapshot may contain chunks with erroneous blobs, check deeper + val isBad = badChunkIds.any { chunkId -> + val blob = snapshot.blobsMap[chunkId] ?: error("No blob for chunkId") + // is this chunkId/blob pair in errorChunkIdBlobPairs? + errorChunkIdBlobPairs.any { pair -> + pair.chunkId == chunkId && pair.blob == blob + } + } + if (isBad) bad.add(snapshot) else good.add(snapshot) + } } goodSnapshots = good badSnapshots = bad diff --git a/app/src/test/java/com/stevesoltys/seedvault/repo/CheckerTest.kt b/app/src/test/java/com/stevesoltys/seedvault/repo/CheckerTest.kt index 75241f40..cbfca6b1 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/repo/CheckerTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/repo/CheckerTest.kt @@ -82,6 +82,36 @@ internal class CheckerTest : TransportTest() { assertEquals(expectedSize, checker.getBackupSize()) } + @Test + fun `getBackupSize returns size without under-counting blobs with same chunkId`() = + runBlocking { + val apk = apk.copy { + splits.clear() + splits.add(baseSplit.copy { + this.chunkIds.clear() + chunkIds.add(ByteString.fromHex(chunkId1)) + }) + } + val snapshot = snapshot.copy { + apps[packageName] = app.copy { this.apk = apk } + blobs.clear() + } + val snapshotMap = mapOf( + snapshotHandle1 to snapshot.copy { + token = 1 + blobs[chunkId1] = blob1 + }, + snapshotHandle2 to snapshot.copy { + token = 2 + blobs[chunkId1] = blob2 + }, + ) + val expectedSize = blob1.length.toLong() + blob2.length.toLong() + expectLoadingSnapshots(snapshotMap) + + assertEquals(expectedSize, checker.getBackupSize()) + } + @Test fun `check works even with no backup data`() = runBlocking { expectLoadingSnapshots(emptyMap()) @@ -138,7 +168,8 @@ internal class CheckerTest : TransportTest() { assertEquals(snapshotMap.values.toSet(), result.badSnapshots.toSet()) assertEquals(emptyList(), result.goodSnapshots) assertEquals(snapshotMap.size, result.existingSnapshots) - assertEquals(setOf(chunkId1, chunkId2), result.errorChunkIds) + val errorPairs = setOf(ChunkIdBlobPair(chunkId1, blob1), ChunkIdBlobPair(chunkId2, blob2)) + assertEquals(errorPairs, result.errorChunkIdBlobPairs) } @Test @@ -189,7 +220,8 @@ internal class CheckerTest : TransportTest() { assertEquals(listOf(snapshotMap[snapshotHandle1]), result.goodSnapshots) assertEquals(listOf(snapshotMap[snapshotHandle2]), result.badSnapshots) assertEquals(snapshotMap.size, result.existingSnapshots) - assertEquals(setOf(chunkId2), result.errorChunkIds) + val errorPairs = setOf(ChunkIdBlobPair(chunkId2, blob2)) + assertEquals(errorPairs, result.errorChunkIdBlobPairs) } @Test @@ -247,8 +279,8 @@ internal class CheckerTest : TransportTest() { fun `check prefers app data over APKs`() = runBlocking { val appDataBlob = blob { id = ByteString.copyFrom(Random.nextBytes(32)) - length = Random.nextInt(0, Int.MAX_VALUE) - uncompressedLength = Random.nextInt(0, Int.MAX_VALUE) + length = Random.nextInt(1, Int.MAX_VALUE) + uncompressedLength = Random.nextInt(1, Int.MAX_VALUE) } val appDataBlobHandle1 = AppBackupFileType.Blob(repoId, appDataBlob.id.hexFromProto()) val appDataChunkId = Random.nextBytes(32).toHexString() @@ -266,6 +298,7 @@ internal class CheckerTest : TransportTest() { // only loading app data, not other blobs coEvery { loader.loadFile(appDataBlobHandle1, null) } throws SecurityException() + println("appDataBlob.length = $appDataBlob.length") every { nm.onCheckFinishedWithError(appDataBlob.length.toLong(), any()) } just Runs assertNull(checker.checkerResult) @@ -275,7 +308,8 @@ internal class CheckerTest : TransportTest() { assertEquals(snapshotMap.values.toSet(), result.snapshots.toSet()) assertEquals(snapshotMap.values.toSet(), result.badSnapshots.toSet()) assertEquals(snapshotMap.size, result.existingSnapshots) - assertEquals(setOf(appDataChunkId), result.errorChunkIds) + val errorPairs = setOf(ChunkIdBlobPair(appDataChunkId, appDataBlob)) + assertEquals(errorPairs, result.errorChunkIdBlobPairs) coVerify(exactly = 0) { loader.loadFile(blobHandle1, null) @@ -283,6 +317,55 @@ internal class CheckerTest : TransportTest() { } } + @Test + fun `check doesn't skip broken blobs that have a fix with same chunkID`() = runBlocking { + // get "real" data for blob2 + val messageDigest = MessageDigest.getInstance("SHA-256") + val data1 = getRandomByteArray() // broken blob + val data2 = getRandomByteArray() // data2 matches chunkId + val chunkId = messageDigest.digest(data2).toHexString() + val apk = apk.copy { + splits.clear() + splits.add(baseSplit.copy { + this.chunkIds.clear() + chunkIds.add(ByteString.fromHex(chunkId)) + }) + } + val snapshot = snapshot.copy { + apps[packageName] = app.copy { this.apk = apk } + blobs.clear() + } + val snapshotMap = mapOf( + snapshotHandle1 to snapshot.copy { + token = 1 + blobs[chunkId] = blob1 // snapshot1 has broken blob for chunkId + }, + snapshotHandle2 to snapshot.copy { + token = 2 + blobs[chunkId] = blob2 // snapshot2 has fixed blob for chunkId + }, + ) + + expectLoadingSnapshots(snapshotMap) + every { backendManager.requiresNetwork } returns Random.nextBoolean() + + coEvery { loader.loadFile(blobHandle1, null) } returns ByteArrayInputStream(data1) + coEvery { loader.loadFile(blobHandle2, null) } returns ByteArrayInputStream(data2) + + every { nm.onCheckFinishedWithError(any(), any()) } just Runs + + assertNull(checker.checkerResult) + checker.check(100) + assertInstanceOf(CheckerResult.Error::class.java, checker.checkerResult) + val result = checker.checkerResult as CheckerResult.Error + assertEquals(snapshotMap.values.toSet(), result.snapshots.toSet()) + assertEquals(setOf(snapshotMap[snapshotHandle2]), result.goodSnapshots.toSet()) + assertEquals(setOf(snapshotMap[snapshotHandle1]), result.badSnapshots.toSet()) + assertEquals(snapshotMap.size, result.existingSnapshots) + val errorPairs = setOf(ChunkIdBlobPair(chunkId, blob1)) + assertEquals(errorPairs, result.errorChunkIdBlobPairs) + } + private suspend fun expectLoadingSnapshots( snapshots: Map, ) {