K/V restore using single file

This commit is contained in:
Torsten Grote 2021-09-22 14:41:49 +02:00 committed by Chirayu Desai
parent 0c915e5eb8
commit a0f3c6b45f
10 changed files with 405 additions and 158 deletions

View file

@ -272,6 +272,7 @@ internal class BackupCoordinator(
val salt = metadataManager.salt val salt = metadataManager.salt
val result = kv.performBackup(packageInfo, data, flags, token, salt) val result = kv.performBackup(packageInfo, data, flags, token, salt)
if (result == TRANSPORT_OK && packageName == MAGIC_PACKAGE_MANAGER) { if (result == TRANSPORT_OK && packageName == MAGIC_PACKAGE_MANAGER) {
// TODO move to finish backup of @pm@ so we can upload the DB before
// hook in here to back up APKs of apps that are otherwise not allowed for backup // hook in here to back up APKs of apps that are otherwise not allowed for backup
backUpApksOfNotBackedUpPackages() backUpApksOfNotBackedUpPackages()
} }
@ -392,7 +393,9 @@ internal class BackupCoordinator(
} }
// getCurrentPackage() not-null because we have state // getCurrentPackage() not-null because we have state
onPackageBackedUp(kv.getCurrentPackage()!!, BackupType.KV) onPackageBackedUp(kv.getCurrentPackage()!!, BackupType.KV)
val isPmBackup = kv.getCurrentPackage()!!.packageName == MAGIC_PACKAGE_MANAGER
kv.finishBackup() kv.finishBackup()
// TODO move @pm@ backup hook here
} }
full.hasState() -> { full.hasState() -> {
check(!kv.hasState()) { check(!kv.hasState()) {

View file

@ -8,37 +8,59 @@ import android.database.sqlite.SQLiteOpenHelper
import android.provider.BaseColumns import android.provider.BaseColumns
import java.io.File import java.io.File
import java.io.FileInputStream import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
interface KvDbManager { interface KvDbManager {
fun getDb(packageName: String): KVDb fun getDb(packageName: String, isRestore: Boolean = false): KVDb
/**
* Use only for backup.
*/
fun getDbInputStream(packageName: String): InputStream fun getDbInputStream(packageName: String): InputStream
/**
* Use only for restore.
*/
fun getDbOutputStream(packageName: String): OutputStream
/**
* Use only for backup.
*/
fun existsDb(packageName: String): Boolean fun existsDb(packageName: String): Boolean
fun deleteDb(packageName: String): Boolean fun deleteDb(packageName: String, isRestore: Boolean = false): Boolean
} }
class KvDbManagerImpl(private val context: Context) : KvDbManager { class KvDbManagerImpl(private val context: Context) : KvDbManager {
override fun getDb(packageName: String): KVDb { override fun getDb(packageName: String, isRestore: Boolean): KVDb {
return KVDbImpl(context, getFileName(packageName)) return KVDbImpl(context, getFileName(packageName, isRestore))
} }
private fun getFileName(packageName: String) = "kv_$packageName.db" private fun getFileName(packageName: String, isRestore: Boolean): String {
val prefix = if (isRestore) "restore_" else ""
return "${prefix}kv_$packageName.db"
}
private fun getDbFile(packageName: String): File { private fun getDbFile(packageName: String, isRestore: Boolean = false): File {
return context.getDatabasePath(getFileName(packageName)) return context.getDatabasePath(getFileName(packageName, isRestore))
} }
override fun getDbInputStream(packageName: String): InputStream { override fun getDbInputStream(packageName: String): InputStream {
return FileInputStream(getDbFile(packageName)) return FileInputStream(getDbFile(packageName))
} }
override fun getDbOutputStream(packageName: String): OutputStream {
return FileOutputStream(getDbFile(packageName, true))
}
override fun existsDb(packageName: String): Boolean { override fun existsDb(packageName: String): Boolean {
return getDbFile(packageName).isFile return getDbFile(packageName).isFile
} }
override fun deleteDb(packageName: String): Boolean { override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
return getDbFile(packageName).delete() return getDbFile(packageName, isRestore).delete()
} }
} }

View file

@ -15,19 +15,25 @@ import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KVDb
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
import java.io.IOException import java.io.IOException
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
import java.util.ArrayList import java.util.ArrayList
import java.util.zip.GZIPInputStream
import javax.crypto.AEADBadTagException import javax.crypto.AEADBadTagException
private class KVRestoreState( private class KVRestoreState(
val version: Byte, val version: Byte,
val token: Long, val token: Long,
val name: String,
val packageInfo: PackageInfo, val packageInfo: PackageInfo,
/** /**
* Optional [PackageInfo] for single package restore, optimizes restore of @pm@ * Optional [PackageInfo] for single package restore, optimizes restore of @pm@
*/ */
@Deprecated("TODO remove?")
val pmPackageInfo: PackageInfo? val pmPackageInfo: PackageInfo?
) )
@ -35,20 +41,25 @@ private val TAG = KVRestore::class.java.simpleName
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestore( internal class KVRestore(
private val plugin: KVRestorePlugin, private val plugin: BackupPlugin,
private val legacyPlugin: KVRestorePlugin,
private val outputFactory: OutputFactory, private val outputFactory: OutputFactory,
private val headerReader: HeaderReader, private val headerReader: HeaderReader,
private val crypto: Crypto private val crypto: Crypto,
private val dbManager: KvDbManager
) { ) {
private var state: KVRestoreState? = null private var state: KVRestoreState? = null
/** /**
* Return true if there are records stored for the given package. * Return true if there are records stored for the given package.
*
* Deprecated. Use only for v0 backups.
*/ */
@Throws(IOException::class) @Throws(IOException::class)
@Deprecated("Use BackupPlugin#hasData() instead")
suspend fun hasDataForPackage(token: Long, packageInfo: PackageInfo): Boolean { suspend fun hasDataForPackage(token: Long, packageInfo: PackageInfo): Boolean {
return plugin.hasDataForPackage(token, packageInfo) return legacyPlugin.hasDataForPackage(token, packageInfo)
} }
/** /**
@ -62,10 +73,11 @@ internal class KVRestore(
fun initializeState( fun initializeState(
version: Byte, version: Byte,
token: Long, token: Long,
name: String,
packageInfo: PackageInfo, packageInfo: PackageInfo,
pmPackageInfo: PackageInfo? = null pmPackageInfo: PackageInfo? = null
) { ) {
state = KVRestoreState(version, token, packageInfo, pmPackageInfo) state = KVRestoreState(version, token, name, packageInfo, pmPackageInfo)
} }
/** /**
@ -78,12 +90,66 @@ internal class KVRestore(
suspend fun getRestoreData(data: ParcelFileDescriptor): Int { suspend fun getRestoreData(data: ParcelFileDescriptor): Int {
val state = this.state ?: throw IllegalStateException("no state") val state = this.state ?: throw IllegalStateException("no state")
// take legacy path for version 0
if (state.version == 0x00.toByte()) return getRestoreDataV0(state, data)
return try {
val db = getRestoreDb(state)
val out = outputFactory.getBackupDataOutput(data)
db.getAll().sortedBy { it.first }.forEach { (key, value) ->
val size = value.size
Log.v(TAG, " ... key=$key size=$size")
out.writeEntityHeader(key, size)
out.writeEntityData(value, size)
}
TRANSPORT_OK
} catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR
} catch (e: IOException) {
Log.e(TAG, "Unable to process K/V backup database", e)
TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "General security exception while reading backup database", e)
TRANSPORT_ERROR
} catch (e: AEADBadTagException) {
Log.e(TAG, "Decryption failed", e)
TRANSPORT_ERROR
} finally {
dbManager.deleteDb(state.packageInfo.packageName, true)
this.state = null
closeQuietly(data)
}
}
@Throws(IOException::class, GeneralSecurityException::class, UnsupportedVersionException::class)
private suspend fun getRestoreDb(state: KVRestoreState): KVDb {
val packageName = state.packageInfo.packageName
plugin.getInputStream(state.token, state.name).use { inputStream ->
headerReader.readVersion(inputStream, state.version)
val ad = getADForKV(VERSION, packageName)
crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
GZIPInputStream(decryptedStream).use { gzipStream ->
dbManager.getDbOutputStream(packageName).use { outputStream ->
gzipStream.copyTo(outputStream)
}
}
}
}
return dbManager.getDb(packageName, true)
}
//
// v0 restore legacy code below
//
private suspend fun getRestoreDataV0(state: KVRestoreState, data: ParcelFileDescriptor): Int {
// The restore set is the concatenation of the individual record blobs, // The restore set is the concatenation of the individual record blobs,
// each of which is a file in the package's directory. // each of which is a file in the package's directory.
// We return the data in lexical order sorted by key, // We return the data in lexical order sorted by key,
// so that apps which use synthetic keys like BLOB_1, BLOB_2, etc // so that apps which use synthetic keys like BLOB_1, BLOB_2, etc
// will see the date in the most obvious order. // will see the date in the most obvious order.
val sortedKeys = getSortedKeys(state.token, state.packageInfo) val sortedKeys = getSortedKeysV0(state.token, state.packageInfo)
if (sortedKeys == null) { if (sortedKeys == null) {
// nextRestorePackage() ensures the dir exists, so this is an error // nextRestorePackage() ensures the dir exists, so this is an error
Log.e(TAG, "No keys for package: ${state.packageInfo.packageName}") Log.e(TAG, "No keys for package: ${state.packageInfo.packageName}")
@ -96,7 +162,7 @@ internal class KVRestore(
return try { return try {
val dataOutput = outputFactory.getBackupDataOutput(data) val dataOutput = outputFactory.getBackupDataOutput(data)
for (keyEntry in sortedKeys) { for (keyEntry in sortedKeys) {
readAndWriteValue(state, keyEntry, dataOutput) readAndWriteValueV0(state, keyEntry, dataOutput)
} }
TRANSPORT_OK TRANSPORT_OK
} catch (e: IOException) { } catch (e: IOException) {
@ -105,9 +171,6 @@ internal class KVRestore(
} catch (e: SecurityException) { } catch (e: SecurityException) {
Log.e(TAG, "Security exception while reading backup records", e) Log.e(TAG, "Security exception while reading backup records", e)
TRANSPORT_ERROR TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "General security exception while reading backup records", e)
TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) { } catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e) Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR TRANSPORT_ERROR
@ -124,9 +187,9 @@ internal class KVRestore(
* Return a list of the records (represented by key files) in the given directory, * Return a list of the records (represented by key files) in the given directory,
* sorted lexically by the Base64-decoded key file name, not by the on-disk filename. * sorted lexically by the Base64-decoded key file name, not by the on-disk filename.
*/ */
private suspend fun getSortedKeys(token: Long, packageInfo: PackageInfo): List<DecodedKey>? { private suspend fun getSortedKeysV0(token: Long, packageInfo: PackageInfo): List<DecodedKey>? {
val records: List<String> = try { val records: List<String> = try {
plugin.listRecords(token, packageInfo) legacyPlugin.listRecords(token, packageInfo)
} catch (e: IOException) { } catch (e: IOException) {
return null return null
} }
@ -150,24 +213,18 @@ internal class KVRestore(
/** /**
* Read the encrypted value for the given key and write it to the given [BackupDataOutput]. * Read the encrypted value for the given key and write it to the given [BackupDataOutput].
*/ */
@Suppress("Deprecation")
@Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class) @Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
private suspend fun readAndWriteValue( private suspend fun readAndWriteValueV0(
state: KVRestoreState, state: KVRestoreState,
dKey: DecodedKey, dKey: DecodedKey,
out: BackupDataOutput out: BackupDataOutput
) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key) ) = legacyPlugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
.use { inputStream -> .use { inputStream ->
val version = headerReader.readVersion(inputStream, state.version) val version = headerReader.readVersion(inputStream, state.version)
val packageName = state.packageInfo.packageName val packageName = state.packageInfo.packageName
val value = if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName, dKey.key) crypto.decryptHeader(inputStream, version, packageName, dKey.key)
crypto.decryptMultipleSegments(inputStream) val value = crypto.decryptMultipleSegments(inputStream)
} else {
val ad = getADForKV(VERSION, packageName)
crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
decryptedStream.readBytes()
}
}
val size = value.size val size = value.size
Log.v(TAG, " ... key=${dKey.key} size=$size") Log.v(TAG, " ... key=${dKey.key} size=$size")

View file

@ -207,7 +207,13 @@ internal class RestoreCoordinator(
val name = crypto.getNameForPackage(state.backupMetadata.salt, packageName) val name = crypto.getNameForPackage(state.backupMetadata.salt, packageName)
if (plugin.hasData(state.token, name)) { if (plugin.hasData(state.token, name)) {
Log.i(TAG, "Found K/V data for $packageName.") Log.i(TAG, "Found K/V data for $packageName.")
kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo) kv.initializeState(
version = version,
token = state.token,
name = name,
packageInfo = packageInfo,
pmPackageInfo = state.pmPackageInfo
)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_KEY_VALUE TYPE_KEY_VALUE
} else throw IOException("No data found for $packageName. Skipping.") } else throw IOException("No data found for $packageName. Skipping.")
@ -243,7 +249,7 @@ internal class RestoreCoordinator(
// check key/value data first and if available, don't even check for full data // check key/value data first and if available, don't even check for full data
kv.hasDataForPackage(state.token, packageInfo) -> { kv.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found K/V data for $packageName.") Log.i(TAG, "Found K/V data for $packageName.")
kv.initializeState(0x00, state.token, packageInfo, state.pmPackageInfo) kv.initializeState(0x00, state.token, "", packageInfo, state.pmPackageInfo)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_KEY_VALUE TYPE_KEY_VALUE
} }

View file

@ -5,7 +5,7 @@ import org.koin.dsl.module
val restoreModule = module { val restoreModule = module {
single { OutputFactory() } single { OutputFactory() }
single { KVRestore(get<RestorePlugin>().kvRestorePlugin, get(), get(), get()) } single { KVRestore(get(), get<RestorePlugin>().kvRestorePlugin, get(), get(), get(), get()) }
single { FullRestore(get(), get<RestorePlugin>().fullRestorePlugin, get(), get(), get()) } single { FullRestore(get(), get<RestorePlugin>().fullRestorePlugin, get(), get(), get()) }
single { single {
RestoreCoordinator(androidContext(), get(), get(), get(), get(), get(), get(), get(), get()) RestoreCoordinator(androidContext(), get(), get(), get(), get(), get(), get(), get(), get())

View file

@ -10,7 +10,6 @@ import android.os.ParcelFileDescriptor
import com.stevesoltys.seedvault.crypto.CipherFactoryImpl import com.stevesoltys.seedvault.crypto.CipherFactoryImpl
import com.stevesoltys.seedvault.crypto.CryptoImpl import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.BackupType import com.stevesoltys.seedvault.metadata.BackupType
@ -39,6 +38,7 @@ import io.mockk.coEvery
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
@ -81,7 +81,14 @@ internal class CoordinatorIntegrationTest : TransportTest() {
) )
private val kvRestorePlugin = mockk<KVRestorePlugin>() private val kvRestorePlugin = mockk<KVRestorePlugin>()
private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl) private val kvRestore = KVRestore(
backupPlugin,
kvRestorePlugin,
outputFactory,
headerReader,
cryptoImpl,
dbManager
)
private val fullRestorePlugin = mockk<FullRestorePlugin>() private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore = private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl) FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)
@ -104,9 +111,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
private val metadataOutputStream = ByteArrayOutputStream() private val metadataOutputStream = ByteArrayOutputStream()
private val packageMetadata = PackageMetadata(time = 0L) private val packageMetadata = PackageMetadata(time = 0L)
private val key = "RestoreKey" private val key = "RestoreKey"
private val key64 = key.encodeBase64()
private val key2 = "RestoreKey2" private val key2 = "RestoreKey2"
private val key264 = key2.encodeBase64()
// as we use real crypto, we need a real name for packageInfo // as we use real crypto, we need a real name for packageInfo
private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName) private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName)
@ -116,7 +121,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
val value = CapturingSlot<ByteArray>() val value = CapturingSlot<ByteArray>()
val value2 = CapturingSlot<ByteArray>() val value2 = CapturingSlot<ByteArray>()
val bOutputStream = ByteArrayOutputStream() val bOutputStream = ByteArrayOutputStream()
val bOutputStream2 = ByteArrayOutputStream()
every { settingsManager.getToken() } returns token every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt every { metadataManager.salt } returns salt
@ -170,29 +174,21 @@ internal class CoordinatorIntegrationTest : TransportTest() {
// restore finds the backed up key and writes the decrypted value // restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>() val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray()) val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
val rInputStream2 = ByteArrayInputStream(bOutputStream2.toByteArray()) coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64, key264)
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key64
)
} returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137 every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key264
)
} returns rInputStream2
every { backupDataOutput.writeEntityHeader(key2, appData2.size) } returns 1137 every { backupDataOutput.writeEntityHeader(key2, appData2.size) } returns 1137
every { backupDataOutput.writeEntityData(appData2, appData2.size) } returns appData2.size every { backupDataOutput.writeEntityData(appData2, appData2.size) } returns appData2.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verify {
backupDataOutput.writeEntityHeader(key, appData.size)
backupDataOutput.writeEntityData(appData, appData.size)
backupDataOutput.writeEntityHeader(key2, appData2.size)
backupDataOutput.writeEntityData(appData2, appData2.size)
}
} }
@Test @Test
@ -246,19 +242,17 @@ internal class CoordinatorIntegrationTest : TransportTest() {
// restore finds the backed up key and writes the decrypted value // restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>() val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray()) val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64) coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key64
)
} returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137 every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verify {
backupDataOutput.writeEntityHeader(key, appData.size)
backupDataOutput.writeEntityData(appData, appData.size)
}
} }
@Test @Test

