Full backup and restore using v2

while maintaining support for v0 and v1
This commit is contained in:
Torsten Grote 2024-09-06 16:27:37 -03:00
parent 83708d9403
commit 7c7ea5fcd7
No known key found for this signature in database
GPG key ID: 3E5F77D92CF891FF
19 changed files with 695 additions and 610 deletions

View file

@ -36,11 +36,11 @@ class KoinInstrumentationTestApp : App() {
single { spyk(SettingsManager(context)) } single { spyk(SettingsManager(context)) }
single { spyk(BackupNotificationManager(context)) } single { spyk(BackupNotificationManager(context)) }
single { spyk(FullBackup(get(), get(), get(), get(), get())) } single { spyk(FullBackup(get(), get(), get(), get())) }
single { spyk(KVBackup(get(), get(), get(), get(), get(), get())) } single { spyk(KVBackup(get(), get(), get(), get(), get(), get())) }
single { spyk(InputFactory()) } single { spyk(InputFactory()) }
single { spyk(FullRestore(get(), get(), get(), get(), get())) } single { spyk(FullRestore(get(), get(), get(), get(), get(), get())) }
single { spyk(KVRestore(get(), get(), get(), get(), get(), get())) } single { spyk(KVRestore(get(), get(), get(), get(), get(), get())) }
single { spyk(OutputFactory()) } single { spyk(OutputFactory()) }

View file

@ -157,7 +157,7 @@ internal interface LargeBackupTestBase : LargeTestBase {
var dataIntercept = ByteArrayOutputStream() var dataIntercept = ByteArrayOutputStream()
coEvery { coEvery {
spyFullBackup.performFullBackup(any(), any(), any(), any(), any()) spyFullBackup.performFullBackup(any(), any(), any())
} answers { } answers {
packageName = firstArg<PackageInfo>().packageName packageName = firstArg<PackageInfo>().packageName
callOriginal() callOriginal()
@ -172,7 +172,7 @@ internal interface LargeBackupTestBase : LargeTestBase {
) )
} }
every { coEvery {
spyFullBackup.finishBackup() spyFullBackup.finishBackup()
} answers { } answers {
val result = callOriginal() val result = callOriginal()

View file

@ -189,7 +189,7 @@ internal interface LargeRestoreTestBase : LargeTestBase {
clearMocks(spyFullRestore) clearMocks(spyFullRestore)
coEvery { coEvery {
spyFullRestore.initializeState(any(), any(), any(), any()) spyFullRestore.initializeState(any(), any(), any())
} answers { } answers {
packageName?.let { packageName?.let {
restoreResult.full[it] = dataIntercept.toByteArray().sha256() restoreResult.full[it] = dataIntercept.toByteArray().sha256()

View file

@ -145,10 +145,9 @@ internal class MetadataManager(
packageInfo: PackageInfo, packageInfo: PackageInfo,
type: BackupType, type: BackupType,
size: Long?, size: Long?,
metadataOutputStream: OutputStream,
) { ) {
val packageName = packageInfo.packageName val packageName = packageInfo.packageName
modifyMetadata(metadataOutputStream) { modifyCachedMetadata {
val now = clock.time() val now = clock.time()
metadata.time = now metadata.time = now
metadata.d2dBackup = settingsManager.d2dBackupsEnabled() metadata.d2dBackup = settingsManager.d2dBackupsEnabled()

View file

@ -23,15 +23,15 @@ import android.util.Log
import androidx.annotation.WorkerThread import androidx.annotation.WorkerThread
import com.stevesoltys.seedvault.Clock import com.stevesoltys.seedvault.Clock
import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.getMetadataOutputStream
import com.stevesoltys.seedvault.backend.isOutOfSpace
import com.stevesoltys.seedvault.metadata.BackupType import com.stevesoltys.seedvault.metadata.BackupType
import com.stevesoltys.seedvault.metadata.MetadataManager import com.stevesoltys.seedvault.metadata.MetadataManager
import com.stevesoltys.seedvault.metadata.PackageState import com.stevesoltys.seedvault.metadata.PackageState
import com.stevesoltys.seedvault.metadata.PackageState.NO_DATA import com.stevesoltys.seedvault.metadata.PackageState.NO_DATA
import com.stevesoltys.seedvault.metadata.PackageState.QUOTA_EXCEEDED import com.stevesoltys.seedvault.metadata.PackageState.QUOTA_EXCEEDED
import com.stevesoltys.seedvault.metadata.PackageState.UNKNOWN_ERROR import com.stevesoltys.seedvault.metadata.PackageState.UNKNOWN_ERROR
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.getMetadataOutputStream
import com.stevesoltys.seedvault.backend.isOutOfSpace
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import java.io.IOException import java.io.IOException
@ -63,6 +63,7 @@ private class CoordinatorState(
internal class BackupCoordinator( internal class BackupCoordinator(
private val context: Context, private val context: Context,
private val backendManager: BackendManager, private val backendManager: BackendManager,
private val appBackupManager: AppBackupManager,
private val kv: KVBackup, private val kv: KVBackup,
private val full: FullBackup, private val full: FullBackup,
private val clock: Clock, private val clock: Clock,
@ -73,6 +74,8 @@ internal class BackupCoordinator(
) { ) {
private val backend get() = backendManager.backend private val backend get() = backendManager.backend
private val snapshotCreator
get() = appBackupManager.snapshotCreator ?: error("No SnapshotCreator")
private val state = CoordinatorState( private val state = CoordinatorState(
calledInitialize = false, calledInitialize = false,
calledClearBackupData = false, calledClearBackupData = false,
@ -154,7 +157,7 @@ internal class BackupCoordinator(
fun getBackupQuota(packageName: String, isFullBackup: Boolean): Long { fun getBackupQuota(packageName: String, isFullBackup: Boolean): Long {
// report back quota // report back quota
Log.i(TAG, "Get backup quota for $packageName. Is full backup: $isFullBackup.") Log.i(TAG, "Get backup quota for $packageName. Is full backup: $isFullBackup.")
val quota = if (isFullBackup) full.getQuota() else kv.getQuota() val quota = if (isFullBackup) full.quota else kv.getQuota()
Log.i(TAG, "Reported quota of $quota bytes.") Log.i(TAG, "Reported quota of $quota bytes.")
return quota return quota
} }
@ -262,15 +265,13 @@ internal class BackupCoordinator(
return result return result
} }
suspend fun performFullBackup( fun performFullBackup(
targetPackage: PackageInfo, targetPackage: PackageInfo,
fileDescriptor: ParcelFileDescriptor, fileDescriptor: ParcelFileDescriptor,
flags: Int, flags: Int,
): Int { ): Int {
state.cancelReason = UNKNOWN_ERROR state.cancelReason = UNKNOWN_ERROR
val token = settingsManager.getToken() ?: error("no token in performFullBackup") return full.performFullBackup(targetPackage, fileDescriptor, flags)
val salt = metadataManager.salt
return full.performFullBackup(targetPackage, fileDescriptor, flags, token, salt)
} }
/** /**
@ -299,8 +300,8 @@ internal class BackupCoordinator(
* It needs to tear down any ongoing backup state here. * It needs to tear down any ongoing backup state here.
*/ */
suspend fun cancelFullBackup() { suspend fun cancelFullBackup() {
val packageInfo = full.getCurrentPackage() val packageInfo = full.currentPackageInfo
?: throw AssertionError("Cancelling full backup, but no current package") ?: error("Cancelling full backup, but no current package")
Log.i( Log.i(
TAG, "Cancel full backup of ${packageInfo.packageName}" + TAG, "Cancel full backup of ${packageInfo.packageName}" +
" because of ${state.cancelReason}" " because of ${state.cancelReason}"
@ -308,9 +309,7 @@ internal class BackupCoordinator(
// don't bother with system apps that have no data // don't bother with system apps that have no data
val ignoreApp = state.cancelReason == NO_DATA && packageInfo.isSystemApp() val ignoreApp = state.cancelReason == NO_DATA && packageInfo.isSystemApp()
if (!ignoreApp) onPackageBackupError(packageInfo, BackupType.FULL) if (!ignoreApp) onPackageBackupError(packageInfo, BackupType.FULL)
val token = settingsManager.getToken() ?: error("no token in cancelFullBackup") full.cancelFullBackup()
val salt = metadataManager.salt
full.cancelFullBackup(token, salt, ignoreApp)
} }
// Clear and Finish // Clear and Finish
@ -335,12 +334,7 @@ internal class BackupCoordinator(
Log.w(TAG, "Error clearing K/V backup data for $packageName", e) Log.w(TAG, "Error clearing K/V backup data for $packageName", e)
return TRANSPORT_ERROR return TRANSPORT_ERROR
} }
try { // we don't clear backup data anymore, we have snapshots and those old ones stay valid
full.clearBackupData(packageInfo, token, salt)
} catch (e: IOException) {
Log.w(TAG, "Error clearing full backup data for $packageName", e)
return TRANSPORT_ERROR
}
state.calledClearBackupData = true state.calledClearBackupData = true
return TRANSPORT_OK return TRANSPORT_OK
} }
@ -355,7 +349,7 @@ internal class BackupCoordinator(
*/ */
suspend fun finishBackup(): Int = when { suspend fun finishBackup(): Int = when {
kv.hasState() -> { kv.hasState() -> {
check(!full.hasState()) { check(!full.hasState) {
"K/V backup has state, but full backup has dangling state as well" "K/V backup has state, but full backup has dangling state as well"
} }
// getCurrentPackage() not-null because we have state, call before finishing // getCurrentPackage() not-null because we have state, call before finishing
@ -369,7 +363,7 @@ internal class BackupCoordinator(
// call onPackageBackedUp for @pm@ only if we can do backups right now // call onPackageBackedUp for @pm@ only if we can do backups right now
if (isNormalBackup || backendManager.canDoBackupNow()) { if (isNormalBackup || backendManager.canDoBackupNow()) {
try { try {
onPackageBackedUp(packageInfo, BackupType.KV, size) metadataManager.onPackageBackedUp(packageInfo, BackupType.KV, size)
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error calling onPackageBackedUp for $packageName", e) Log.e(TAG, "Error calling onPackageBackedUp for $packageName", e)
if (e.isOutOfSpace()) nm.onInsufficientSpaceError() if (e.isOutOfSpace()) nm.onInsufficientSpaceError()
@ -379,24 +373,25 @@ internal class BackupCoordinator(
} }
result result
} }
full.hasState() -> { full.hasState -> {
check(!kv.hasState()) { check(!kv.hasState()) {
"Full backup has state, but K/V backup has dangling state as well" "Full backup has state, but K/V backup has dangling state as well"
} }
// getCurrentPackage() not-null because we have state // getCurrentPackage() not-null because we have state
val packageInfo = full.getCurrentPackage()!! val packageInfo = full.currentPackageInfo!!
val packageName = packageInfo.packageName val packageName = packageInfo.packageName
val size = full.getCurrentSize()
// tell full backup to finish // tell full backup to finish
var result = full.finishBackup()
try { try {
onPackageBackedUp(packageInfo, BackupType.FULL, size) val backupData = full.finishBackup()
snapshotCreator.onPackageBackedUp(packageInfo, BackupType.FULL, backupData)
// TODO unify both calls
metadataManager.onPackageBackedUp(packageInfo, BackupType.FULL, backupData.size)
TRANSPORT_OK
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error calling onPackageBackedUp for $packageName", e) Log.e(TAG, "Error calling onPackageBackedUp for $packageName", e)
if (e.isOutOfSpace()) nm.onInsufficientSpaceError() if (e.isOutOfSpace()) nm.onInsufficientSpaceError()
result = TRANSPORT_PACKAGE_REJECTED TRANSPORT_PACKAGE_REJECTED
} }
result
} }
state.expectFinish -> { state.expectFinish -> {
state.onFinish() state.onFinish()
@ -405,13 +400,6 @@ internal class BackupCoordinator(
else -> throw IllegalStateException("Unexpected state in finishBackup()") else -> throw IllegalStateException("Unexpected state in finishBackup()")
} }
private suspend fun onPackageBackedUp(packageInfo: PackageInfo, type: BackupType, size: Long?) {
val token = settingsManager.getToken() ?: error("no token")
backend.getMetadataOutputStream(token).use {
metadataManager.onPackageBackedUp(packageInfo, type, size, it)
}
}
private suspend fun onPackageBackupError(packageInfo: PackageInfo, type: BackupType) { private suspend fun onPackageBackupError(packageInfo: PackageInfo, type: BackupType) {
val packageName = packageInfo.packageName val packageName = packageInfo.packageName
try { try {

View file

@ -38,17 +38,17 @@ val backupModule = module {
} }
single { single {
FullBackup( FullBackup(
backendManager = get(),
settingsManager = get(), settingsManager = get(),
nm = get(), nm = get(),
backupReceiver = get(),
inputFactory = get(), inputFactory = get(),
crypto = get(),
) )
} }
single { single {
BackupCoordinator( BackupCoordinator(
context = androidContext(), context = androidContext(),
backendManager = get(), backendManager = get(),
appBackupManager = get(),
kv = get(), kv = get(),
full = get(), full = get(),
clock = get(), clock = get(),

View file

@ -16,7 +16,9 @@ import java.io.InputStream
data class BackupData( data class BackupData(
val chunks: List<String>, val chunks: List<String>,
val chunkMap: Map<String, Blob>, val chunkMap: Map<String, Blob>,
) ) {
val size get() = chunkMap.values.sumOf { it.uncompressedLength }.toLong()
}
internal class BackupReceiver( internal class BackupReceiver(
private val blobsCache: BlobsCache, private val blobsCache: BlobsCache,
@ -40,8 +42,10 @@ internal class BackupReceiver(
} }
private val chunks = mutableListOf<String>() private val chunks = mutableListOf<String>()
private val chunkMap = mutableMapOf<String, Blob>() private val chunkMap = mutableMapOf<String, Blob>()
private var addedBytes = false
suspend fun addBytes(bytes: ByteArray) { suspend fun addBytes(bytes: ByteArray) {
addedBytes = true
chunker.addBytes(bytes).forEach { chunk -> chunker.addBytes(bytes).forEach { chunk ->
onNewChunk(chunk) onNewChunk(chunk)
} }
@ -73,9 +77,15 @@ internal class BackupReceiver(
val backupData = BackupData(chunks.toList(), chunkMap.toMap()) val backupData = BackupData(chunks.toList(), chunkMap.toMap())
chunks.clear() chunks.clear()
chunkMap.clear() chunkMap.clear()
addedBytes = false
return backupData return backupData
} }
fun assertFinalized() {
// TODO maybe even use a userTag and throw also above if that doesn't match
check(!addedBytes) { "Re-used non-finalized BackupReceiver" }
}
private suspend fun onNewChunk(chunk: Chunk) { private suspend fun onNewChunk(chunk: Chunk) {
chunks.add(chunk.hash) chunks.add(chunk.hash)

View file

@ -13,30 +13,19 @@ import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
import android.content.pm.PackageInfo import android.content.pm.PackageInfo
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.util.Log import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.isOutOfSpace import com.stevesoltys.seedvault.backend.isOutOfSpace
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import org.calyxos.seedvault.core.backends.LegacyAppBackupFile
import java.io.Closeable import java.io.Closeable
import java.io.EOFException import java.io.EOFException
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
private class FullBackupState( private class FullBackupState(
val packageInfo: PackageInfo, val packageInfo: PackageInfo,
val inputFileDescriptor: ParcelFileDescriptor, val inputFileDescriptor: ParcelFileDescriptor,
val inputStream: InputStream, val inputStream: InputStream,
var outputStreamInit: (suspend () -> OutputStream)?,
) { ) {
/**
* This is an encrypted stream that can be written to directly.
*/
var outputStream: OutputStream? = null
val packageName: String = packageInfo.packageName val packageName: String = packageInfo.packageName
var size: Long = 0 var size: Long = 0
} }
@ -47,31 +36,28 @@ private val TAG = FullBackup::class.java.simpleName
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
internal class FullBackup( internal class FullBackup(
private val backendManager: BackendManager,
private val settingsManager: SettingsManager, private val settingsManager: SettingsManager,
private val nm: BackupNotificationManager, private val nm: BackupNotificationManager,
private val backupReceiver: BackupReceiver,
private val inputFactory: InputFactory, private val inputFactory: InputFactory,
private val crypto: Crypto,
) { ) {
private val backend get() = backendManager.backend
private var state: FullBackupState? = null private var state: FullBackupState? = null
fun hasState() = state != null val hasState: Boolean get() = state != null
val currentPackageInfo get() = state?.packageInfo
fun getCurrentPackage() = state?.packageInfo val quota
get() = if (settingsManager.isQuotaUnlimited()) {
fun getCurrentSize() = state?.size Long.MAX_VALUE
} else {
fun getQuota(): Long { DEFAULT_QUOTA_FULL_BACKUP
return if (settingsManager.isQuotaUnlimited()) Long.MAX_VALUE else DEFAULT_QUOTA_FULL_BACKUP
} }
fun checkFullBackupSize(size: Long): Int { fun checkFullBackupSize(size: Long): Int {
Log.i(TAG, "Check full backup size of $size bytes.") Log.i(TAG, "Check full backup size of $size bytes.")
return when { return when {
size <= 0 -> TRANSPORT_PACKAGE_REJECTED size <= 0 -> TRANSPORT_PACKAGE_REJECTED
size > getQuota() -> TRANSPORT_QUOTA_EXCEEDED size > quota -> TRANSPORT_QUOTA_EXCEEDED
else -> TRANSPORT_OK else -> TRANSPORT_OK
} }
} }
@ -111,71 +97,42 @@ internal class FullBackup(
* [TRANSPORT_OK] to indicate that the OS may proceed with delivering backup data; * [TRANSPORT_OK] to indicate that the OS may proceed with delivering backup data;
* [TRANSPORT_ERROR] to indicate an error that precludes performing a backup at this time. * [TRANSPORT_ERROR] to indicate an error that precludes performing a backup at this time.
*/ */
suspend fun performFullBackup( fun performFullBackup(
targetPackage: PackageInfo, targetPackage: PackageInfo,
socket: ParcelFileDescriptor, socket: ParcelFileDescriptor,
@Suppress("UNUSED_PARAMETER") flags: Int = 0, @Suppress("UNUSED_PARAMETER") flags: Int = 0,
token: Long,
salt: String,
): Int { ): Int {
if (state != null) throw AssertionError() if (state != null) error("state wasn't initialized for $targetPackage")
val packageName = targetPackage.packageName val packageName = targetPackage.packageName
Log.i(TAG, "Perform full backup for $packageName.") Log.i(TAG, "Perform full backup for $packageName.")
// create new state // create new state
val inputStream = inputFactory.getInputStream(socket) val inputStream = inputFactory.getInputStream(socket)
state = FullBackupState(targetPackage, socket, inputStream) { state = FullBackupState(targetPackage, socket, inputStream)
Log.d(TAG, "Initializing OutputStream for $packageName.") backupReceiver.assertFinalized()
val name = crypto.getNameForPackage(salt, packageName)
// get OutputStream to write backup data into
val outputStream = try {
backend.save(LegacyAppBackupFile.Blob(token, name))
} catch (e: IOException) {
"Error getting OutputStream for full backup of $packageName".let {
Log.e(TAG, it, e)
}
throw(e)
}
// store version header
val state = this.state ?: throw AssertionError()
outputStream.write(ByteArray(1) { VERSION })
crypto.newEncryptingStreamV1(outputStream, getADForFull(VERSION, state.packageName))
} // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK return TRANSPORT_OK
} }
suspend fun sendBackupData(numBytes: Int): Int { suspend fun sendBackupData(numBytes: Int): Int {
val state = this.state val state = this.state ?: error("Attempted sendBackupData before performFullBackup")
?: throw AssertionError("Attempted sendBackupData before performFullBackup")
// check if size fits quota // check if size fits quota
state.size += numBytes val newSize = state.size + numBytes
val quota = getQuota() if (newSize > quota) {
if (state.size > quota) {
Log.w( Log.w(
TAG, TAG,
"Full backup of additional $numBytes exceeds quota of $quota with ${state.size}." "Full backup of additional $numBytes exceeds quota of $quota with $newSize."
) )
return TRANSPORT_QUOTA_EXCEEDED return TRANSPORT_QUOTA_EXCEEDED
} }
return try { return try {
// get output stream or initialize it, if it does not yet exist
check((state.outputStream != null) xor (state.outputStreamInit != null)) {
"No OutputStream xor no StreamGetter"
}
val outputStream = state.outputStream ?: suspend {
val stream = state.outputStreamInit!!() // not-null due to check above
state.outputStream = stream
stream
}()
state.outputStreamInit = null // the stream init lambda is not needed beyond that point
// read backup data and write it to encrypted output stream // read backup data and write it to encrypted output stream
val payload = ByteArray(numBytes) val payload = ByteArray(numBytes)
val read = state.inputStream.read(payload, 0, numBytes) val read = state.inputStream.read(payload, 0, numBytes)
if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.") if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.")
outputStream.write(payload) backupReceiver.addBytes(payload)
state.size += numBytes
TRANSPORT_OK TRANSPORT_OK
} catch (e: IOException) { } catch (e: IOException) {
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e) Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)
@ -184,44 +141,41 @@ internal class FullBackup(
} }
} }
@Throws(IOException::class) suspend fun cancelFullBackup() {
suspend fun clearBackupData(packageInfo: PackageInfo, token: Long, salt: String) { val state = this.state ?: error("No state when canceling")
val name = crypto.getNameForPackage(salt, packageInfo.packageName) Log.i(TAG, "Cancel full backup for ${state.packageName}")
backend.remove(LegacyAppBackupFile.Blob(token, name)) // TODO check if worth keeping the blobs. they've been uploaded already and may be re-usable
} // so we could add them to the snapshot's blobMap or just let prune remove them at the end
suspend fun cancelFullBackup(token: Long, salt: String, ignoreApp: Boolean) {
Log.i(TAG, "Cancel full backup")
val state = this.state ?: throw AssertionError("No state when canceling")
try { try {
if (!ignoreApp) clearBackupData(state.packageInfo, token, salt) backupReceiver.finalize()
} catch (e: IOException) { } catch (e: Exception) {
Log.w(TAG, "Error cancelling full backup for ${state.packageName}", e) // as the backup was cancelled anyway, we don't care if finalizing had an error
Log.e(TAG, "Error finalizing backup in cancelFullBackup().", e)
} }
clearState() clearState()
// TODO roll back to the previous known-good archive
} }
fun finishBackup(): Int { /**
Log.i(TAG, "Finish full backup of ${state!!.packageName}. Wrote ${state!!.size} bytes") * Returns a pair of the [BackupData] after finalizing last chunks and the total backup size.
return clearState() */
@Throws(IOException::class)
suspend fun finishBackup(): BackupData {
val state = this.state ?: error("No state when finishing")
Log.i(TAG, "Finish full backup of ${state.packageName}. Wrote ${state.size} bytes")
val result = try {
backupReceiver.finalize()
} finally {
clearState()
}
return result
} }
private fun clearState(): Int { private fun clearState() {
val state = this.state ?: throw AssertionError("Trying to clear empty state.") val state = this.state ?: error("Trying to clear empty state.")
return try {
state.outputStream?.flush()
closeLogging(state.outputStream)
closeLogging(state.inputStream) closeLogging(state.inputStream)
closeLogging(state.inputFileDescriptor) closeLogging(state.inputFileDescriptor)
TRANSPORT_OK
} catch (e: IOException) {
Log.w(TAG, "Error when clearing state", e)
TRANSPORT_ERROR
} finally {
this.state = null this.state = null
} }
}
private fun closeLogging(closable: Closeable?) = try { private fun closeLogging(closable: Closeable?) = try {
closable?.close() closable?.close()

View file

@ -12,14 +12,15 @@ import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import android.content.pm.PackageInfo import android.content.pm.PackageInfo
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.util.Log import android.util.Log
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.LegacyStoragePlugin
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.getADForFull import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.LegacyStoragePlugin
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
import org.calyxos.seedvault.core.backends.AppBackupFileType.Blob
import org.calyxos.seedvault.core.backends.LegacyAppBackupFile import org.calyxos.seedvault.core.backends.LegacyAppBackupFile
import java.io.EOFException import java.io.EOFException
import java.io.IOException import java.io.IOException
@ -29,9 +30,10 @@ import java.security.GeneralSecurityException
private class FullRestoreState( private class FullRestoreState(
val version: Byte, val version: Byte,
val token: Long,
val name: String,
val packageInfo: PackageInfo, val packageInfo: PackageInfo,
val blobHandles: List<Blob>? = null,
val token: Long? = null,
val name: String? = null,
) { ) {
var inputStream: InputStream? = null var inputStream: InputStream? = null
} }
@ -40,6 +42,7 @@ private val TAG = FullRestore::class.java.simpleName
internal class FullRestore( internal class FullRestore(
private val backendManager: BackendManager, private val backendManager: BackendManager,
private val loader: Loader,
@Suppress("Deprecation") @Suppress("Deprecation")
private val legacyPlugin: LegacyStoragePlugin, private val legacyPlugin: LegacyStoragePlugin,
private val outputFactory: OutputFactory, private val outputFactory: OutputFactory,
@ -50,7 +53,7 @@ internal class FullRestore(
private val backend get() = backendManager.backend private val backend get() = backendManager.backend
private var state: FullRestoreState? = null private var state: FullRestoreState? = null
fun hasState() = state != null val hasState get() = state != null
/** /**
* Return true if there is data stored for the given package. * Return true if there is data stored for the given package.
@ -69,8 +72,16 @@ internal class FullRestore(
* It is possible that the system decides to not restore the package. * It is possible that the system decides to not restore the package.
* Then a new state will be initialized right away without calling other methods. * Then a new state will be initialized right away without calling other methods.
*/ */
fun initializeState(version: Byte, token: Long, name: String, packageInfo: PackageInfo) { fun initializeState(version: Byte, packageInfo: PackageInfo, blobHandles: List<Blob>) {
state = FullRestoreState(version, token, name, packageInfo) state = FullRestoreState(version, packageInfo, blobHandles)
}
fun initializeStateV1(token: Long, name: String, packageInfo: PackageInfo) {
state = FullRestoreState(1, packageInfo, null, token, name)
}
fun initializeStateV0(token: Long, packageInfo: PackageInfo) {
state = FullRestoreState(0x00, packageInfo, null, token)
} }
/** /**
@ -107,20 +118,30 @@ internal class FullRestore(
if (state.inputStream == null) { if (state.inputStream == null) {
Log.i(TAG, "First Chunk, initializing package input stream.") Log.i(TAG, "First Chunk, initializing package input stream.")
try { try {
if (state.version == 0.toByte()) { when (state.version) {
0.toByte() -> {
val token = state.token ?: error("no token for v0 backup")
val inputStream = val inputStream =
legacyPlugin.getInputStreamForPackage(state.token, state.packageInfo) legacyPlugin.getInputStreamForPackage(token, state.packageInfo)
val version = headerReader.readVersion(inputStream, state.version) val version = headerReader.readVersion(inputStream, state.version)
@Suppress("deprecation") @Suppress("deprecation")
crypto.decryptHeader(inputStream, version, packageName) crypto.decryptHeader(inputStream, version, packageName)
state.inputStream = inputStream state.inputStream = inputStream
} else { }
val handle = LegacyAppBackupFile.Blob(state.token, state.name) 1.toByte() -> {
val token = state.token ?: error("no token for v1 backup")
val name = state.name ?: error("no name for v1 backup")
val handle = LegacyAppBackupFile.Blob(token, name)
val inputStream = backend.load(handle) val inputStream = backend.load(handle)
val version = headerReader.readVersion(inputStream, state.version) val version = headerReader.readVersion(inputStream, state.version)
val ad = getADForFull(version, packageName) val ad = getADForFull(version, packageName)
state.inputStream = crypto.newDecryptingStreamV1(inputStream, ad) state.inputStream = crypto.newDecryptingStreamV1(inputStream, ad)
} }
else -> {
val handles = state.blobHandles ?: error("no blob handles for v2")
state.inputStream = loader.loadFiles(handles)
}
}
} catch (e: IOException) { } catch (e: IOException) {
Log.w(TAG, "Error getting input stream for $packageName", e) Log.w(TAG, "Error getting input stream for $packageName", e)
return TRANSPORT_PACKAGE_REJECTED return TRANSPORT_PACKAGE_REJECTED

View file

@ -30,6 +30,7 @@ import com.stevesoltys.seedvault.proto.Snapshot
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import com.stevesoltys.seedvault.transport.D2D_TRANSPORT_FLAGS import com.stevesoltys.seedvault.transport.D2D_TRANSPORT_FLAGS
import com.stevesoltys.seedvault.transport.DEFAULT_TRANSPORT_FLAGS import com.stevesoltys.seedvault.transport.DEFAULT_TRANSPORT_FLAGS
import com.stevesoltys.seedvault.transport.backup.getBlobHandles
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.backends.Backend import org.calyxos.seedvault.core.backends.Backend
@ -261,8 +262,11 @@ internal class RestoreCoordinator(
val packageInfo = state.packages.next() val packageInfo = state.packages.next()
val version = state.backup.version val version = state.backup.version
if (version == 0.toByte()) return nextRestorePackageV0(state, packageInfo) if (version == 0.toByte()) return nextRestorePackageV0(state, packageInfo)
if (version == 1.toByte()) return nextRestorePackageV1(state, packageInfo)
val packageName = packageInfo.packageName val packageName = packageInfo.packageName
val repoId = state.backup.repoId ?: error("No repoId in v2 backup")
val snapshot = state.backup.snapshot ?: error("No snapshot in v2 backup")
val type = when (state.backup.packageMetadataMap[packageName]?.backupType) { val type = when (state.backup.packageMetadataMap[packageName]?.backupType) {
BackupType.KV -> { BackupType.KV -> {
val name = crypto.getNameForPackage(state.backup.salt, packageName) val name = crypto.getNameForPackage(state.backup.salt, packageName)
@ -278,8 +282,57 @@ internal class RestoreCoordinator(
} }
BackupType.FULL -> { BackupType.FULL -> {
val chunkIds = state.backup.packageMetadataMap[packageName]?.chunkIds
?: error("no metadata or chunkIds")
val blobHandles = try {
snapshot.getBlobHandles(repoId, chunkIds)
} catch (e: Exception) {
Log.e(TAG, "Error getting blob handles: ", e)
failedPackages.add(packageName)
// abort here as this is close to an assertion error
return null
}
full.initializeState(version, packageInfo, blobHandles)
state.currentPackage = packageName
TYPE_FULL_STREAM
}
null -> {
Log.i(TAG, "No backup type found for $packageName. Skipping...")
state.backup.packageMetadataMap[packageName]?.backupType?.let { s ->
Log.w(TAG, "State was ${s.name}")
}
failedPackages.add(packageName)
// don't return null and cause abort here, but try next package
return nextRestorePackage()
}
}
return RestoreDescription(packageName, type)
}
@Suppress("deprecation")
private suspend fun nextRestorePackageV1(
state: RestoreCoordinatorState,
packageInfo: PackageInfo,
): RestoreDescription? {
val packageName = packageInfo.packageName
val type = when (state.backup.packageMetadataMap[packageName]?.backupType) {
BackupType.KV -> {
val name = crypto.getNameForPackage(state.backup.salt, packageName) val name = crypto.getNameForPackage(state.backup.salt, packageName)
full.initializeState(version, state.token, name, packageInfo) kv.initializeState(
version = 1,
token = state.token,
name = name,
packageInfo = packageInfo,
autoRestorePackageInfo = state.autoRestorePackageInfo
)
state.currentPackage = packageName
TYPE_KEY_VALUE
}
BackupType.FULL -> {
val name = crypto.getNameForPackage(state.backup.salt, packageName)
full.initializeStateV1(state.token, name, packageInfo)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_FULL_STREAM TYPE_FULL_STREAM
} }
@ -315,7 +368,7 @@ internal class RestoreCoordinator(
full.hasDataForPackage(state.token, packageInfo) -> { full.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found full backup data for $packageName.") Log.i(TAG, "Found full backup data for $packageName.")
full.initializeState(0x00, state.token, "", packageInfo) full.initializeStateV0(state.token, packageInfo)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_FULL_STREAM TYPE_FULL_STREAM
} }
@ -380,7 +433,7 @@ internal class RestoreCoordinator(
*/ */
fun finishRestore() { fun finishRestore() {
Log.d(TAG, "finishRestore") Log.d(TAG, "finishRestore")
if (full.hasState()) full.finishRestore() if (full.hasState) full.finishRestore()
state = null state = null
} }

View file

@ -12,7 +12,7 @@ val restoreModule = module {
single { OutputFactory() } single { OutputFactory() }
single { Loader(get(), get()) } single { Loader(get(), get()) }
single { KVRestore(get(), get(), get(), get(), get(), get()) } single { KVRestore(get(), get(), get(), get(), get(), get()) }
single { FullRestore(get(), get(), get(), get(), get()) } single { FullRestore(get(), get(), get(), get(), get(), get()) }
single { single {
RestoreCoordinator( RestoreCoordinator(
context = androidContext(), context = androidContext(),

View file

@ -40,7 +40,6 @@ import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse import org.junit.Assert.assertFalse
import org.junit.Assert.assertNull import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.assertThrows
@ -358,7 +357,7 @@ class MetadataManagerTest {
every { clock.time() } returns time every { clock.time() } returns time
expectModifyMetadata(initialMetadata) expectModifyMetadata(initialMetadata)
manager.onPackageBackedUp(packageInfo, BackupType.FULL, size, storageOutputStream) manager.onPackageBackedUp(packageInfo, BackupType.FULL, size)
assertEquals( assertEquals(
packageMetadata.copy( packageMetadata.copy(
@ -388,7 +387,7 @@ class MetadataManagerTest {
every { settingsManager.d2dBackupsEnabled() } returns true every { settingsManager.d2dBackupsEnabled() } returns true
every { context.packageManager } returns packageManager every { context.packageManager } returns packageManager
manager.onPackageBackedUp(packageInfo, BackupType.FULL, 0L, storageOutputStream) manager.onPackageBackedUp(packageInfo, BackupType.FULL, 0L)
assertTrue(initialMetadata.d2dBackup) assertTrue(initialMetadata.d2dBackup)
verify { verify {
@ -397,35 +396,6 @@ class MetadataManagerTest {
} }
} }
@Test
fun `test onPackageBackedUp() fails to write to storage`() {
val updateTime = time + 1
val size = Random.nextLong()
val updatedMetadata = initialMetadata.copy(
time = updateTime,
packageMetadataMap = PackageMetadataMap() // otherwise this isn't copied, but referenced
)
updatedMetadata.packageMetadataMap[packageName] =
PackageMetadata(updateTime, APK_AND_DATA, BackupType.KV, size)
every { context.packageManager } returns packageManager
expectReadFromCache()
every { clock.time() } returns updateTime
every { metadataWriter.write(updatedMetadata, storageOutputStream) } throws IOException()
try {
manager.onPackageBackedUp(packageInfo, BackupType.KV, size, storageOutputStream)
fail()
} catch (e: IOException) {
// expected
}
assertEquals(0L, manager.getLastBackupTime()) // time was reverted
assertNull(manager.getPackageMetadata(packageName)) // no package metadata got added
verify { cacheInputStream.close() }
}
@Test @Test
fun `test onPackageBackedUp() with filled cache`() { fun `test onPackageBackedUp() with filled cache`() {
val cachedPackageName = getRandomString() val cachedPackageName = getRandomString()
@ -445,7 +415,7 @@ class MetadataManagerTest {
every { clock.time() } returns time every { clock.time() } returns time
expectModifyMetadata(updatedMetadata) expectModifyMetadata(updatedMetadata)
manager.onPackageBackedUp(packageInfo, BackupType.FULL, 0L, storageOutputStream) manager.onPackageBackedUp(packageInfo, BackupType.FULL, 0L)
assertEquals(time, manager.getLastBackupTime()) assertEquals(time, manager.getLastBackupTime())
assertEquals(PackageMetadata(time), manager.getPackageMetadata(cachedPackageName)) assertEquals(PackageMetadata(time), manager.getPackageMetadata(cachedPackageName))

View file

@ -22,11 +22,14 @@ import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.BackupType import com.stevesoltys.seedvault.metadata.BackupType
import com.stevesoltys.seedvault.metadata.MetadataReaderImpl import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.metadata.PackageMetadata import com.stevesoltys.seedvault.metadata.PackageMetadata
import com.stevesoltys.seedvault.transport.backup.AppBackupManager
import com.stevesoltys.seedvault.transport.backup.BackupCoordinator import com.stevesoltys.seedvault.transport.backup.BackupCoordinator
import com.stevesoltys.seedvault.transport.backup.BackupReceiver
import com.stevesoltys.seedvault.transport.backup.FullBackup import com.stevesoltys.seedvault.transport.backup.FullBackup
import com.stevesoltys.seedvault.transport.backup.InputFactory import com.stevesoltys.seedvault.transport.backup.InputFactory
import com.stevesoltys.seedvault.transport.backup.KVBackup import com.stevesoltys.seedvault.transport.backup.KVBackup
import com.stevesoltys.seedvault.transport.backup.PackageService import com.stevesoltys.seedvault.transport.backup.PackageService
import com.stevesoltys.seedvault.transport.backup.SnapshotCreator
import com.stevesoltys.seedvault.transport.backup.TestKvDbManager import com.stevesoltys.seedvault.transport.backup.TestKvDbManager
import com.stevesoltys.seedvault.transport.restore.FullRestore import com.stevesoltys.seedvault.transport.restore.FullRestore
import com.stevesoltys.seedvault.transport.restore.KVRestore import com.stevesoltys.seedvault.transport.restore.KVRestore
@ -35,13 +38,13 @@ import com.stevesoltys.seedvault.transport.restore.OutputFactory
import com.stevesoltys.seedvault.transport.restore.RestorableBackup import com.stevesoltys.seedvault.transport.restore.RestorableBackup
import com.stevesoltys.seedvault.transport.restore.RestoreCoordinator import com.stevesoltys.seedvault.transport.restore.RestoreCoordinator
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import com.stevesoltys.seedvault.worker.ApkBackup
import io.mockk.CapturingSlot import io.mockk.CapturingSlot
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.slot
import io.mockk.verify import io.mockk.verify
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.Backend import org.calyxos.seedvault.core.backends.Backend
@ -66,11 +69,14 @@ internal class CoordinatorIntegrationTest : TransportTest() {
private val notificationManager = mockk<BackupNotificationManager>() private val notificationManager = mockk<BackupNotificationManager>()
private val dbManager = TestKvDbManager() private val dbManager = TestKvDbManager()
private val backendManager: BackendManager = mockk() private val backendManager: BackendManager = mockk()
private val appBackupManager: AppBackupManager = mockk()
private val snapshotCreator: SnapshotCreator = mockk()
@Suppress("Deprecation") @Suppress("Deprecation")
private val legacyPlugin = mockk<LegacyStoragePlugin>() private val legacyPlugin = mockk<LegacyStoragePlugin>()
private val backend = mockk<Backend>() private val backend = mockk<Backend>()
private val loader = mockk<Loader>() private val loader = mockk<Loader>()
private val backupReceiver = mockk<BackupReceiver>()
private val kvBackup = KVBackup( private val kvBackup = KVBackup(
backendManager = backendManager, backendManager = backendManager,
settingsManager = settingsManager, settingsManager = settingsManager,
@ -80,17 +86,16 @@ internal class CoordinatorIntegrationTest : TransportTest() {
dbManager = dbManager, dbManager = dbManager,
) )
private val fullBackup = FullBackup( private val fullBackup = FullBackup(
backendManager = backendManager,
settingsManager = settingsManager, settingsManager = settingsManager,
nm = notificationManager, nm = notificationManager,
backupReceiver = backupReceiver,
inputFactory = inputFactory, inputFactory = inputFactory,
crypto = cryptoImpl,
) )
private val apkBackup = mockk<ApkBackup>()
private val packageService: PackageService = mockk() private val packageService: PackageService = mockk()
private val backup = BackupCoordinator( private val backup = BackupCoordinator(
context, context,
backendManager, backendManager,
appBackupManager,
kvBackup, kvBackup,
fullBackup, fullBackup,
clock, clock,
@ -109,7 +114,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
dbManager dbManager
) )
private val fullRestore = private val fullRestore =
FullRestore(backendManager, legacyPlugin, outputFactory, headerReader, cryptoImpl) FullRestore(backendManager, loader, legacyPlugin, outputFactory, headerReader, cryptoImpl)
private val restore = RestoreCoordinator( private val restore = RestoreCoordinator(
context, context,
crypto, crypto,
@ -123,21 +128,21 @@ internal class CoordinatorIntegrationTest : TransportTest() {
metadataReader metadataReader
) )
private val restorableBackup = RestorableBackup(metadata) private val restorableBackup = RestorableBackup(metadata, repoId, snapshot)
private val backupDataInput = mockk<BackupDataInput>() private val backupDataInput = mockk<BackupDataInput>()
private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true) private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true)
private val appData = ByteArray(42).apply { Random.nextBytes(this) } private val appData = ByteArray(42).apply { Random.nextBytes(this) }
private val appData2 = ByteArray(1337).apply { Random.nextBytes(this) } private val appData2 = ByteArray(1337).apply { Random.nextBytes(this) }
private val metadataOutputStream = ByteArrayOutputStream() private val metadataOutputStream = ByteArrayOutputStream()
private val packageMetadata = PackageMetadata(time = 0L)
private val key = "RestoreKey" private val key = "RestoreKey"
private val key2 = "RestoreKey2" private val key2 = "RestoreKey2"
// as we use real crypto, we need a real name for packageInfo // as we use real crypto, we need a real name for packageInfo
private val realName = cryptoImpl.getNameForPackage(salt, packageInfo.packageName) private val realName = cryptoImpl.getNameForPackage(salt, packageName)
init { init {
every { backendManager.backend } returns backend every { backendManager.backend } returns backend
every { appBackupManager.snapshotCreator } returns snapshotCreator
} }
@Test @Test
@ -162,19 +167,11 @@ internal class CoordinatorIntegrationTest : TransportTest() {
appData2.copyInto(value2.captured) // write the app data into the passed ByteArray appData2.copyInto(value2.captured) // write the app data into the passed ByteArray
appData2.size appData2.size
} }
coEvery { apkBackup.backupApkIfNecessary(packageInfo) } just Runs
coEvery {
backend.save(LegacyAppBackupFile.Metadata(token))
} returns metadataOutputStream
every {
metadataManager.onApkBackedUp(packageInfo, packageMetadata)
} just Runs
every { every {
metadataManager.onPackageBackedUp( metadataManager.onPackageBackedUp(
packageInfo = packageInfo, packageInfo = packageInfo,
type = BackupType.KV, type = BackupType.KV,
size = more((appData.size + appData2.size).toLong()), // more because DB overhead size = more((appData.size + appData2.size).toLong()), // more because DB overhead
metadataOutputStream = metadataOutputStream,
) )
} just Runs } just Runs
@ -241,7 +238,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
appData.copyInto(value.captured) // write the app data into the passed ByteArray appData.copyInto(value.captured) // write the app data into the passed ByteArray
appData.size appData.size
} }
coEvery { apkBackup.backupApkIfNecessary(packageInfo) } just Runs
every { settingsManager.getToken() } returns token every { settingsManager.getToken() } returns token
coEvery { coEvery {
backend.save(LegacyAppBackupFile.Metadata(token)) backend.save(LegacyAppBackupFile.Metadata(token))
@ -251,7 +247,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
packageInfo = packageInfo, packageInfo = packageInfo,
type = BackupType.KV, type = BackupType.KV,
size = more(size.toLong()), // more than $size, because DB overhead size = more(size.toLong()), // more than $size, because DB overhead
metadataOutputStream = metadataOutputStream,
) )
} just Runs } just Runs
@ -297,34 +292,38 @@ internal class CoordinatorIntegrationTest : TransportTest() {
@Test @Test
fun `test full backup and restore with two chunks`() = runBlocking { fun `test full backup and restore with two chunks`() = runBlocking {
metadata.packageMetadataMap[packageName] = PackageMetadata(
backupType = BackupType.FULL,
chunkIds = listOf(apkChunkId),
)
// package is of type FULL // package is of type FULL
val packageMetadata = metadata.packageMetadataMap[packageInfo.packageName]!! val packageMetadata = metadata.packageMetadataMap[packageInfo.packageName]!!
metadata.packageMetadataMap[packageInfo.packageName] = metadata.packageMetadataMap[packageInfo.packageName] =
packageMetadata.copy(backupType = BackupType.FULL) packageMetadata.copy(backupType = BackupType.FULL)
// return streams from plugin and app data // return streams from plugin and app data
val byteSlot = slot<ByteArray>()
val bOutputStream = ByteArrayOutputStream() val bOutputStream = ByteArrayOutputStream()
val bInputStream = ByteArrayInputStream(appData) val bInputStream = ByteArrayInputStream(appData)
coEvery {
backend.save(LegacyAppBackupFile.Blob(token, realName))
} returns bOutputStream
every { inputFactory.getInputStream(fileDescriptor) } returns bInputStream every { inputFactory.getInputStream(fileDescriptor) } returns bInputStream
every { backupReceiver.assertFinalized() } just Runs
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
coEvery { apkBackup.backupApkIfNecessary(packageInfo) } just Runs coEvery { backupReceiver.addBytes(capture(byteSlot)) } answers {
every { settingsManager.getToken() } returns token bOutputStream.writeBytes(byteSlot.captured)
every { metadataManager.salt } returns salt }
coEvery { every {
backend.save(LegacyAppBackupFile.Metadata(token)) snapshotCreator.onPackageBackedUp(packageInfo, BackupType.FULL, apkBackupData)
} returns metadataOutputStream } just Runs
every { metadataManager.onApkBackedUp(packageInfo, packageMetadata) } just Runs
every { every {
metadataManager.onPackageBackedUp( metadataManager.onPackageBackedUp(
packageInfo = packageInfo, packageInfo = packageInfo,
type = BackupType.FULL, type = BackupType.FULL,
size = appData.size.toLong(), size = apkBackupData.size,
metadataOutputStream = metadataOutputStream,
) )
} just Runs } just Runs
coEvery { backupReceiver.finalize() } returns apkBackupData // just some backupData
// perform backup to output stream // perform backup to output stream
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, fileDescriptor, 0)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, fileDescriptor, 0))
@ -336,9 +335,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
restore.beforeStartRestore(restorableBackup) restore.beforeStartRestore(restorableBackup)
assertEquals(TRANSPORT_OK, restore.startRestore(token, arrayOf(packageInfo))) assertEquals(TRANSPORT_OK, restore.startRestore(token, arrayOf(packageInfo)))
// finds data for full backup
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name
val restoreDescription = restore.nextRestorePackage() ?: fail() val restoreDescription = restore.nextRestorePackage() ?: fail()
assertEquals(packageInfo.packageName, restoreDescription.packageName) assertEquals(packageInfo.packageName, restoreDescription.packageName)
assertEquals(TYPE_FULL_STREAM, restoreDescription.dataType) assertEquals(TYPE_FULL_STREAM, restoreDescription.dataType)
@ -346,9 +342,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
// reverse the backup streams into restore input // reverse the backup streams into restore input
val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray()) val rInputStream = ByteArrayInputStream(bOutputStream.toByteArray())
val rOutputStream = ByteArrayOutputStream() val rOutputStream = ByteArrayOutputStream()
coEvery { coEvery { loader.loadFiles(listOf(apkBlobHandle)) } returns rInputStream
backend.load(LegacyAppBackupFile.Blob(token, name))
} returns rInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream
// restore data // restore data

View file

@ -42,6 +42,7 @@ import kotlin.random.Random
internal class BackupCoordinatorTest : BackupTest() { internal class BackupCoordinatorTest : BackupTest() {
private val backendManager = mockk<BackendManager>() private val backendManager = mockk<BackendManager>()
private val appBackupManager = mockk<AppBackupManager>()
private val kv = mockk<KVBackup>() private val kv = mockk<KVBackup>()
private val full = mockk<FullBackup>() private val full = mockk<FullBackup>()
private val apkBackup = mockk<ApkBackup>() private val apkBackup = mockk<ApkBackup>()
@ -51,6 +52,7 @@ internal class BackupCoordinatorTest : BackupTest() {
private val backup = BackupCoordinator( private val backup = BackupCoordinator(
context = context, context = context,
backendManager = backendManager, backendManager = backendManager,
appBackupManager = appBackupManager,
kv = kv, kv = kv,
full = full, full = full,
clock = clock, clock = clock,
@ -80,7 +82,7 @@ internal class BackupCoordinatorTest : BackupTest() {
fun `device initialization succeeds and delegates to plugin`() = runBlocking { fun `device initialization succeeds and delegates to plugin`() = runBlocking {
expectStartNewRestoreSet() expectStartNewRestoreSet()
every { kv.hasState() } returns false every { kv.hasState() } returns false
every { full.hasState() } returns false every { full.hasState } returns false
assertEquals(TRANSPORT_OK, backup.initializeDevice()) assertEquals(TRANSPORT_OK, backup.initializeDevice())
assertEquals(TRANSPORT_OK, backup.finishBackup()) assertEquals(TRANSPORT_OK, backup.finishBackup())
@ -107,7 +109,7 @@ internal class BackupCoordinatorTest : BackupTest() {
// finish will only be called when TRANSPORT_OK is returned, so it should throw // finish will only be called when TRANSPORT_OK is returned, so it should throw
every { kv.hasState() } returns false every { kv.hasState() } returns false
every { full.hasState() } returns false every { full.hasState } returns false
coAssertThrows(IllegalStateException::class.java) { coAssertThrows(IllegalStateException::class.java) {
backup.finishBackup() backup.finishBackup()
} }
@ -126,7 +128,7 @@ internal class BackupCoordinatorTest : BackupTest() {
// finish will only be called when TRANSPORT_OK is returned, so it should throw // finish will only be called when TRANSPORT_OK is returned, so it should throw
every { kv.hasState() } returns false every { kv.hasState() } returns false
every { full.hasState() } returns false every { full.hasState } returns false
coAssertThrows(IllegalStateException::class.java) { coAssertThrows(IllegalStateException::class.java) {
backup.finishBackup() backup.finishBackup()
} }
@ -159,7 +161,7 @@ internal class BackupCoordinatorTest : BackupTest() {
val quota = Random.nextLong() val quota = Random.nextLong()
if (isFullBackup) { if (isFullBackup) {
every { full.getQuota() } returns quota every { full.quota } returns quota
} else { } else {
every { kv.getQuota() } returns quota every { kv.getQuota() } returns quota
} }
@ -175,61 +177,30 @@ internal class BackupCoordinatorTest : BackupTest() {
assertEquals(TRANSPORT_ERROR, backup.clearBackupData(packageInfo)) assertEquals(TRANSPORT_ERROR, backup.clearBackupData(packageInfo))
} }
@Test
fun `clearing full backup data throws`() = runBlocking {
every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt
coEvery { kv.clearBackupData(packageInfo, token, salt) } just Runs
coEvery { full.clearBackupData(packageInfo, token, salt) } throws IOException()
assertEquals(TRANSPORT_ERROR, backup.clearBackupData(packageInfo))
}
@Test
fun `clearing backup data succeeds`() = runBlocking {
every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt
coEvery { kv.clearBackupData(packageInfo, token, salt) } just Runs
coEvery { full.clearBackupData(packageInfo, token, salt) } just Runs
assertEquals(TRANSPORT_OK, backup.clearBackupData(packageInfo))
every { kv.hasState() } returns false
every { full.hasState() } returns false
assertEquals(TRANSPORT_OK, backup.finishBackup())
}
@Test @Test
fun `finish backup delegates to KV plugin if it has state`() = runBlocking { fun `finish backup delegates to KV plugin if it has state`() = runBlocking {
val size = 0L val size = 0L
every { kv.hasState() } returns true every { kv.hasState() } returns true
every { full.hasState() } returns false every { full.hasState } returns false
every { kv.getCurrentPackage() } returns packageInfo every { kv.getCurrentPackage() } returns packageInfo
coEvery { kv.finishBackup() } returns TRANSPORT_OK coEvery { kv.finishBackup() } returns TRANSPORT_OK
every { settingsManager.getToken() } returns token
coEvery { backend.save(LegacyAppBackupFile.Metadata(token)) } returns metadataOutputStream
every { kv.getCurrentSize() } returns size every { kv.getCurrentSize() } returns size
every { every {
metadataManager.onPackageBackedUp( metadataManager.onPackageBackedUp(
packageInfo = packageInfo, packageInfo = packageInfo,
type = BackupType.KV, type = BackupType.KV,
size = size, size = size,
metadataOutputStream = metadataOutputStream,
) )
} just Runs } just Runs
every { metadataOutputStream.close() } just Runs
assertEquals(TRANSPORT_OK, backup.finishBackup()) assertEquals(TRANSPORT_OK, backup.finishBackup())
verify { metadataOutputStream.close() }
} }
@Test @Test
fun `finish backup does not upload @pm@ metadata, if it can't do backups`() = runBlocking { fun `finish backup does not upload @pm@ metadata, if it can't do backups`() = runBlocking {
every { kv.hasState() } returns true every { kv.hasState() } returns true
every { full.hasState() } returns false every { full.hasState } returns false
every { kv.getCurrentPackage() } returns pmPackageInfo every { kv.getCurrentPackage() } returns pmPackageInfo
every { kv.getCurrentSize() } returns 42L every { kv.getCurrentSize() } returns 42L
@ -241,29 +212,26 @@ internal class BackupCoordinatorTest : BackupTest() {
@Test @Test
fun `finish backup delegates to full plugin if it has state`() = runBlocking { fun `finish backup delegates to full plugin if it has state`() = runBlocking {
val result = Random.nextInt() val snapshotCreator: SnapshotCreator = mockk()
val size: Long? = null val size: Long = 2345
every { kv.hasState() } returns false every { kv.hasState() } returns false
every { full.hasState() } returns true every { full.hasState } returns true
every { full.getCurrentPackage() } returns packageInfo every { full.currentPackageInfo } returns packageInfo
every { full.finishBackup() } returns result coEvery { full.finishBackup() } returns apkBackupData
every { settingsManager.getToken() } returns token every { appBackupManager.snapshotCreator } returns snapshotCreator
coEvery { backend.save(LegacyAppBackupFile.Metadata(token)) } returns metadataOutputStream every {
every { full.getCurrentSize() } returns size snapshotCreator.onPackageBackedUp(packageInfo, BackupType.FULL, apkBackupData)
} just Runs
every { every {
metadataManager.onPackageBackedUp( metadataManager.onPackageBackedUp(
packageInfo = packageInfo, packageInfo = packageInfo,
type = BackupType.FULL, type = BackupType.FULL,
size = size, size = apkBackupData.size,
metadataOutputStream = metadataOutputStream,
) )
} just Runs } just Runs
every { metadataOutputStream.close() } just Runs
assertEquals(result, backup.finishBackup()) assertEquals(TRANSPORT_OK, backup.finishBackup())
verify { metadataOutputStream.close() }
} }
@Test @Test
@ -271,7 +239,7 @@ internal class BackupCoordinatorTest : BackupTest() {
every { settingsManager.getToken() } returns token every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt every { metadataManager.salt } returns salt
coEvery { coEvery {
full.performFullBackup(packageInfo, fileDescriptor, 0, token, salt) full.performFullBackup(packageInfo, fileDescriptor, 0)
} returns TRANSPORT_OK } returns TRANSPORT_OK
coEvery { apkBackup.backupApkIfNecessary(packageInfo) } just Runs coEvery { apkBackup.backupApkIfNecessary(packageInfo) } just Runs
@ -283,14 +251,14 @@ internal class BackupCoordinatorTest : BackupTest() {
every { settingsManager.getToken() } returns token every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt every { metadataManager.salt } returns salt
coEvery { coEvery {
full.performFullBackup(packageInfo, fileDescriptor, 0, token, salt) full.performFullBackup(packageInfo, fileDescriptor, 0)
} returns TRANSPORT_OK } returns TRANSPORT_OK
expectApkBackupAndMetadataWrite() expectApkBackupAndMetadataWrite()
every { full.getQuota() } returns DEFAULT_QUOTA_FULL_BACKUP every { full.quota } returns DEFAULT_QUOTA_FULL_BACKUP
every { every {
full.checkFullBackupSize(DEFAULT_QUOTA_FULL_BACKUP + 1) full.checkFullBackupSize(DEFAULT_QUOTA_FULL_BACKUP + 1)
} returns TRANSPORT_QUOTA_EXCEEDED } returns TRANSPORT_QUOTA_EXCEEDED
every { full.getCurrentPackage() } returns packageInfo every { full.currentPackageInfo } returns packageInfo
every { every {
metadataManager.onPackageBackupError( metadataManager.onPackageBackupError(
packageInfo, packageInfo,
@ -299,7 +267,7 @@ internal class BackupCoordinatorTest : BackupTest() {
BackupType.FULL BackupType.FULL
) )
} just Runs } just Runs
coEvery { full.cancelFullBackup(token, metadata.salt, false) } just Runs coEvery { full.cancelFullBackup() } just Runs
every { backendManager.backendProperties } returns safProperties every { backendManager.backendProperties } returns safProperties
every { settingsManager.useMeteredNetwork } returns false every { settingsManager.useMeteredNetwork } returns false
every { metadataOutputStream.close() } just Runs every { metadataOutputStream.close() } just Runs
@ -335,12 +303,12 @@ internal class BackupCoordinatorTest : BackupTest() {
every { settingsManager.getToken() } returns token every { settingsManager.getToken() } returns token
every { metadataManager.salt } returns salt every { metadataManager.salt } returns salt
coEvery { coEvery {
full.performFullBackup(packageInfo, fileDescriptor, 0, token, salt) full.performFullBackup(packageInfo, fileDescriptor, 0)
} returns TRANSPORT_OK } returns TRANSPORT_OK
expectApkBackupAndMetadataWrite() expectApkBackupAndMetadataWrite()
every { full.getQuota() } returns DEFAULT_QUOTA_FULL_BACKUP every { full.quota } returns DEFAULT_QUOTA_FULL_BACKUP
every { full.checkFullBackupSize(0) } returns TRANSPORT_PACKAGE_REJECTED every { full.checkFullBackupSize(0) } returns TRANSPORT_PACKAGE_REJECTED
every { full.getCurrentPackage() } returns packageInfo every { full.currentPackageInfo } returns packageInfo
every { every {
metadataManager.onPackageBackupError( metadataManager.onPackageBackupError(
packageInfo, packageInfo,
@ -349,7 +317,7 @@ internal class BackupCoordinatorTest : BackupTest() {
BackupType.FULL BackupType.FULL
) )
} just Runs } just Runs
coEvery { full.cancelFullBackup(token, metadata.salt, false) } just Runs coEvery { full.cancelFullBackup() } just Runs
every { backendManager.backendProperties } returns safProperties every { backendManager.backendProperties } returns safProperties
every { settingsManager.useMeteredNetwork } returns false every { settingsManager.useMeteredNetwork } returns false
every { metadataOutputStream.close() } just Runs every { metadataOutputStream.close() } just Runs

View file

@ -9,50 +9,41 @@ import android.app.backup.BackupTransport.TRANSPORT_ERROR
import android.app.backup.BackupTransport.TRANSPORT_OK import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.Backend
import org.calyxos.seedvault.core.backends.LegacyAppBackupFile
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.io.FileInputStream import java.io.FileInputStream
import java.io.IOException import java.io.IOException
import kotlin.random.Random import kotlin.random.Random
internal class FullBackupTest : BackupTest() { internal class FullBackupTest : BackupTest() {
private val backendManager: BackendManager = mockk() private val backupReceiver = mockk<BackupReceiver>()
private val backend = mockk<Backend>()
private val notificationManager = mockk<BackupNotificationManager>() private val notificationManager = mockk<BackupNotificationManager>()
private val backup = FullBackup( private val backup = FullBackup(
backendManager = backendManager,
settingsManager = settingsManager, settingsManager = settingsManager,
nm = notificationManager, nm = notificationManager,
backupReceiver = backupReceiver,
inputFactory = inputFactory, inputFactory = inputFactory,
crypto = crypto,
) )
private val bytes = ByteArray(23).apply { Random.nextBytes(this) } private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
private val inputStream = mockk<FileInputStream>() private val inputStream = mockk<FileInputStream>()
private val ad = getADForFull(VERSION, packageInfo.packageName) private val backupData = apkBackupData
init {
every { backendManager.backend } returns backend
}
@Test @Test
fun `has no initial state`() { fun `has no initial state`() {
assertFalse(backup.hasState()) assertFalse(backup.hasState)
} }
@Test @Test
@ -99,254 +90,229 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `performFullBackup runs ok`() = runBlocking { fun `performFullBackup runs ok`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData first call over quota`() = runBlocking { fun `sendBackupData first call over quota`() = runBlocking {
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
val numBytes = (quota + 1).toInt() val numBytes = (quota + 1).toInt()
expectSendData(numBytes) expectSendData(numBytes)
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
assertEquals(TRANSPORT_QUOTA_EXCEEDED, backup.sendBackupData(numBytes))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_QUOTA_EXCEEDED, backup.sendBackupData(numBytes))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData second call over quota`() = runBlocking { fun `sendBackupData subsequent calls over quota`() = runBlocking {
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
val numBytes1 = quota.toInt()
expectSendData(numBytes1) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
val numBytes2 = 1 assertTrue(backup.hasState)
expectSendData(numBytes2)
// split up sending data in smaller chunks, so we don't run out of heap space
var sendResult: Int = TRANSPORT_ERROR
val numBytes = (quota / 1024).toInt()
for (i in 0..1024) {
expectSendData(numBytes)
sendResult = backup.sendBackupData(numBytes)
assertTrue(backup.hasState)
if (sendResult == TRANSPORT_QUOTA_EXCEEDED) break
}
assertEquals(TRANSPORT_QUOTA_EXCEEDED, sendResult)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) // in reality, this may not call finishBackup(), but cancelBackup()
assertTrue(backup.hasState()) assertEquals(backupData, backup.finishBackup())
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes1)) assertFalse(backup.hasState)
assertTrue(backup.hasState())
assertEquals(TRANSPORT_QUOTA_EXCEEDED, backup.sendBackupData(numBytes2))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData throws exception when reading from InputStream`() = runBlocking { fun `sendBackupData throws exception when reading from InputStream`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { crypto.newEncryptingStreamV1(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } throws IOException() every { inputStream.read(any(), any(), bytes.size) } throws IOException()
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData throws exception when getting outputStream`() = runBlocking { fun `sendBackupData throws exception when sending data`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name every { inputStream.read(any(), 0, bytes.size) } returns bytes.size
coEvery { backend.save(handle) } throws IOException() coEvery { backupReceiver.addBytes(any()) } throws IOException()
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData throws exception when writing header`() = runBlocking { fun `sendBackupData throws exception when finalizing`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name expectSendData(bytes.size)
coEvery { backend.save(handle) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream assertEquals(TRANSPORT_OK, backup.sendBackupData(bytes.size))
every { outputStream.write(ByteArray(1) { VERSION }) } throws IOException() assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } throws IOException()
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertThrows<IOException> {
assertTrue(backup.hasState()) backup.finishBackup()
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
assertFalse(backup.hasState)
@Test verify { data.close() }
fun `sendBackupData throws exception when writing encrypted data to OutputStream`() =
runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { settingsManager.isQuotaUnlimited() } returns false
every { crypto.newEncryptingStreamV1(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
every { encryptedOutputStream.write(any<ByteArray>()) } throws IOException()
expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `sendBackupData runs ok`() = runBlocking { fun `sendBackupData runs ok`() = runBlocking {
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
val numBytes1 = (quota / 2).toInt()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
val numBytes1 = 2342
expectSendData(numBytes1) expectSendData(numBytes1)
val numBytes2 = (quota / 2).toInt() assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes1))
assertTrue(backup.hasState)
val numBytes2 = 4223
expectSendData(numBytes2) expectSendData(numBytes2)
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes2))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes1))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes2))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
}
@Test
fun `clearBackupData delegates to plugin`() = runBlocking {
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name
coEvery { backend.remove(handle) } just Runs
backup.clearBackupData(packageInfo, token, salt)
} }
@Test @Test
fun `cancel full backup runs ok`() = runBlocking { fun `cancel full backup runs ok`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
expectClearState() expectClearState()
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name
coEvery { backend.remove(handle) } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) backup.cancelFullBackup()
assertTrue(backup.hasState()) assertFalse(backup.hasState)
backup.cancelFullBackup(token, salt, false)
assertFalse(backup.hasState())
} }
@Test @Test
fun `cancel full backup ignores exception when calling plugin`() = runBlocking { fun `cancel full backup throws exception when finalizing`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } throws IOException()
expectClearState() expectClearState()
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name
coEvery { backend.remove(handle) } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) backup.cancelFullBackup()
assertTrue(backup.hasState()) assertFalse(backup.hasState)
backup.cancelFullBackup(token, salt, false)
assertFalse(backup.hasState())
}
@Test
fun `clearState throws exception when flushing OutputStream`() = runBlocking {
every { settingsManager.isQuotaUnlimited() } returns false
every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
val numBytes = 42
expectSendData(numBytes)
every { encryptedOutputStream.flush() } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes))
assertEquals(TRANSPORT_ERROR, backup.finishBackup())
assertFalse(backup.hasState())
}
@Test
fun `clearState ignores exception when closing OutputStream`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { outputStream.flush() } just Runs
every { outputStream.close() } throws IOException()
every { inputStream.close() } just Runs
every { data.close() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `clearState ignores exception when closing InputStream`() = runBlocking { fun `clearState ignores exception when closing InputStream`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
every { outputStream.flush() } just Runs every { outputStream.flush() } just Runs
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { inputStream.close() } throws IOException() every { inputStream.close() } throws IOException()
every { data.close() } just Runs every { data.close() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
} }
@Test @Test
fun `clearState ignores exception when closing ParcelFileDescriptor`() = runBlocking { fun `clearState ignores exception when closing ParcelFileDescriptor`() = runBlocking {
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream() every { backupReceiver.assertFinalized() } just Runs
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0))
assertTrue(backup.hasState)
coEvery { backupReceiver.finalize() } returns backupData
every { outputStream.flush() } just Runs every { outputStream.flush() } just Runs
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
every { data.close() } throws IOException() every { data.close() } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data, 0, token, salt)) assertEquals(backupData, backup.finishBackup())
assertTrue(backup.hasState()) assertFalse(backup.hasState)
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
}
private fun expectInitializeOutputStream() {
every { crypto.getNameForPackage(salt, packageInfo.packageName) } returns name
coEvery {
backend.save(LegacyAppBackupFile.Blob(token, name))
} returns outputStream
every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
} }
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) { private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {
every { inputStream.read(any(), any(), numBytes) } returns readBytes every { inputStream.read(any(), any(), numBytes) } returns readBytes
every { crypto.newEncryptingStreamV1(outputStream, ad) } returns encryptedOutputStream coEvery { backupReceiver.addBytes(any()) } just Runs
every { encryptedOutputStream.write(any<ByteArray>()) } just Runs
} }
private fun expectClearState() { private fun expectClearState() {

View file

@ -14,16 +14,14 @@ import com.stevesoltys.seedvault.backend.LegacyStoragePlugin
import com.stevesoltys.seedvault.coAssertThrows import com.stevesoltys.seedvault.coAssertThrows
import com.stevesoltys.seedvault.getRandomByteArray import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.CapturingSlot import io.mockk.CapturingSlot
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.Backend import org.calyxos.seedvault.core.backends.Backend
import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertArrayEquals
@ -41,9 +39,11 @@ internal class FullRestoreTest : RestoreTest() {
private val backendManager: BackendManager = mockk() private val backendManager: BackendManager = mockk()
private val backend = mockk<Backend>() private val backend = mockk<Backend>()
private val loader = mockk<Loader>()
private val legacyPlugin = mockk<LegacyStoragePlugin>() private val legacyPlugin = mockk<LegacyStoragePlugin>()
private val restore = FullRestore( private val restore = FullRestore(
backendManager = backendManager, backendManager = backendManager,
loader = loader,
legacyPlugin = legacyPlugin, legacyPlugin = legacyPlugin,
outputFactory = outputFactory, outputFactory = outputFactory,
headerReader = headerReader, headerReader = headerReader,
@ -52,7 +52,7 @@ internal class FullRestoreTest : RestoreTest() {
private val encrypted = getRandomByteArray() private val encrypted = getRandomByteArray()
private val outputStream = ByteArrayOutputStream() private val outputStream = ByteArrayOutputStream()
private val ad = getADForFull(VERSION, packageInfo.packageName) private val blobHandles = listOf(apkBlobHandle)
init { init {
every { backendManager.backend } returns backend every { backendManager.backend } returns backend
@ -60,7 +60,7 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `has no initial state`() { fun `has no initial state`() {
assertFalse(restore.hasState()) assertFalse(restore.hasState)
} }
@Test @Test
@ -73,14 +73,14 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `initializing state leaves a state`() { fun `initializing state leaves a state`() {
assertFalse(restore.hasState()) assertFalse(restore.hasState)
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
assertTrue(restore.hasState()) assertTrue(restore.hasState)
} }
@Test @Test
fun `getting chunks without initializing state throws`() { fun `getting chunks without initializing state throws`() {
assertFalse(restore.hasState()) assertFalse(restore.hasState)
coAssertThrows(IllegalStateException::class.java) { coAssertThrows(IllegalStateException::class.java) {
restore.getNextFullRestoreDataChunk(fileDescriptor) restore.getNextFullRestoreDataChunk(fileDescriptor)
} }
@ -88,138 +88,52 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `getting InputStream for package when getting first chunk throws`() = runBlocking { fun `getting InputStream for package when getting first chunk throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
coEvery { backend.load(handle) } throws IOException() coEvery { loader.loadFiles(blobHandles) } throws IOException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
assertEquals( assertEquals(
TRANSPORT_PACKAGE_REJECTED, TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor) restore.getNextFullRestoreDataChunk(fileDescriptor)
) )
verify { fileDescriptor.close() }
} }
@Test @Test
fun `reading version header when getting first chunk throws`() = runBlocking { fun `reading from stream throws general security exception`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
coEvery { backend.load(handle) } returns inputStream coEvery { loader.loadFiles(blobHandles) } throws GeneralSecurityException()
every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `reading unsupported version when getting first chunk`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every {
headerReader.readVersion(inputStream, VERSION)
} throws UnsupportedVersionException(unsupportedVersion)
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStreamV1(inputStream, ad) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `getting decrypted stream when getting first chunk throws general security exception`() =
runBlocking {
restore.initializeState(VERSION, token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every {
crypto.newDecryptingStreamV1(inputStream, ad)
} throws GeneralSecurityException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
verify { fileDescriptor.close() }
} }
@Test @Test
fun `full chunk gets decrypted`() = runBlocking { fun `full chunk gets decrypted`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
initInputStream() coEvery { loader.loadFiles(blobHandles) } returns inputStream
readAndEncryptInputStream(encrypted) readInputStream(encrypted)
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray()) assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore() restore.finishRestore()
assertFalse(restore.hasState()) assertFalse(restore.hasState)
} }
@Test @Test
@Suppress("deprecation") fun `larger data gets decrypted and then return no more data`() = runBlocking {
fun `full chunk gets decrypted from version 0`() = runBlocking {
restore.initializeState(0.toByte(), token, name, packageInfo)
coEvery { legacyPlugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
} returns VersionHeader(0.toByte(), packageInfo.packageName)
every { crypto.decryptSegment(inputStream) } returns encrypted
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState())
}
@Test
fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(Byte.MAX_VALUE, token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every {
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
} throws GeneralSecurityException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
restore.abortFullRestore()
assertFalse(restore.hasState())
}
@Test
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1) val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
val decryptedInputStream = ByteArrayInputStream(encryptedBytes) val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
coEvery { backend.load(handle) } returns inputStream coEvery { loader.loadFiles(blobHandles) } returns decryptedInputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStreamV1(inputStream, ad) } returns decryptedInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
@ -231,38 +145,32 @@ internal class FullRestoreTest : RestoreTest() {
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encryptedBytes, outputStream.toByteArray()) assertArrayEquals(encryptedBytes, outputStream.toByteArray())
restore.finishRestore() restore.finishRestore()
assertFalse(restore.hasState()) assertFalse(restore.hasState)
} }
@Test @Test
fun `aborting full restore closes stream, resets state`() = runBlocking { fun `aborting full restore closes stream, resets state`() = runBlocking {
restore.initializeState(VERSION, token, name, packageInfo) restore.initializeState(VERSION, packageInfo, blobHandles)
initInputStream() coEvery { loader.loadFiles(blobHandles) } returns inputStream
readAndEncryptInputStream(encrypted) readInputStream(encrypted)
restore.getNextFullRestoreDataChunk(fileDescriptor) restore.getNextFullRestoreDataChunk(fileDescriptor)
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
assertEquals(TRANSPORT_OK, restore.abortFullRestore()) assertEquals(TRANSPORT_OK, restore.abortFullRestore())
assertFalse(restore.hasState()) assertFalse(restore.hasState)
} }
private fun initInputStream() { private fun readInputStream(encryptedBytes: ByteArray) {
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStreamV1(inputStream, ad) } returns decryptedInputStream
}
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
val slot = CapturingSlot<ByteArray>() val slot = CapturingSlot<ByteArray>()
every { decryptedInputStream.read(capture(slot)) } answers { every { inputStream.read(capture(slot)) } answers {
encryptedBytes.copyInto(slot.captured) encryptedBytes.copyInto(slot.captured)
encryptedBytes.size encryptedBytes.size
} }
every { decryptedInputStream.close() } just Runs every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
} }

View file

@ -0,0 +1,254 @@
/*
* SPDX-FileCopyrightText: 2024 The Calyx Institute
* SPDX-License-Identifier: Apache-2.0
*/
package com.stevesoltys.seedvault.transport.restore
import android.app.backup.BackupTransport.NO_MORE_DATA
import android.app.backup.BackupTransport.TRANSPORT_ERROR
import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.backend.LegacyStoragePlugin
import com.stevesoltys.seedvault.coAssertThrows
import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.CapturingSlot
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.Backend
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.security.GeneralSecurityException
import kotlin.random.Random
@Suppress("DEPRECATION")
internal class FullRestoreV1Test : RestoreTest() {
private val backendManager: BackendManager = mockk()
private val backend = mockk<Backend>()
private val legacyPlugin = mockk<LegacyStoragePlugin>()
private val restore = FullRestore(
backendManager = backendManager,
loader = mockk(),
legacyPlugin = legacyPlugin,
outputFactory = outputFactory,
headerReader = headerReader,
crypto = crypto,
)
private val encrypted = getRandomByteArray()
private val outputStream = ByteArrayOutputStream()
private val ad = getADForFull(1, packageInfo.packageName)
init {
every { backendManager.backend } returns backend
}
@Test
fun `has no initial state`() {
assertFalse(restore.hasState)
}
@Test
@Suppress("deprecation")
fun `v0 hasDataForPackage() delegates to plugin`() = runBlocking {
val result = Random.nextBoolean()
coEvery { legacyPlugin.hasDataForFullPackage(token, packageInfo) } returns result
assertEquals(result, restore.hasDataForPackage(token, packageInfo))
}
@Test
fun `initializing state leaves a state`() {
assertFalse(restore.hasState)
restore.initializeStateV1(token, name, packageInfo)
assertTrue(restore.hasState)
}
@Test
fun `getting chunks without initializing state throws`() {
assertFalse(restore.hasState)
coAssertThrows(IllegalStateException::class.java) {
restore.getNextFullRestoreDataChunk(fileDescriptor)
}
}
@Test
fun `getting InputStream for package when getting first chunk throws`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `reading version header when getting first chunk throws`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, 1) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `reading unsupported version when getting first chunk`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every {
headerReader.readVersion(inputStream, 1)
} throws UnsupportedVersionException(unsupportedVersion)
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, 1) } returns 1
every { crypto.newDecryptingStreamV1(inputStream, ad) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `getting decrypted stream when getting first chunk throws general security exception`() =
runBlocking {
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, 1) } returns 1
every {
crypto.newDecryptingStreamV1(inputStream, ad)
} throws GeneralSecurityException()
every { fileDescriptor.close() } just Runs
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
}
@Test
fun `full chunk gets decrypted`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
initInputStream()
readAndEncryptInputStream(encrypted)
every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState)
}
@Test
@Suppress("deprecation")
fun `full chunk gets decrypted from version 0`() = runBlocking {
restore.initializeStateV0(token, packageInfo)
coEvery { legacyPlugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
} returns VersionHeader(0.toByte(), packageInfo.packageName)
every { crypto.decryptSegment(inputStream) } returns encrypted
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState)
}
@Test
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
restore.initializeStateV1(token, name, packageInfo)
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, 1) } returns 1
every { crypto.newDecryptingStreamV1(inputStream, ad) } returns decryptedInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(1, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encryptedBytes, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState)
}
@Test
fun `aborting full restore closes stream, resets state`() = runBlocking {
restore.initializeStateV1(token, name, packageInfo)
initInputStream()
readAndEncryptInputStream(encrypted)
restore.getNextFullRestoreDataChunk(fileDescriptor)
every { inputStream.close() } just Runs
assertEquals(TRANSPORT_OK, restore.abortFullRestore())
assertFalse(restore.hasState)
}
private fun initInputStream() {
coEvery { backend.load(handle) } returns inputStream
every { headerReader.readVersion(inputStream, 1) } returns 1
every { crypto.newDecryptingStreamV1(inputStream, ad) } returns decryptedInputStream
}
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
val slot = CapturingSlot<ByteArray>()
every { decryptedInputStream.read(capture(slot)) } answers {
encryptedBytes.copyInto(slot.captured)
encryptedBytes.size
}
every { decryptedInputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
}
}

View file

@ -67,7 +67,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
metadataReader = metadataReader, metadataReader = metadataReader,
) )
private val restorableBackup = RestorableBackup(metadata) private val restorableBackup = RestorableBackup(metadata, repoId, snapshot)
private val inputStream = mockk<InputStream>() private val inputStream = mockk<InputStream>()
private val safStorage: SafProperties = mockk() private val safStorage: SafProperties = mockk()
private val packageInfo2 = PackageInfo().apply { packageName = "org.example2" } private val packageInfo2 = PackageInfo().apply { packageName = "org.example2" }
@ -80,8 +80,10 @@ internal class RestoreCoordinatorTest : TransportTest() {
private val storageName = getRandomString() private val storageName = getRandomString()
init { init {
metadata.packageMetadataMap[packageInfo2.packageName] = metadata.packageMetadataMap[packageInfo2.packageName] = PackageMetadata(
PackageMetadata(backupType = BackupType.FULL) backupType = BackupType.FULL,
chunkIds = listOf(apkChunkId),
)
mockkStatic("com.stevesoltys.seedvault.backend.BackendExtKt") mockkStatic("com.stevesoltys.seedvault.backend.BackendExtKt")
every { backendManager.backend } returns backend every { backendManager.backend } returns backend
@ -200,7 +202,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.beforeStartRestore(restorableBackup) restore.beforeStartRestore(restorableBackup)
assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray)) assertEquals(TRANSPORT_OK, restore.startRestore(token, packageInfoArray))
every { full.hasState() } returns false every { full.hasState } returns false
restore.finishRestore() restore.finishRestore()
restore.beforeStartRestore(restorableBackup) restore.beforeStartRestore(restorableBackup)
@ -306,7 +308,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
coEvery { kv.hasDataForPackage(token, packageInfo) } returns false coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
coEvery { full.hasDataForPackage(token, packageInfo) } returns true coEvery { full.hasDataForPackage(token, packageInfo) } returns true
every { full.initializeState(0x00, token, "", packageInfo) } just Runs every { full.initializeStateV0(token, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM) val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -319,8 +321,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.beforeStartRestore(restorableBackup) restore.beforeStartRestore(restorableBackup)
restore.startRestore(token, packageInfoArray2) restore.startRestore(token, packageInfoArray2)
every { crypto.getNameForPackage(metadata.salt, packageInfo2.packageName) } returns name2 every { full.initializeState(VERSION, packageInfo2, listOf(apkBlobHandle)) } just Runs
every { full.initializeState(VERSION, token, name2, packageInfo2) } just Runs
val expected = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM) val expected = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -339,8 +340,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
every { crypto.getNameForPackage(metadata.salt, packageInfo2.packageName) } returns name2 every { full.initializeState(VERSION, packageInfo2, listOf(apkBlobHandle)) } just Runs
every { full.initializeState(VERSION, token, name2, packageInfo2) } just Runs
val expected2 = val expected2 =
RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM) RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
@ -364,7 +364,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false
coEvery { full.hasDataForPackage(token, packageInfo2) } returns true coEvery { full.hasDataForPackage(token, packageInfo2) } returns true
every { full.initializeState(0.toByte(), token, "", packageInfo2) } just Runs every { full.initializeStateV0(token, packageInfo2) } just Runs
val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM) val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
assertEquals(expected2, restore.nextRestorePackage()) assertEquals(expected2, restore.nextRestorePackage())
@ -430,7 +430,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
fun `finishRestore() delegates to Full if it has state`() { fun `finishRestore() delegates to Full if it has state`() {
val hasState = Random.nextBoolean() val hasState = Random.nextBoolean()
every { full.hasState() } returns hasState every { full.hasState } returns hasState
if (hasState) { if (hasState) {
every { full.finishRestore() } just Runs every { full.finishRestore() } just Runs
} }

View file

@ -69,7 +69,7 @@ internal class RestoreV0IntegrationTest : TransportTest() {
dbManager = dbManager, dbManager = dbManager,
) )
private val fullRestore = private val fullRestore =
FullRestore(backendManager, legacyPlugin, outputFactory, headerReader, cryptoImpl) FullRestore(backendManager, loader, legacyPlugin, outputFactory, headerReader, cryptoImpl)
private val restore = RestoreCoordinator( private val restore = RestoreCoordinator(
context = context, context = context,
crypto = crypto, crypto = crypto,