K/V restore using single file

This commit is contained in:
Torsten Grote 2021-09-22 14:41:49 +02:00 committed by Chirayu Desai
parent 0c915e5eb8
commit a0f3c6b45f
10 changed files with 405 additions and 158 deletions

View file

@ -272,6 +272,7 @@ internal class BackupCoordinator(
val salt = metadataManager.salt
val result = kv.performBackup(packageInfo, data, flags, token, salt)
if (result == TRANSPORT_OK && packageName == MAGIC_PACKAGE_MANAGER) {
// TODO move to finish backup of @pm@ so we can upload the DB before
// hook in here to back up APKs of apps that are otherwise not allowed for backup
backUpApksOfNotBackedUpPackages()
}
@ -392,7 +393,9 @@ internal class BackupCoordinator(
}
// getCurrentPackage() not-null because we have state
onPackageBackedUp(kv.getCurrentPackage()!!, BackupType.KV)
val isPmBackup = kv.getCurrentPackage()!!.packageName == MAGIC_PACKAGE_MANAGER
kv.finishBackup()
// TODO move @pm@ backup hook here
}
full.hasState() -> {
check(!kv.hasState()) {

View file

@ -8,37 +8,59 @@ import android.database.sqlite.SQLiteOpenHelper
import android.provider.BaseColumns
import java.io.File
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.InputStream
import java.io.OutputStream
interface KvDbManager {
fun getDb(packageName: String): KVDb
fun getDb(packageName: String, isRestore: Boolean = false): KVDb
/**
* Use only for backup.
*/
fun getDbInputStream(packageName: String): InputStream
/**
* Use only for restore.
*/
fun getDbOutputStream(packageName: String): OutputStream
/**
* Use only for backup.
*/
fun existsDb(packageName: String): Boolean
fun deleteDb(packageName: String): Boolean
fun deleteDb(packageName: String, isRestore: Boolean = false): Boolean
}
class KvDbManagerImpl(private val context: Context) : KvDbManager {
override fun getDb(packageName: String): KVDb {
return KVDbImpl(context, getFileName(packageName))
override fun getDb(packageName: String, isRestore: Boolean): KVDb {
return KVDbImpl(context, getFileName(packageName, isRestore))
}
private fun getFileName(packageName: String) = "kv_$packageName.db"
private fun getFileName(packageName: String, isRestore: Boolean): String {
val prefix = if (isRestore) "restore_" else ""
return "${prefix}kv_$packageName.db"
}
private fun getDbFile(packageName: String): File {
return context.getDatabasePath(getFileName(packageName))
private fun getDbFile(packageName: String, isRestore: Boolean = false): File {
return context.getDatabasePath(getFileName(packageName, isRestore))
}
override fun getDbInputStream(packageName: String): InputStream {
return FileInputStream(getDbFile(packageName))
}
override fun getDbOutputStream(packageName: String): OutputStream {
return FileOutputStream(getDbFile(packageName, true))
}
override fun existsDb(packageName: String): Boolean {
return getDbFile(packageName).isFile
}
override fun deleteDb(packageName: String): Boolean {
return getDbFile(packageName).delete()
override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
return getDbFile(packageName, isRestore).delete()
}
}

View file

@ -15,19 +15,25 @@ import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KVDb
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import libcore.io.IoUtils.closeQuietly
import java.io.IOException
import java.security.GeneralSecurityException
import java.util.ArrayList
import java.util.zip.GZIPInputStream
import javax.crypto.AEADBadTagException
private class KVRestoreState(
val version: Byte,
val token: Long,
val name: String,
val packageInfo: PackageInfo,
/**
* Optional [PackageInfo] for single package restore, optimizes restore of @pm@
*/
@Deprecated("TODO remove?")
val pmPackageInfo: PackageInfo?
)
@ -35,20 +41,25 @@ private val TAG = KVRestore::class.java.simpleName
@Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestore(
private val plugin: KVRestorePlugin,
private val plugin: BackupPlugin,
private val legacyPlugin: KVRestorePlugin,
private val outputFactory: OutputFactory,
private val headerReader: HeaderReader,
private val crypto: Crypto
private val crypto: Crypto,
private val dbManager: KvDbManager
) {
private var state: KVRestoreState? = null
/**
* Return true if there are records stored for the given package.
*
* Deprecated. Use only for v0 backups.
*/
@Throws(IOException::class)
@Deprecated("Use BackupPlugin#hasData() instead")
suspend fun hasDataForPackage(token: Long, packageInfo: PackageInfo): Boolean {
return plugin.hasDataForPackage(token, packageInfo)
return legacyPlugin.hasDataForPackage(token, packageInfo)
}
/**
@ -62,10 +73,11 @@ internal class KVRestore(
fun initializeState(
version: Byte,
token: Long,
name: String,
packageInfo: PackageInfo,
pmPackageInfo: PackageInfo? = null
) {
state = KVRestoreState(version, token, packageInfo, pmPackageInfo)
state = KVRestoreState(version, token, name, packageInfo, pmPackageInfo)
}
/**
@ -78,12 +90,66 @@ internal class KVRestore(
suspend fun getRestoreData(data: ParcelFileDescriptor): Int {
val state = this.state ?: throw IllegalStateException("no state")
// take legacy path for version 0
if (state.version == 0x00.toByte()) return getRestoreDataV0(state, data)
return try {
val db = getRestoreDb(state)
val out = outputFactory.getBackupDataOutput(data)
db.getAll().sortedBy { it.first }.forEach { (key, value) ->
val size = value.size
Log.v(TAG, " ... key=$key size=$size")
out.writeEntityHeader(key, size)
out.writeEntityData(value, size)
}
TRANSPORT_OK
} catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR
} catch (e: IOException) {
Log.e(TAG, "Unable to process K/V backup database", e)
TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "General security exception while reading backup database", e)
TRANSPORT_ERROR
} catch (e: AEADBadTagException) {
Log.e(TAG, "Decryption failed", e)
TRANSPORT_ERROR
} finally {
dbManager.deleteDb(state.packageInfo.packageName, true)
this.state = null
closeQuietly(data)
}
}
@Throws(IOException::class, GeneralSecurityException::class, UnsupportedVersionException::class)
private suspend fun getRestoreDb(state: KVRestoreState): KVDb {
val packageName = state.packageInfo.packageName
plugin.getInputStream(state.token, state.name).use { inputStream ->
headerReader.readVersion(inputStream, state.version)
val ad = getADForKV(VERSION, packageName)
crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
GZIPInputStream(decryptedStream).use { gzipStream ->
dbManager.getDbOutputStream(packageName).use { outputStream ->
gzipStream.copyTo(outputStream)
}
}
}
}
return dbManager.getDb(packageName, true)
}
//
// v0 restore legacy code below
//
private suspend fun getRestoreDataV0(state: KVRestoreState, data: ParcelFileDescriptor): Int {
// The restore set is the concatenation of the individual record blobs,
// each of which is a file in the package's directory.
// We return the data in lexical order sorted by key,
// so that apps which use synthetic keys like BLOB_1, BLOB_2, etc
// will see the date in the most obvious order.
val sortedKeys = getSortedKeys(state.token, state.packageInfo)
val sortedKeys = getSortedKeysV0(state.token, state.packageInfo)
if (sortedKeys == null) {
// nextRestorePackage() ensures the dir exists, so this is an error
Log.e(TAG, "No keys for package: ${state.packageInfo.packageName}")
@ -96,7 +162,7 @@ internal class KVRestore(
return try {
val dataOutput = outputFactory.getBackupDataOutput(data)
for (keyEntry in sortedKeys) {
readAndWriteValue(state, keyEntry, dataOutput)
readAndWriteValueV0(state, keyEntry, dataOutput)
}
TRANSPORT_OK
} catch (e: IOException) {
@ -105,9 +171,6 @@ internal class KVRestore(
} catch (e: SecurityException) {
Log.e(TAG, "Security exception while reading backup records", e)
TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "General security exception while reading backup records", e)
TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR
@ -124,9 +187,9 @@ internal class KVRestore(
* Return a list of the records (represented by key files) in the given directory,
* sorted lexically by the Base64-decoded key file name, not by the on-disk filename.
*/
private suspend fun getSortedKeys(token: Long, packageInfo: PackageInfo): List<DecodedKey>? {
private suspend fun getSortedKeysV0(token: Long, packageInfo: PackageInfo): List<DecodedKey>? {
val records: List<String> = try {
plugin.listRecords(token, packageInfo)
legacyPlugin.listRecords(token, packageInfo)
} catch (e: IOException) {
return null
}
@ -150,24 +213,18 @@ internal class KVRestore(
/**
* Read the encrypted value for the given key and write it to the given [BackupDataOutput].
*/
@Suppress("Deprecation")
@Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
private suspend fun readAndWriteValue(
private suspend fun readAndWriteValueV0(
state: KVRestoreState,
dKey: DecodedKey,
out: BackupDataOutput
) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
) = legacyPlugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
.use { inputStream ->
val version = headerReader.readVersion(inputStream, state.version)
val packageName = state.packageInfo.packageName
val value = if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName, dKey.key)
crypto.decryptMultipleSegments(inputStream)
} else {
val ad = getADForKV(VERSION, packageName)
crypto.newDecryptingStream(inputStream, ad).use { decryptedStream ->
decryptedStream.readBytes()
}
}
val value = crypto.decryptMultipleSegments(inputStream)
val size = value.size
Log.v(TAG, " ... key=${dKey.key} size=$size")

View file

@ -207,7 +207,13 @@ internal class RestoreCoordinator(
val name = crypto.getNameForPackage(state.backupMetadata.salt, packageName)
if (plugin.hasData(state.token, name)) {
Log.i(TAG, "Found K/V data for $packageName.")
kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo)
kv.initializeState(
version = version,
token = state.token,
name = name,
packageInfo = packageInfo,
pmPackageInfo = state.pmPackageInfo
)
state.currentPackage = packageName
TYPE_KEY_VALUE
} else throw IOException("No data found for $packageName. Skipping.")
@ -243,7 +249,7 @@ internal class RestoreCoordinator(
// check key/value data first and if available, don't even check for full data
kv.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found K/V data for $packageName.")
kv.initializeState(0x00, state.token, packageInfo, state.pmPackageInfo)
kv.initializeState(0x00, state.token, "", packageInfo, state.pmPackageInfo)
state.currentPackage = packageName
TYPE_KEY_VALUE
}

View file

@ -5,7 +5,7 @@ import org.koin.dsl.module
val restoreModule = module {
single { OutputFactory() }
single { KVRestore(get<RestorePlugin>().kvRestorePlugin, get(), get(), get()) }
single { KVRestore(get(), get<RestorePlugin>().kvRestorePlugin, get(), get(), get(), get()) }
single { FullRestore(get(), get<RestorePlugin>().fullRestorePlugin, get(), get(), get()) }
single {
RestoreCoordinator(androidContext(), get(), get(), get(), get(), get(), get(), get(), get())

View file

@ -10,7 +10,6 @@ import android.os.ParcelFileDescriptor
import com.stevesoltys.seedvault.crypto.CipherFactoryImpl
import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.BackupType
@ -39,6 +38,7 @@ import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
@ -81,7 +81,14 @@ internal class CoordinatorIntegrationTest : TransportTest() {
)
private val kvRestorePlugin = mockk<KVRestorePlugin>()
private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl)
private val kvRestore = KVRestore(
backupPlugin,
kvRestorePlugin,
outputFactory,
headerReader,
cryptoImpl,
dbManager
)
private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)
@ -104,9 +111,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
private val metadataOutputStream = ByteArrayOutputStream()
private val packageMetadata = PackageMetadata(time = 0L)
private val key = "RestoreKey"
private val key64 = key.encodeBase64()
private val key2 = "RestoreKey2"
private val key264 = key2.encodeBase64()
// as we use real crypto, we need a real name for packageInfo
private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName)
@ -116,7 +121,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
val value = CapturingSlot<ByteArray>()
val value2 = CapturingSlot<ByteArray>()
val bOutputStream = ByteArrayOutputStream()
val bOutputStream2 = ByteArrayOutputStream()
every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt
@ -170,29 +174,21 @@ internal class CoordinatorIntegrationTest : TransportTest() {
// restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
val rInputStream2 = ByteArrayInputStream(bOutputStream2.toByteArray())
coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64, key264)
coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key64
)
} returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key264
)
} returns rInputStream2
every { backupDataOutput.writeEntityHeader(key2, appData2.size) } returns 1137
every { backupDataOutput.writeEntityData(appData2, appData2.size) } returns appData2.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verify {
backupDataOutput.writeEntityHeader(key, appData.size)
backupDataOutput.writeEntityData(appData, appData.size)
backupDataOutput.writeEntityHeader(key2, appData2.size)
backupDataOutput.writeEntityData(appData2, appData2.size)
}
}
@Test
@ -246,19 +242,17 @@ internal class CoordinatorIntegrationTest : TransportTest() {
// restore finds the backed up key and writes the decrypted value
val backupDataOutput = mockk<BackupDataOutput>()
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
coEvery { kvRestorePlugin.listRecords(token, packageInfo) } returns listOf(key64)
coEvery { backupPlugin.getInputStream(token, name) } returns rInputStream
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns backupDataOutput
coEvery {
kvRestorePlugin.getInputStreamForRecord(
token,
packageInfo,
key64
)
} returns rInputStream
every { backupDataOutput.writeEntityHeader(key, appData.size) } returns 1137
every { backupDataOutput.writeEntityData(appData, appData.size) } returns appData.size
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verify {
backupDataOutput.writeEntityHeader(key, appData.size)
backupDataOutput.writeEntityData(appData, appData.size)
}
}
@Test

View file

@ -3,22 +3,29 @@ package com.stevesoltys.seedvault.transport.backup
import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.toHexString
import junit.framework.Assert.assertEquals
import junit.framework.Assert.assertFalse
import junit.framework.Assert.assertNull
import junit.framework.Assert.assertTrue
import org.json.JSONObject
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.io.OutputStream
import kotlin.random.Random
class TestKvDbManager : KvDbManager {
private var db: TestKVDb? = null
private val outputStream = ByteArrayOutputStream()
override fun getDb(packageName: String): KVDb {
override fun getDb(packageName: String, isRestore: Boolean): KVDb {
if (isRestore) {
readDbFromStream(ByteArrayInputStream(outputStream.toByteArray()))
return this.db!!
}
return TestKVDb().apply { db = this }
}
@ -26,11 +33,16 @@ class TestKvDbManager : KvDbManager {
return ByteArrayInputStream(db!!.serialize().toByteArray())
}
override fun getDbOutputStream(packageName: String): OutputStream {
outputStream.reset()
return outputStream
}
override fun existsDb(packageName: String): Boolean {
return db != null
}
override fun deleteDb(packageName: String): Boolean {
override fun deleteDb(packageName: String, isRestore: Boolean): Boolean {
clearDb()
return true
}

View file

@ -10,48 +10,57 @@ import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KVDb
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.verify
import io.mockk.verifyAll
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.security.GeneralSecurityException
import java.util.zip.GZIPOutputStream
import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext")
internal class KVRestoreTest : RestoreTest() {
private val plugin = mockk<KVRestorePlugin>()
private val plugin = mockk<BackupPlugin>()
private val legacyPlugin = mockk<KVRestorePlugin>()
private val dbManager = mockk<KvDbManager>()
private val output = mockk<BackupDataOutput>()
private val restore = KVRestore(plugin, outputFactory, headerReader, crypto)
private val restore =
KVRestore(plugin, legacyPlugin, outputFactory, headerReader, crypto, dbManager)
private val db = mockk<KVDb>()
private val ad = getADForKV(VERSION, packageInfo.packageName)
private val key = "Restore Key"
private val key64 = key.encodeBase64()
private val key2 = "Restore Key2"
private val key264 = key2.encodeBase64()
private val data2 = getRandomByteArray()
private val outputStream = ByteArrayOutputStream().apply {
GZIPOutputStream(this).close()
}
private val decryptInputStream = ByteArrayInputStream(outputStream.toByteArray())
init {
// for InputStream#readBytes()
mockkStatic("kotlin.io.ByteStreamsKt")
}
@Test
fun `hasDataForPackage() delegates to plugin`() = runBlocking {
val result = Random.nextBoolean()
coEvery { plugin.hasDataForPackage(token, packageInfo) } returns result
assertEquals(result, restore.hasDataForPackage(token, packageInfo))
}
@Test
fun `getRestoreData() throws without initializing state`() {
coAssertThrows(IllegalStateException::class.java) {
@ -60,22 +69,133 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `listing records throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.listRecords(token, packageInfo) } throws IOException()
coEvery { plugin.getInputStream(token, name) } returns inputStream
every {
headerReader.readVersion(inputStream, VERSION)
} throws UnsupportedVersionException(Byte.MAX_VALUE)
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
}
@Test
fun `newDecryptingStream throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verifyAll {
dbManager.deleteDb(packageInfo.packageName, true)
}
}
@Test
fun `writeEntityHeader throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
every {
dbManager.getDbOutputStream(packageInfo.packageName)
} returns ByteArrayOutputStream()
every { dbManager.getDb(packageInfo.packageName, true) } returns db
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
every { db.getAll() } returns listOf(Pair(key, data))
every { output.writeEntityHeader(key, data.size) } throws IOException()
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verify {
dbManager.deleteDb(packageInfo.packageName, true)
}
}
@Test
fun `two records get restored`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { plugin.getInputStream(token, name) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptInputStream
every {
dbManager.getDbOutputStream(packageInfo.packageName)
} returns ByteArrayOutputStream()
every { dbManager.getDb(packageInfo.packageName, true) } returns db
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
every { db.getAll() } returns listOf(
Pair(key, data),
Pair(key2, data2)
)
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
every { output.writeEntityHeader(key2, data2.size) } returns 42
every { output.writeEntityData(data2, data2.size) } returns data2.size
every { dbManager.deleteDb(packageInfo.packageName, true) } returns true
streamsGetClosed()
assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
verify {
output.writeEntityHeader(key, data.size)
output.writeEntityData(data, data.size)
output.writeEntityHeader(key2, data2.size)
output.writeEntityData(data2, data2.size)
dbManager.deleteDb(packageInfo.packageName, true)
}
}
//
// v0 legacy tests below
//
@Test
@Suppress("Deprecation")
fun `v0 hasDataForPackage() delegates to plugin`() = runBlocking {
val result = Random.nextBoolean()
coEvery { legacyPlugin.hasDataForPackage(token, packageInfo) } returns result
assertEquals(result, restore.hasDataForPackage(token, packageInfo))
}
@Test
@Suppress("Deprecation")
fun `v0 listing records throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
coEvery { legacyPlugin.listRecords(token, packageInfo) } throws IOException()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
}
@Test
fun `reading VersionHeader with unsupported version throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
fun `v0 reading VersionHeader with unsupported version throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every {
headerReader.readVersion(inputStream, VERSION)
headerReader.readVersion(inputStream, 0x00)
} throws UnsupportedVersionException(unsupportedVersion)
streamsGetClosed()
@ -84,12 +204,14 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `error reading VersionHeader throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
fun `v0 error reading VersionHeader throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -97,13 +219,18 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `decrypting stream throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
@Suppress("deprecation")
fun `v0 decrypting stream throws`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -111,13 +238,19 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `decrypting stream throws security exception`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
@Suppress("deprecation")
fun `v0 decrypting stream throws security exception`() = runBlocking {
restore.initializeState(0x00, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException()
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0x00) } returns 0x00
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} returns VersionHeader(0x00, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } throws IOException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -125,14 +258,19 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `writing header throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
@Suppress("Deprecation")
fun `v0 writing header throws`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0x00, packageInfo.packageName, key)
} returns VersionHeader(0x00, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } throws IOException()
streamsGetClosed()
@ -141,14 +279,19 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `writing value throws`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
@Suppress("deprecation")
fun `v0 writing value throws`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } throws IOException()
streamsGetClosed()
@ -158,14 +301,19 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `writing value succeeds`() = runBlocking {
restore.initializeState(VERSION, token, packageInfo)
@Suppress("deprecation")
fun `v0 writing value succeeds`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
streamsGetClosed()
@ -175,14 +323,17 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `writing value uses old v0 code`() = runBlocking {
restore.initializeState(0.toByte(), token, packageInfo)
@Suppress("deprecation")
fun `v0 writing value uses old v0 code`() = runBlocking {
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key)
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(VERSION, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
@ -194,43 +345,35 @@ internal class KVRestoreTest : RestoreTest() {
}
@Test
fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every {
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
} throws GeneralSecurityException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
}
@Test
fun `writing two values succeeds`() = runBlocking {
@Suppress("Deprecation")
fun `v0 writing two values succeeds`() = runBlocking {
val data2 = getRandomByteArray()
val inputStream2 = mockk<InputStream>()
val decryptedInputStream2 = mockk<InputStream>()
restore.initializeState(VERSION, token, packageInfo)
restore.initializeState(0, token, name, packageInfo)
getRecordsAndOutput(listOf(key64, key264))
// first key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key64)
} returns inputStream
every { headerReader.readVersion(inputStream, 0) } returns 0
every {
crypto.decryptHeader(inputStream, 0, packageInfo.packageName, key)
} returns VersionHeader(0, packageInfo.packageName, key)
every { crypto.decryptMultipleSegments(inputStream) } returns data
every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size
// second key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2
every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2
every { decryptedInputStream2.readBytes() } returns data2
coEvery {
legacyPlugin.getInputStreamForRecord(token, packageInfo, key264)
} returns inputStream2
every { headerReader.readVersion(inputStream2, 0) } returns 0
every {
crypto.decryptHeader(inputStream2, 0, packageInfo.packageName, key2)
} returns VersionHeader(0, packageInfo.packageName, key2)
every { crypto.decryptMultipleSegments(inputStream2) } returns data2
every { output.writeEntityHeader(key2, data2.size) } returns 42
every { output.writeEntityData(data2, data2.size) } returns data2.size
every { decryptedInputStream2.close() } just Runs
every { inputStream2.close() } just Runs
streamsGetClosed()
@ -238,12 +381,11 @@ internal class KVRestoreTest : RestoreTest() {
}
private fun getRecordsAndOutput(recordKeys: List<String> = listOf(key64)) {
coEvery { plugin.listRecords(token, packageInfo) } returns recordKeys
coEvery { legacyPlugin.listRecords(token, packageInfo) } returns recordKeys
every { outputFactory.getBackupDataOutput(fileDescriptor) } returns output
}
private fun streamsGetClosed() {
every { decryptedInputStream.close() } just Runs
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
}

View file

@ -223,19 +223,20 @@ internal class RestoreCoordinatorTest : TransportTest() {
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true
every { kv.initializeState(VERSION, token, packageInfo) } just Runs
every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
}
@Test
@Suppress("Deprecation")
fun `v0 nextRestorePackage() returns KV description and takes precedence`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(0x00, token, packageInfo) } just Runs
every { kv.initializeState(0x00, token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@ -292,7 +293,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
every { crypto.getNameForPackage(metadata.salt, packageName) } returns name
coEvery { plugin.hasData(token, name) } returns true
every { kv.initializeState(VERSION, token, packageInfo) } just Runs
every { kv.initializeState(VERSION, token, name, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@ -315,7 +316,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.startRestore(token, packageInfoArray2)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(0.toByte(), token, packageInfo) } just Runs
every { kv.initializeState(0.toByte(), token, "", packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage())
@ -331,6 +332,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
}
@Test
@Suppress("Deprecation")
fun `v0 when kv#hasDataForPackage() throws, it tries next package`() = runBlocking {
restore.beforeStartRestore(metadata.copy(version = 0x00))
restore.startRestore(token, packageInfoArray)

View file

@ -16,6 +16,7 @@ import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.transport.TransportTest
import com.stevesoltys.seedvault.transport.backup.BackupPlugin
import com.stevesoltys.seedvault.transport.backup.KvDbManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.coEvery
import io.mockk.every
@ -44,12 +45,20 @@ internal class RestoreV0IntegrationTest : TransportTest() {
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val dbManager = mockk<KvDbManager>()
private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>()
private val backupPlugin = mockk<BackupPlugin>()
private val kvRestorePlugin = mockk<KVRestorePlugin>()
private val kvRestore = KVRestore(kvRestorePlugin, outputFactory, headerReader, cryptoImpl)
private val kvRestore = KVRestore(
backupPlugin,
kvRestorePlugin,
outputFactory,
headerReader,
cryptoImpl,
dbManager
)
private val fullRestorePlugin = mockk<FullRestorePlugin>()
private val fullRestore =
FullRestore(backupPlugin, fullRestorePlugin, outputFactory, headerReader, cryptoImpl)