View file

@ -3,22 +3,29 @@ package com.stevesoltys.seedvault.transport.backup
import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.toHexString import com.stevesoltys.seedvault.toHexString
import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertFalse
import junit.framework.Assert.assertNull
import junit.framework.Assert.assertTrue
import org.json.JSONObject import org.json.JSONObject
import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
import kotlin.random.Random import kotlin.random.Random
class TestKvDbManager : KvDbManager { class TestKvDbManager : KvDbManager {
private var db: TestKVDb? = null private var db: TestKVDb? = null
private val outputStream = ByteArrayOutputStream()
override fun getDb(packageName: String): KVDb { override fun getDb(packageName: String, isRestore: Boolean): KVDb {
if (isRestore) {
readDbFromStream(ByteArrayInputStream(outputStream.toByteArray()))
return this.db!!
}
return TestKVDb().apply { db = this } return TestKVDb().apply { db = this }
} }
@ -26,11 +33,16 @@ class TestKvDbManager : KvDbManager {
return ByteArrayInputStream(db!!.serialize().toByteArray()) return ByteArrayInputStream(db!!.serialize().toByteArray())
} }
override fun getDbOutputStream(packageName: String): OutputStream {
outputStream.reset()
return outputStream
}
override fun existsDb(packageName: String): Boolean { override fun existsDb(packageName: String): Boolean {
return db != null return db != null
} }
override fun deleteDb(packageName: String): Boolean { override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
clearDb() clearDb()
return true return true
} }

