From a0f3c6b45fa6f4e79003a2984c32a8834cdc007f Mon Sep 17 00:00:00 2001 From: Torsten Grote Date: Wed, 22 Sep 2021 14:41:49 +0200 Subject: [PATCH] K/V restore using single file --- .../transport/backup/BackupCoordinator.kt | 3 + .../seedvault/transport/backup/KVDbManager.kt | 40 ++- .../seedvault/transport/restore/KVRestore.kt | 101 ++++-- .../transport/restore/RestoreCoordinator.kt | 10 +- .../transport/restore/RestoreModule.kt | 2 +- .../transport/CoordinatorIntegrationTest.kt | 52 ++- .../transport/backup/TestKvDbManager.kt | 24 +- .../transport/restore/KVRestoreTest.kt | 310 +++++++++++++----- .../restore/RestoreCoordinatorTest.kt | 10 +- .../restore/RestoreV0IntegrationTest.kt | 11 +- 10 files changed, 405 insertions(+), 158 deletions(-) diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt index 31a7436b..62c21b6a 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupCoordinator.kt @@ -272,6 +272,7 @@ internal class BackupCoordinator( val salt = metadataManager.salt val result = kv.performBackup(packageInfo, data, flags, token, salt) if (result == TRANSPORT_OK && packageName == MAGIC_PACKAGE_MANAGER) { + // TODO move to finish backup of @pm@ so we can upload the DB before // hook in here to back up APKs of apps that are otherwise not allowed for backup backUpApksOfNotBackedUpPackages() } @@ -392,7 +393,9 @@ internal class BackupCoordinator( } // getCurrentPackage() not-null because we have state onPackageBackedUp(kv.getCurrentPackage()!!, BackupType.KV) + val isPmBackup = kv.getCurrentPackage()!!.packageName == MAGIC_PACKAGE_MANAGER kv.finishBackup() + // TODO move @pm@ backup hook here } full.hasState() -> { check(!kv.hasState()) { diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt index 7689220c..f6d03e72 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVDbManager.kt @@ -8,37 +8,59 @@ import android.database.sqlite.SQLiteOpenHelper import android.provider.BaseColumns import java.io.File import java.io.FileInputStream +import java.io.FileOutputStream import java.io.InputStream +import java.io.OutputStream interface KvDbManager { - fun getDb(packageName: String): KVDb + fun getDb(packageName: String, isRestore: Boolean = false): KVDb + + /** + * Use only for backup. + */ fun getDbInputStream(packageName: String): InputStream + + /** + * Use only for restore. + */ + fun getDbOutputStream(packageName: String): OutputStream + + /** + * Use only for backup. + */ fun existsDb(packageName: String): Boolean - fun deleteDb(packageName: String): Boolean + fun deleteDb(packageName: String, isRestore: Boolean = false): Boolean } class KvDbManagerImpl(private val context: Context) : KvDbManager { - override fun getDb(packageName: String): KVDb { - return KVDbImpl(context, getFileName(packageName)) + override fun getDb(packageName: String, isRestore: Boolean): KVDb { + return KVDbImpl(context, getFileName(packageName, isRestore)) } - private fun getFileName(packageName: String) = "kv_$packageName.db" + private fun getFileName(packageName: String, isRestore: Boolean): String { + val prefix = if (isRestore) "restore_" else "" + return "${prefix}kv_$packageName.db" + } - private fun getDbFile(packageName: String): File { - return context.getDatabasePath(getFileName(packageName)) + private fun getDbFile(packageName: String, isRestore: Boolean = false): File { + return context.getDatabasePath(getFileName(packageName, isRestore)) } override fun getDbInputStream(packageName: String): InputStream { return FileInputStream(getDbFile(packageName)) } + override fun getDbOutputStream(packageName: String): OutputStream { + return FileOutputStream(getDbFile(packageName, true)) + } + override fun existsDb(packageName: String): Boolean { return getDbFile(packageName).isFile } - override fun deleteDb(packageName: String): Boolean { - return getDbFile(packageName).delete() + override fun deleteDb(packageName: String, isRestore: Boolean): Boolean { + return getDbFile(packageName, isRestore).delete() } } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt index 5ee74ed8..7af153fa 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt @@ -15,19 +15,25 @@ import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.getADForKV +import com.stevesoltys.seedvault.transport.backup.BackupPlugin +import com.stevesoltys.seedvault.transport.backup.KVDb +import com.stevesoltys.seedvault.transport.backup.KvDbManager import libcore.io.IoUtils.closeQuietly import java.io.IOException import java.security.GeneralSecurityException import java.util.ArrayList +import java.util.zip.GZIPInputStream import javax.crypto.AEADBadTagException private class KVRestoreState( val version: Byte, val token: Long, + val name: String, val packageInfo: PackageInfo, /** * Optional [PackageInfo] for single package restore, optimizes restore of @pm@ */ + @Deprecated("TODO remove?") val pmPackageInfo: PackageInfo? ) @@ -35,20 +41,25 @@ private val TAG = KVRestore::class.java.simpleName @Suppress("BlockingMethodInNonBlockingContext") internal class KVRestore( - private val plugin: KVRestorePlugin, + private val plugin: BackupPlugin, + private val legacyPlugin: KVRestorePlugin, private val outputFactory: OutputFactory, private val headerReader: HeaderReader, - private val crypto: Crypto + private val crypto: Crypto, + private val dbManager: KvDbManager ) { private var state: KVRestoreState? = null /** * Return true if there are records stored for the given package. + * + * Deprecated. Use only for v0 backups. */ @Throws(IOException::class) + @Deprecated("Use BackupPlugin#hasData() instead") suspend fun hasDataForPackage(token: Long, packageInfo: PackageInfo): Boolean { - return plugin.hasDataForPackage(token, packageInfo) + return legacyPlugin.hasDataForPackage(token, packageInfo) } /** @@ -62,10 +73,11 @@ internal class KVRestore( fun initializeState( version: Byte, token: Long, + name: String, packageInfo: PackageInfo, pmPackageInfo: PackageInfo? = null ) { - state = KVRestoreState(version, token, packageInfo, pmPackageInfo) + state = KVRestoreState(version, token, name, packageInfo, pmPackageInfo) } /** @@ -78,12 +90,66 @@ internal class KVRestore( suspend fun getRestoreData(data: ParcelFileDescriptor): Int { val state = this.state ?: throw IllegalStateException("no state") + // take legacy path for version 0 + if (state.version == 0x00.toByte()) return getRestoreDataV0(state, data) + + return try { + val db = getRestoreDb(state) + val out = outputFactory.getBackupDataOutput(data) + db.getAll().sortedBy { it.first }.forEach { (key, value) -> + val size = value.size + Log.v(TAG, " ... key=$key size=$size") + out.writeEntityHeader(key, size) + out.writeEntityData(value, size) + } + TRANSPORT_OK + } catch (e: UnsupportedVersionException) { + Log.e(TAG, "Unsupported version in backup: ${e.version}", e) + TRANSPORT_ERROR + } catch (e: IOException) { + Log.e(TAG, "Unable to process K/V backup database", e) + TRANSPORT_ERROR + } catch (e: GeneralSecurityException) { + Log.e(TAG, "General security exception while reading backup database", e) + TRANSPORT_ERROR + } catch (e: AEADBadTagException) { + Log.e(TAG, "Decryption failed", e) + TRANSPORT_ERROR + } finally { + dbManager.deleteDb(state.packageInfo.packageName, true) + this.state = null + closeQuietly(data) + } + } + + @Throws(IOException::class, GeneralSecurityException::class, UnsupportedVersionException::class) + private suspend fun getRestoreDb(state: KVRestoreState): KVDb { + val packageName = state.packageInfo.packageName + plugin.getInputStream(state.token, state.name).use { inputStream -> + headerReader.readVersion(inputStream, state.version) + val ad = getADForKV(VERSION, packageName) + crypto.newDecryptingStream(inputStream, ad).use { decryptedStream -> + GZIPInputStream(decryptedStream).use { gzipStream -> + dbManager.getDbOutputStream(packageName).use { outputStream -> + gzipStream.copyTo(outputStream) + } + } + } + } + return dbManager.getDb(packageName, true) + } + + // + // v0 restore legacy code below + // + + private suspend fun getRestoreDataV0(state: KVRestoreState, data: ParcelFileDescriptor): Int { // The restore set is the concatenation of the individual record blobs, // each of which is a file in the package's directory. // We return the data in lexical order sorted by key, // so that apps which use synthetic keys like BLOB_1, BLOB_2, etc // will see the date in the most obvious order. - val sortedKeys = getSortedKeys(state.token, state.packageInfo) + val sortedKeys = getSortedKeysV0(state.token, state.packageInfo) if (sortedKeys == null) { // nextRestorePackage() ensures the dir exists, so this is an error Log.e(TAG, "No keys for package: ${state.packageInfo.packageName}") @@ -96,7 +162,7 @@ internal class KVRestore( return try { val dataOutput = outputFactory.getBackupDataOutput(data) for (keyEntry in sortedKeys) { - readAndWriteValue(state, keyEntry, dataOutput) + readAndWriteValueV0(state, keyEntry, dataOutput) } TRANSPORT_OK } catch (e: IOException) { @@ -105,9 +171,6 @@ internal class KVRestore( } catch (e: SecurityException) { Log.e(TAG, "Security exception while reading backup records", e) TRANSPORT_ERROR - } catch (e: GeneralSecurityException) { - Log.e(TAG, "General security exception while reading backup records", e) - TRANSPORT_ERROR } catch (e: UnsupportedVersionException) { Log.e(TAG, "Unsupported version in backup: ${e.version}", e) TRANSPORT_ERROR @@ -124,9 +187,9 @@ internal class KVRestore( * Return a list of the records (represented by key files) in the given directory, * sorted lexically by the Base64-decoded key file name, not by the on-disk filename. */ - private suspend fun getSortedKeys(token: Long, packageInfo: PackageInfo): List? { + private suspend fun getSortedKeysV0(token: Long, packageInfo: PackageInfo): List? { val records: List = try { - plugin.listRecords(token, packageInfo) + legacyPlugin.listRecords(token, packageInfo) } catch (e: IOException) { return null } @@ -150,24 +213,18 @@ internal class KVRestore( /** * Read the encrypted value for the given key and write it to the given [BackupDataOutput]. */ + @Suppress("Deprecation") @Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class) - private suspend fun readAndWriteValue( + private suspend fun readAndWriteValueV0( state: KVRestoreState, dKey: DecodedKey, out: BackupDataOutput - ) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key) + ) = legacyPlugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key) .use { inputStream -> val version = headerReader.readVersion(inputStream, state.version) val packageName = state.packageInfo.packageName - val value = if (version == 0.toByte()) { - crypto.decryptHeader(inputStream, version, packageName, dKey.key) - crypto.decryptMultipleSegments(inputStream) - } else { - val ad = getADForKV(VERSION, packageName) - crypto.newDecryptingStream(inputStream, ad).use { decryptedStream -> - decryptedStream.readBytes() - } - } + crypto.decryptHeader(inputStream, version, packageName, dKey.key) + val value = crypto.decryptMultipleSegments(inputStream) val size = value.size Log.v(TAG, " ... key=${dKey.key} size=$size") diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt index 0859cdac..586a846f 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinator.kt @@ -207,7 +207,13 @@ internal class RestoreCoordinator( val name = crypto.getNameForPackage(state.backupMetadata.salt, packageName) if (plugin.hasData(state.token, name)) { Log.i(TAG, "Found K/V data for $packageName.") - kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo) + kv.initializeState( + version = version, + token = state.token, + name = name, + packageInfo = packageInfo, + pmPackageInfo = state.pmPackageInfo + ) state.currentPackage = packageName TYPE_KEY_VALUE } else throw IOException("No data found for $packageName. Skipping.") @@ -243,7 +249,7 @@ internal class RestoreCoordinator( // check key/value data first and if available, don't even check for full data kv.hasDataForPackage(state.token, packageInfo) -> { Log.i(TAG, "Found K/V data for $packageName.") - kv.initializeState(0x00, state.token, packageInfo, state.pmPackageInfo) + kv.initializeState(0x00, state.token, "", packageInfo, state.pmPackageInfo) state.currentPackage = packageName TYPE_KEY_VALUE } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt index 62756a84..7ef3c4ca 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/RestoreModule.kt @@ -5,7 +5,7 @@ import org.koin.dsl.module val restoreModule = module { single { OutputFactory() } - single { KVRestore(get().kvRestorePlugin, get(), get(), get()) } + single { KVRestore(get(), get().kvRestorePlugin, get(), get(), get(), get()) } single { FullRestore(get(), get().fullRestorePlugin, get(), get(), get()) } single { RestoreCoordinator(androidContext(), get(), get(), get(), get(), get(), get(), get(), get()) diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt index 73b25085..282a6bac 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt @@ -10,7 +10,6 @@ import android.os.ParcelFileDescriptor import com.stevesoltys.seedvault.crypto.CipherFactoryImpl import com.stevesoltys.seedvault.crypto.CryptoImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl -import com.stevesoltys.seedvault.encodeBase64 import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH import com.stevesoltys.seedvault.metadata.BackupType @@ -39,6 +38,7 @@ import io.mockk.coEvery import io.mockk.every import io.mockk.just import io.mockk.mockk +import io.mockk.verify import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals @@ -81,7 +81,14 @@ internal class CoordinatorIntegrationTest : TransportTest() { ) private val kvRestorePlugin = mockk() - private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl) + private val kvRestore = KVRestore( + backupPlugin, + kvRestorePlugin, + outputFactory, + headerReader, + cryptoImpl, + dbManager + ) private val fullRestorePlugin = mockk() private val fullRestore = FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl) @@ -104,9 +111,7 @@ internal class CoordinatorIntegrationTest : TransportTest() { private val metadataOutputStream = ByteArrayOutputStream() private val packageMetadata = PackageMetadata(time = 0L) private val key = "RestoreKey" - private val key64 = key.encodeBase64() private val key2 = "RestoreKey2" - private val key264 = key2.encodeBase64() // as we use real crypto, we need a real name for packageInfo private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName) @@ -116,7 +121,6 @@ internal class CoordinatorIntegrationTest : TransportTest() { val value = CapturingSlot() val value2 = CapturingSlot() val bOutputStream = ByteArrayOutputStream() - val bOutputStream2 = ByteArrayOutputStream() every { settingsManager.getToken() } returns token every { metadataManager.salt } returns salt @@ -170,29 +174,21 @@ internal class CoordinatorIntegrationTest : TransportTest() { // restore finds the backed up key and writes the decrypted value val backupDataOutput = mockk() val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray()) - val rInputStream2 = ByteArrayInputStream(bOutputStream2.toByteArray()) - coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64, key264) + coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput - coEvery { - kvRestorePlugin.getInputStreamForRecord( - token, - packageInfo, - key64 - ) - } returns rInputStream every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137 every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size - coEvery { - kvRestorePlugin.getInputStreamForRecord( - token, - packageInfo, - key264 - ) - } returns rInputStream2 every { backupDataOutput.writeEntityHeader(key2, appData2.size) } returns 1137 every { backupDataOutput.writeEntityData(appData2, appData2.size) } returns appData2.size assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) + + verify { + backupDataOutput.writeEntityHeader(key, appData.size) + backupDataOutput.writeEntityData(appData, appData.size) + backupDataOutput.writeEntityHeader(key2, appData2.size) + backupDataOutput.writeEntityData(appData2, appData2.size) + } } @Test @@ -246,19 +242,17 @@ internal class CoordinatorIntegrationTest : TransportTest() { // restore finds the backed up key and writes the decrypted value val backupDataOutput = mockk() val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray()) - coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64) + coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput - coEvery { - kvRestorePlugin.getInputStreamForRecord( - token, - packageInfo, - key64 - ) - } returns rInputStream every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137 every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) + + verify { + backupDataOutput.writeEntityHeader(key, appData.size) + backupDataOutput.writeEntityData(appData, appData.size) + } } @Test diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt index 34a5e0d4..7173f2ff 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/TestKvDbManager.kt @@ -3,22 +3,29 @@ package com.stevesoltys.seedvault.transport.backup import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.toHexString -import junit.framework.Assert.assertEquals -import junit.framework.Assert.assertFalse -import junit.framework.Assert.assertNull -import junit.framework.Assert.assertTrue import org.json.JSONObject import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream import java.io.InputStream +import java.io.OutputStream import kotlin.random.Random class TestKvDbManager : KvDbManager { private var db: TestKVDb? = null + private val outputStream = ByteArrayOutputStream() - override fun getDb(packageName: String): KVDb { + override fun getDb(packageName: String, isRestore: Boolean): KVDb { + if (isRestore) { + readDbFromStream(ByteArrayInputStream(outputStream.toByteArray())) + return this.db!! + } return TestKVDb().apply { db = this } } @@ -26,11 +33,16 @@ class TestKvDbManager : KvDbManager { return ByteArrayInputStream(db!!.serialize().toByteArray()) } + override fun getDbOutputStream(packageName: String): OutputStream { + outputStream.reset() + return outputStream + } + override fun existsDb(packageName: String): Boolean { return db != null } - override fun deleteDb(packageName: String): Boolean { + override fun deleteDb(packageName: String, isRestore: Boolean): Boolean { clearDb() return true } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt index 05a8337d..a2ae238f 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt @@ -10,48 +10,57 @@ import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.getADForKV +import com.stevesoltys.seedvault.transport.backup.BackupPlugin +import com.stevesoltys.seedvault.transport.backup.KVDb +import com.stevesoltys.seedvault.transport.backup.KvDbManager import io.mockk.Runs import io.mockk.coEvery import io.mockk.every import io.mockk.just import io.mockk.mockk import io.mockk.mockkStatic +import io.mockk.verify import io.mockk.verifyAll import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream import java.io.IOException import java.io.InputStream import java.security.GeneralSecurityException +import java.util.zip.GZIPOutputStream import kotlin.random.Random @Suppress("BlockingMethodInNonBlockingContext") internal class KVRestoreTest : RestoreTest() { - private val plugin = mockk() + private val plugin = mockk() + private val legacyPlugin = mockk() + private val dbManager = mockk() private val output = mockk() - private val restore = KVRestore(plugin, outputFactory, headerReader, crypto) + private val restore = + KVRestore(plugin, legacyPlugin, outputFactory, headerReader, crypto, dbManager) + + private val db = mockk() private val ad = getADForKV(VERSION, packageInfo.packageName) private val key = "Restore Key" private val key64 = key.encodeBase64() private val key2 = "Restore Key2" private val key264 = key2.encodeBase64() + private val data2 = getRandomByteArray() + + private val outputStream = ByteArrayOutputStream().apply { + GZIPOutputStream(this).close() + } + private val decryptInputStream = ByteArrayInputStream(outputStream.toByteArray()) init { // for InputStream#readBytes() mockkStatic("kotlin.io.ByteStreamsKt") } - @Test - fun `hasDataForPackage() delegates to plugin`() = runBlocking { - val result = Random.nextBoolean() - - coEvery { plugin.hasDataForPackage(token, packageInfo) } returns result - - assertEquals(result, restore.hasDataForPackage(token, packageInfo)) - } - @Test fun `getRestoreData() throws without initializing state`() { coAssertThrows(IllegalStateException::class.java) { @@ -60,22 +69,133 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `listing records throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + fun `unexpected version aborts with error`() = runBlocking { + restore.initializeState(VERSION, token, name, packageInfo) - coEvery { plugin.listRecords(token, packageInfo) } throws IOException() + coEvery { plugin.getInputStream(token, name) } returns inputStream + every { + headerReader.readVersion(inputStream, VERSION) + } throws UnsupportedVersionException(Byte.MAX_VALUE) + every { dbManager.deleteDb(packageInfo.packageName, true) } returns true + streamsGetClosed() + + assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) + verifyStreamWasClosed() + } + + @Test + fun `newDecryptingStream throws`() = runBlocking { + restore.initializeState(VERSION, token, name, packageInfo) + + coEvery { plugin.getInputStream(token, name) } returns inputStream + every { headerReader.readVersion(inputStream, VERSION) } returns VERSION + every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException() + every { dbManager.deleteDb(packageInfo.packageName, true) } returns true + streamsGetClosed() + + assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) + verifyStreamWasClosed() + + verifyAll { + dbManager.deleteDb(packageInfo.packageName, true) + } + } + + @Test + fun `writeEntityHeader throws`() = runBlocking { + restore.initializeState(VERSION, token, name, packageInfo) + + coEvery { plugin.getInputStream(token, name) } returns inputStream + every { headerReader.readVersion(inputStream, VERSION) } returns VERSION + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream + every { + dbManager.getDbOutputStream(packageInfo.packageName) + } returns ByteArrayOutputStream() + every { dbManager.getDb(packageInfo.packageName, true) } returns db + every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output + every { db.getAll() } returns listOf(Pair(key, data)) + every { output.writeEntityHeader(key, data.size) } throws IOException() + every { dbManager.deleteDb(packageInfo.packageName, true) } returns true + streamsGetClosed() + + assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) + verifyStreamWasClosed() + + verify { + dbManager.deleteDb(packageInfo.packageName, true) + } + } + + @Test + fun `two records get restored`() = runBlocking { + restore.initializeState(VERSION, token, name, packageInfo) + + coEvery { plugin.getInputStream(token, name) } returns inputStream + every { headerReader.readVersion(inputStream, VERSION) } returns VERSION + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream + every { + dbManager.getDbOutputStream(packageInfo.packageName) + } returns ByteArrayOutputStream() + every { dbManager.getDb(packageInfo.packageName, true) } returns db + every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output + every { db.getAll() } returns listOf( + Pair(key, data), + Pair(key2, data2) + ) + every { output.writeEntityHeader(key, data.size) } returns 42 + every { output.writeEntityData(data, data.size) } returns data.size + every { output.writeEntityHeader(key2, data2.size) } returns 42 + every { output.writeEntityData(data2, data2.size) } returns data2.size + + every { dbManager.deleteDb(packageInfo.packageName, true) } returns true + streamsGetClosed() + + assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) + verifyStreamWasClosed() + + verify { + output.writeEntityHeader(key, data.size) + output.writeEntityData(data, data.size) + output.writeEntityHeader(key2, data2.size) + output.writeEntityData(data2, data2.size) + dbManager.deleteDb(packageInfo.packageName, true) + } + } + + // + // v0 legacy tests below + // + + @Test + @Suppress("Deprecation") + fun `v0 hasDataForPackage() delegates to plugin`() = runBlocking { + val result = Random.nextBoolean() + + coEvery { legacyPlugin.hasDataForPackage(token, packageInfo) } returns result + + assertEquals(result, restore.hasDataForPackage(token, packageInfo)) + } + + @Test + @Suppress("Deprecation") + fun `v0 listing records throws`() = runBlocking { + restore.initializeState(0x00, token, name, packageInfo) + + coEvery { legacyPlugin.listRecords(token, packageInfo) } throws IOException() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) } @Test - fun `reading VersionHeader with unsupported version throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + fun `v0 reading VersionHeader with unsupported version throws`() = runBlocking { + restore.initializeState(0x00, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream every { - headerReader.readVersion(inputStream, VERSION) + headerReader.readVersion(inputStream, 0x00) } throws UnsupportedVersionException(unsupportedVersion) streamsGetClosed() @@ -84,12 +204,14 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `error reading VersionHeader throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + fun `v0 error reading VersionHeader throws`() = runBlocking { + restore.initializeState(0x00, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } throws IOException() + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0x00) } throws IOException() streamsGetClosed() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) @@ -97,13 +219,18 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `decrypting stream throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + @Suppress("deprecation") + fun `v0 decrypting stream throws`() = runBlocking { + restore.initializeState(0x00, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0x00) } returns 0x00 + every { + crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key) + } throws IOException() streamsGetClosed() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) @@ -111,13 +238,19 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `decrypting stream throws security exception`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + @Suppress("deprecation") + fun `v0 decrypting stream throws security exception`() = runBlocking { + restore.initializeState(0x00, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException() + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0x00) } returns 0x00 + every { + crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key) + } returns VersionHeader(0x00, packageInfo.packageName, key) + every { crypto.decryptMultipleSegments(inputStream) } throws IOException() streamsGetClosed() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) @@ -125,14 +258,19 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `writing header throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + @Suppress("Deprecation") + fun `v0 writing header throws`() = runBlocking { + restore.initializeState(0, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream - every { decryptedInputStream.readBytes() } returns data + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0) } returns 0 + every { + crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key) + } returns VersionHeader(0x00, packageInfo.packageName, key) + every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } throws IOException() streamsGetClosed() @@ -141,14 +279,19 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `writing value throws`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + @Suppress("deprecation") + fun `v0 writing value throws`() = runBlocking { + restore.initializeState(0, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream - every { decryptedInputStream.readBytes() } returns data + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0) } returns 0 + every { + crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key) + } returns VersionHeader(0, packageInfo.packageName, key) + every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } throws IOException() streamsGetClosed() @@ -158,14 +301,19 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `writing value succeeds`() = runBlocking { - restore.initializeState(VERSION, token, packageInfo) + @Suppress("deprecation") + fun `v0 writing value succeeds`() = runBlocking { + restore.initializeState(0, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream - every { decryptedInputStream.readBytes() } returns data + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0) } returns 0 + every { + crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key) + } returns VersionHeader(0, packageInfo.packageName, key) + every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } returns data.size streamsGetClosed() @@ -175,14 +323,17 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `writing value uses old v0 code`() = runBlocking { - restore.initializeState(0.toByte(), token, packageInfo) + @Suppress("deprecation") + fun `v0 writing value uses old v0 code`() = runBlocking { + restore.initializeState(0, token, name, packageInfo) getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte() + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0) } returns 0 every { - crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key) + crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key) } returns VersionHeader(VERSION, packageInfo.packageName, key) every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } returns 42 @@ -194,43 +345,35 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `unexpected version aborts with error`() = runBlocking { - restore.initializeState(Byte.MAX_VALUE, token, packageInfo) - - getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { - headerReader.readVersion(inputStream, Byte.MAX_VALUE) - } throws GeneralSecurityException() - streamsGetClosed() - - assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) - verifyStreamWasClosed() - } - - @Test - fun `writing two values succeeds`() = runBlocking { + @Suppress("Deprecation") + fun `v0 writing two values succeeds`() = runBlocking { val data2 = getRandomByteArray() val inputStream2 = mockk() - val decryptedInputStream2 = mockk() - restore.initializeState(VERSION, token, packageInfo) + restore.initializeState(0, token, name, packageInfo) getRecordsAndOutput(listOf(key64, key264)) // first key/value - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream - every { decryptedInputStream.readBytes() } returns data + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key64) + } returns inputStream + every { headerReader.readVersion(inputStream, 0) } returns 0 + every { + crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key) + } returns VersionHeader(0, packageInfo.packageName, key) + every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } returns data.size // second key/value - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2 - every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION - every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2 - every { decryptedInputStream2.readBytes() } returns data2 + coEvery { + legacyPlugin.getInputStreamForRecord(token, packageInfo, key264) + } returns inputStream2 + every { headerReader.readVersion(inputStream2, 0) } returns 0 + every { + crypto.decryptHeader(inputStream2, 0, packageInfo.packageName, key2) + } returns VersionHeader(0, packageInfo.packageName, key2) + every { crypto.decryptMultipleSegments(inputStream2) } returns data2 every { output.writeEntityHeader(key2, data2.size) } returns 42 every { output.writeEntityData(data2, data2.size) } returns data2.size - every { decryptedInputStream2.close() } just Runs every { inputStream2.close() } just Runs streamsGetClosed() @@ -238,12 +381,11 @@ internal class KVRestoreTest : RestoreTest() { } private fun getRecordsAndOutput(recordKeys: List = listOf(key64)) { - coEvery { plugin.listRecords(token, packageInfo) } returns recordKeys + coEvery { legacyPlugin.listRecords(token, packageInfo) } returns recordKeys every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output } private fun streamsGetClosed() { - every { decryptedInputStream.close() } just Runs every { inputStream.close() } just Runs every { fileDescriptor.close() } just Runs } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt index 72fe53a3..357098f6 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreCoordinatorTest.kt @@ -223,19 +223,20 @@ internal class RestoreCoordinatorTest : TransportTest() { every { crypto.getNameForPackage(metadata.salt, packageName) } returns name coEvery { plugin.hasData(token, name) } returns true - every { kv.initializeState(VERSION, token, packageInfo) } just Runs + every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs val expected = RestoreDescription(packageName, TYPE_KEY_VALUE) assertEquals(expected, restore.nextRestorePackage()) } @Test + @Suppress("Deprecation") fun `v0 nextRestorePackage() returns KV description and takes precedence`() = runBlocking { restore.beforeStartRestore(metadata.copy(version = 0x00)) restore.startRestore(token, packageInfoArray) coEvery { kv.hasDataForPackage(token, packageInfo) } returns true - every { kv.initializeState(0x00, token, packageInfo) } just Runs + every { kv.initializeState(0x00, token, "", packageInfo) } just Runs val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) assertEquals(expected, restore.nextRestorePackage()) @@ -292,7 +293,7 @@ internal class RestoreCoordinatorTest : TransportTest() { every { crypto.getNameForPackage(metadata.salt, packageName) } returns name coEvery { plugin.hasData(token, name) } returns true - every { kv.initializeState(VERSION, token, packageInfo) } just Runs + every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) assertEquals(expected, restore.nextRestorePackage()) @@ -315,7 +316,7 @@ internal class RestoreCoordinatorTest : TransportTest() { restore.startRestore(token, packageInfoArray2) coEvery { kv.hasDataForPackage(token, packageInfo) } returns true - every { kv.initializeState(0.toByte(), token, packageInfo) } just Runs + every { kv.initializeState(0.toByte(), token, "", packageInfo) } just Runs val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) assertEquals(expected, restore.nextRestorePackage()) @@ -331,6 +332,7 @@ internal class RestoreCoordinatorTest : TransportTest() { } @Test + @Suppress("Deprecation") fun `v0 when kv#hasDataForPackage() throws, it tries next package`() = runBlocking { restore.beforeStartRestore(metadata.copy(version = 0x00)) restore.startRestore(token, packageInfoArray) diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt index 1b06ed76..07d4b5a6 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreV0IntegrationTest.kt @@ -16,6 +16,7 @@ import com.stevesoltys.seedvault.metadata.MetadataReaderImpl import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.transport.TransportTest import com.stevesoltys.seedvault.transport.backup.BackupPlugin +import com.stevesoltys.seedvault.transport.backup.KvDbManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import io.mockk.coEvery import io.mockk.every @@ -44,12 +45,20 @@ internal class RestoreV0IntegrationTest : TransportTest() { private val cipherFactory = CipherFactoryImpl(keyManager) private val headerReader = HeaderReaderImpl() private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader) + private val dbManager = mockk() private val metadataReader = MetadataReaderImpl(cryptoImpl) private val notificationManager = mockk() private val backupPlugin = mockk() private val kvRestorePlugin = mockk() - private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl) + private val kvRestore = KVRestore( + backupPlugin, + kvRestorePlugin, + outputFactory, + headerReader, + cryptoImpl, + dbManager + ) private val fullRestorePlugin = mockk() private val fullRestore = FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)