View file

@ -10,48 +10,57 @@ import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KVDb
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.mockkStatic import io.mockk.mockkStatic
import io.mockk.verify
import io.mockk.verifyAll import io.mockk.verifyAll
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
import java.util.zip.GZIPOutputStream
import kotlin.random.Random import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestoreTest : RestoreTest() { internal class KVRestoreTest : RestoreTest() {
private val plugin = mockk<KVRestorePlugin>() private val plugin = mockk<BackupPlugin>()
private val legacyPlugin = mockk<KVRestorePlugin>()
private val dbManager = mockk<KvDbManager>()
private val output = mockk<BackupDataOutput>() private val output = mockk<BackupDataOutput>()
private val restore = KVRestore(plugin, outputFactory, headerReader, crypto) private val restore =
KVRestore(plugin, legacyPlugin, outputFactory, headerReader, crypto, dbManager)
private val db = mockk<KVDb>()
private val ad = getADForKV(VERSION, packageInfo.packageName) private val ad = getADForKV(VERSION, packageInfo.packageName)
private val key = "Restore Key" private val key = "Restore Key"
private val key64 = key.encodeBase64() private val key64 = key.encodeBase64()
private val key2 = "Restore Key2" private val key2 = "Restore Key2"
private val key264 = key2.encodeBase64() private val key264 = key2.encodeBase64()
private val data2 = getRandomByteArray()
private val outputStream = ByteArrayOutputStream().apply {
GZIPOutputStream(this).close()
}
private val decryptInputStream = ByteArrayInputStream(outputStream.toByteArray())
init { init {
// for InputStream#readBytes() // for InputStream#readBytes()
mockkStatic("kotlin.io.ByteStreamsKt") mockkStatic("kotlin.io.ByteStreamsKt")
} }
@Test
fun `hasDataForPackage() delegates to plugin`() = runBlocking {
val result = Random.nextBoolean()
coEvery { plugin.hasDataForPackage(token, packageInfo) } returns result
assertEquals(result, restore.hasDataForPackage(token, packageInfo))
}
@Test @Test
fun `getRestoreData() throws without initializing state`() { fun `getRestoreData() throws without initializing state`() {
coAssertThrows(IllegalStateException::class.java) { coAssertThrows(IllegalStateException::class.java) {
@ -60,22 +69,133 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `listing records throws`() = runBlocking { fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo) restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.listRecords(token, packageInfo) } throws IOException() coEvery { plugin.getInputStream(token, name) } returns inputStream
every {
headerReader.readVersion(inputStream, VERSION)
} throws UnsupportedVersionException(Byte.MAX_VALUE)
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
}
@Test
fun `newDecryptingStream throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verifyAll {
dbManager.deleteDb(packageInfo.packageName, true)
}
}
@Test
fun `writeEntityHeader throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
every {
dbManager.getDbOutputStream(packageInfo.packageName)
} returns ByteArrayOutputStream()
every { dbManager.getDb(packageInfo.packageName, true) } returns db
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
every { db.getAll() } returns listOf(Pair(key, data))
every { output.writeEntityHeader(key, data.size) } throws IOException()
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verify {
dbManager.deleteDb(packageInfo.packageName, true)
}
}
@Test
fun `two records get restored`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
every {
dbManager.getDbOutputStream(packageInfo.packageName)
} returns ByteArrayOutputStream()
every { dbManager.getDb(packageInfo.packageName, true) } returns db
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
every { db.getAll() } returns listOf(
Pair(key, data),
Pair(key2, data2)
)
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
every { output.writeEntityHeader(key2, data2.size) } returns 42
every { output.writeEntityData(data2, data2.size) } returns data2.size
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verify {
output.writeEntityHeader(key, data.size)
output.writeEntityData(data, data.size)
output.writeEntityHeader(key2, data2.size)
output.writeEntityData(data2, data2.size)
dbManager.deleteDb(packageInfo.packageName, true)
}
}
//
// v0 legacy tests below
//
@Test
@Suppress("Deprecation")
fun `v0 hasDataForPackage() delegates to plugin`() = runBlocking {
val result = Random.nextBoolean()
coEvery { legacyPlugin.hasDataForPackage(token, packageInfo) } returns result
assertEquals(result, restore.hasDataForPackage(token, packageInfo))
}
@Test
@Suppress("Deprecation")
fun `v0 listing records throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
coEvery { legacyPlugin.listRecords(token, packageInfo) } throws IOException()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
} }
@Test @Test
fun `reading VersionHeader with unsupported version throws`() = runBlocking { fun `v0 reading VersionHeader with unsupported version throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo) restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { every {
headerReader.readVersion(inputStream, VERSION) headerReader.readVersion(inputStream, 0x00)
} throws UnsupportedVersionException(unsupportedVersion) } throws UnsupportedVersionException(unsupportedVersion)
streamsGetClosed() streamsGetClosed()
@ -84,12 +204,14 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `error reading VersionHeader throws`() = runBlocking { fun `v0 error reading VersionHeader throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo) restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } throws IOException() legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } throws IOException()
streamsGetClosed() streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -97,13 +219,18 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `decrypting stream throws`() = runBlocking { @Suppress("deprecation")
restore.initializeState(VERSION, token, packageInfo) fun `v0 decrypting stream throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() } returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} throws IOException()
streamsGetClosed() streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -111,13 +238,19 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `decrypting stream throws security exception`() = runBlocking { @Suppress("deprecation")
restore.initializeState(VERSION, token, packageInfo) fun `v0 decrypting stream throws security exception`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException() } returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} returns VersionHeader(0x00, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } throws IOException()
streamsGetClosed() streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -125,14 +258,19 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `writing header throws`() = runBlocking { @Suppress("Deprecation")
restore.initializeState(VERSION, token, packageInfo) fun `v0 writing header throws`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream } returns inputStream
every { decryptedInputStream.readBytes() } returns data every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} returns VersionHeader(0x00, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } throws IOException() every { output.writeEntityHeader(key, data.size) } throws IOException()
streamsGetClosed() streamsGetClosed()
@ -141,14 +279,19 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `writing value throws`() = runBlocking { @Suppress("deprecation")
restore.initializeState(VERSION, token, packageInfo) fun `v0 writing value throws`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream } returns inputStream
every { decryptedInputStream.readBytes() } returns data every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } throws IOException() every { output.writeEntityData(data, data.size) } throws IOException()
streamsGetClosed() streamsGetClosed()
@ -158,14 +301,19 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `writing value succeeds`() = runBlocking { @Suppress("deprecation")
restore.initializeState(VERSION, token, packageInfo) fun `v0 writing value succeeds`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream } returns inputStream
every { decryptedInputStream.readBytes() } returns data every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size every { output.writeEntityData(data, data.size) } returns data.size
streamsGetClosed() streamsGetClosed()
@ -175,14 +323,17 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `writing value uses old v0 code`() = runBlocking { @Suppress("deprecation")
restore.initializeState(0.toByte(), token, packageInfo) fun `v0 writing value uses old v0 code`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte() legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every { every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key) crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(VERSION, packageInfo.packageName, key) } returns VersionHeader(VERSION, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
@ -194,43 +345,35 @@ internal class KVRestoreTest : RestoreTest() {
} }
@Test @Test
fun `unexpected version aborts with error`() = runBlocking { @Suppress("Deprecation")
restore.initializeState(Byte.MAX_VALUE, token, packageInfo) fun `v0 writing two values succeeds`() = runBlocking {
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every {
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
} throws GeneralSecurityException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
}
@Test
fun `writing two values succeeds`() = runBlocking {
val data2 = getRandomByteArray() val data2 = getRandomByteArray()
val inputStream2 = mockk<InputStream>() val inputStream2 = mockk<InputStream>()
val decryptedInputStream2 = mockk<InputStream>() restore.initializeState(0, token, name, packageInfo)
restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput(listOf(key64, key264)) getRecordsAndOutput(listOf(key64, key264))
// first key/value // first key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery {
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream } returns inputStream
every { decryptedInputStream.readBytes() } returns data every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size every { output.writeEntityData(data, data.size) } returns data.size
// second key/value // second key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2 coEvery {
every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION legacyPlugin.getInputStreamForRecord(token, packageInfo, key264)
every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2 } returns inputStream2
every { decryptedInputStream2.readBytes() } returns data2 every { headerReader.readVersion(inputStream2, 0) } returns 0
every {
crypto.decryptHeader(inputStream2, 0, packageInfo.packageName, key2)
} returns VersionHeader(0, packageInfo.packageName, key2)
every { crypto.decryptMultipleSegments(inputStream2) } returns data2
every { output.writeEntityHeader(key2, data2.size) } returns 42 every { output.writeEntityHeader(key2, data2.size) } returns 42
every { output.writeEntityData(data2, data2.size) } returns data2.size every { output.writeEntityData(data2, data2.size) } returns data2.size
every { decryptedInputStream2.close() } just Runs
every { inputStream2.close() } just Runs every { inputStream2.close() } just Runs
streamsGetClosed() streamsGetClosed()
@ -238,12 +381,11 @@ internal class KVRestoreTest : RestoreTest() {
} }
private fun getRecordsAndOutput(recordKeys: List<String> = listOf(key64)) { private fun getRecordsAndOutput(recordKeys: List<String> = listOf(key64)) {
coEvery { plugin.listRecords(token, packageInfo) } returns recordKeys coEvery { legacyPlugin.listRecords(token, packageInfo) } returns recordKeys
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
} }
private fun streamsGetClosed() { private fun streamsGetClosed() {
every { decryptedInputStream.close() } just Runs
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
} }

View file

@ -223,19 +223,20 @@ internal class RestoreCoordinatorTest : TransportTest() {
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true coEvery { plugin.hasData(token, name) } returns true
every { kv.initializeState(VERSION, token, packageInfo) } just Runs every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
} }
@Test @Test
@Suppress("Deprecation")
fun `v0 nextRestorePackage() returns KV description and takes precedence`() = runBlocking { fun `v0 nextRestorePackage() returns KV description and takes precedence`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00)) restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray) restore.startRestore(token, packageInfoArray)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(0x00, token, packageInfo) } just Runs every { kv.initializeState(0x00, token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -292,7 +293,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true coEvery { plugin.hasData(token, name) } returns true
every { kv.initializeState(VERSION, token, packageInfo) } just Runs every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -315,7 +316,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.startRestore(token, packageInfoArray2) restore.startRestore(token, packageInfoArray2)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(0.toByte(), token, packageInfo) } just Runs every { kv.initializeState(0.toByte(), token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -331,6 +332,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
} }
@Test @Test
@Suppress("Deprecation")
fun `v0 when kv#hasDataForPackage() throws, it tries next package`() = runBlocking { fun `v0 when kv#hasDataForPackage() throws, it tries next package`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00)) restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray) restore.startRestore(token, packageInfoArray)

View file

@ -16,6 +16,7 @@ import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.transport.TransportTest import com.stevesoltys.seedvault.transport.TransportTest
import com.stevesoltys.seedvault.transport.backup.BackupPlugin import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
@ -44,12 +45,20 @@ internal class RestoreV0IntegrationTest : TransportTest() {
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader) private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val dbManager = mockk<KvDbManager>()
private val metadataReader = MetadataReaderImpl(cryptoImpl) private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>() private val notificationManager = mockk<BackupNotificationManager>()
private val backupPlugin = mockk<BackupPlugin>() private val backupPlugin = mockk<BackupPlugin>()
private val kvRestorePlugin = mockk<KVRestorePlugin>() private val kvRestorePlugin = mockk<KVRestorePlugin>()
private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl) private val kvRestore = KVRestore(
backupPlugin,
kvRestorePlugin,
outputFactory,
headerReader,
cryptoImpl,
dbManager
)
private val fullRestorePlugin = mockk<FullRestorePlugin>() private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore = private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl) FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)