Fully implement SnapshotManager

which manages interactions with snapshots, such as loading, saving and removing them.
It also keeps a reference to the latestSnapshot that holds important re-usable data.
This commit is contained in:
Torsten Grote 2024-09-12 17:41:15 -03:00
parent 952cdec55d
commit bfa17fa5ec
No known key found for this signature in database
GPG key ID: 3E5F77D92CF891FF
6 changed files with 269 additions and 58 deletions

View file

@ -12,12 +12,18 @@ import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.proto.Snapshot import com.stevesoltys.seedvault.proto.Snapshot
import com.stevesoltys.seedvault.transport.restore.Loader import com.stevesoltys.seedvault.transport.restore.Loader
import io.github.oshai.kotlinlogging.KotlinLogging import io.github.oshai.kotlinlogging.KotlinLogging
import okio.Buffer
import okio.buffer
import okio.sink
import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
/**
* Manages interactions with snapshots, such as loading, saving and removing them.
* Also keeps a reference to the [latestSnapshot] that holds important re-usable data.
*/
internal class SnapshotManager( internal class SnapshotManager(
private val snapshotFolder: File,
private val crypto: Crypto, private val crypto: Crypto,
private val loader: Loader, private val loader: Loader,
private val backendManager: BackendManager, private val backendManager: BackendManager,
@ -32,12 +38,20 @@ internal class SnapshotManager(
var latestSnapshot: Snapshot? = null var latestSnapshot: Snapshot? = null
private set private set
/**
* Call this before starting a backup run with the [handles] of snapshots
* currently available on the backend.
*/
suspend fun onSnapshotsLoaded(handles: List<AppBackupFileType.Snapshot>): List<Snapshot> { suspend fun onSnapshotsLoaded(handles: List<AppBackupFileType.Snapshot>): List<Snapshot> {
return handles.map { snapshotHandle -> return handles.mapNotNull { snapshotHandle ->
// TODO set up local snapshot cache, so we don't need to download those all the time val snapshot = try {
// TODO is it a fatal error when one snapshot is corrupted or couldn't get loaded? loadSnapshot(snapshotHandle)
val snapshot = loader.loadFile(snapshotHandle).use { inputStream -> } catch (e: Exception) {
Snapshot.parseFrom(inputStream) // This isn't ideal, but the show must go on and we take the snapshots we can get.
// After the first load, a snapshot will get cached, so we are not hitting backend.
// TODO use a re-trying backend for snapshot loading
log.error(e) { "Error loading snapshot: $snapshotHandle" }
return@mapNotNull null
} }
// update latest snapshot if this one is more recent // update latest snapshot if this one is more recent
if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot
@ -45,24 +59,61 @@ internal class SnapshotManager(
} }
} }
/**
* Saves the given [snapshot] to the backend and local cache.
*
* @throws IOException or others if saving fails.
*/
@Throws(IOException::class)
suspend fun saveSnapshot(snapshot: Snapshot) { suspend fun saveSnapshot(snapshot: Snapshot) {
val buffer = Buffer() val byteStream = ByteArrayOutputStream()
val bufferStream = buffer.outputStream() byteStream.write(VERSION.toInt())
bufferStream.write(VERSION.toInt()) crypto.newEncryptingStream(byteStream, crypto.getAdForVersion()).use { cryptoStream ->
crypto.newEncryptingStream(bufferStream, crypto.getAdForVersion()).use { cryptoStream ->
ZstdOutputStream(cryptoStream).use { zstdOutputStream -> ZstdOutputStream(cryptoStream).use { zstdOutputStream ->
snapshot.writeTo(zstdOutputStream) snapshot.writeTo(zstdOutputStream)
} }
} }
val sha256ByteString = buffer.sha256() val bytes = byteStream.toByteArray()
val handle = AppBackupFileType.Snapshot(crypto.repoId, sha256ByteString.hex()) val sha256 = crypto.sha256(bytes).toHexString()
// TODO exception handling val snapshotHandle = AppBackupFileType.Snapshot(crypto.repoId, sha256)
backendManager.backend.save(handle).use { outputStream -> backendManager.backend.save(snapshotHandle).use { outputStream ->
outputStream.sink().buffer().apply { outputStream.write(bytes)
writeAll(buffer) }
flush() // needs flushing // save to local cache while at it
try {
if (!snapshotFolder.isDirectory) snapshotFolder.mkdirs()
File(snapshotFolder, snapshotHandle.name).outputStream().use { outputStream ->
outputStream.write(bytes)
} }
} catch (e: Exception) { // we'll let this one pass
log.error(e) { "Error saving snapshot ${snapshotHandle.hash} to cache: " }
} }
} }
/**
* Removes the snapshot referenced by the given [snapshotHandle] from the backend
* and local cache.
*/
@Throws(IOException::class)
suspend fun removeSnapshot(snapshotHandle: AppBackupFileType.Snapshot) {
backendManager.backend.remove(snapshotHandle)
// remove from cache as well
File(snapshotFolder, snapshotHandle.name).delete()
}
/**
* Loads and parses the snapshot referenced by the given [snapshotHandle].
* If a locally cached version exists, the backend will not be hit.
*/
@Throws(IOException::class)
suspend fun loadSnapshot(snapshotHandle: AppBackupFileType.Snapshot): Snapshot {
val file = File(snapshotFolder, snapshotHandle.name)
val inputStream = if (file.isFile) {
loader.loadFile(file, snapshotHandle.hash)
} else {
loader.loadFile(snapshotHandle, file)
}
return inputStream.use { Snapshot.parseFrom(it) }
}
} }

View file

@ -48,21 +48,25 @@ internal class AppBackupManager(
blobCache.populateCache(blobInfos, snapshots) blobCache.populateCache(blobInfos, snapshots)
} }
suspend fun afterBackupFinished(success: Boolean) { suspend fun afterBackupFinished(success: Boolean): Boolean {
log.info { "After backup finished. Success: $success" } log.info { "After backup finished. Success: $success" }
// free up memory by clearing blobs cache // free up memory by clearing blobs cache
blobCache.clear() blobCache.clear()
var result = false
try { try {
if (success) { if (success) {
val snapshot = val snapshot =
snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator") snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator")
keepTrying { keepTrying { // saving this is so important, we even keep trying
snapshotManager.saveSnapshot(snapshot) snapshotManager.saveSnapshot(snapshot)
} }
settingsManager.token = snapshot.token settingsManager.token = snapshot.token
// after snapshot was written, we can clear local cache as its info is in snapshot // after snapshot was written, we can clear local cache as its info is in snapshot
blobCache.clearLocalCache() blobCache.clearLocalCache()
} }
result = true
} catch (e: Exception) {
log.error(e) { "Error finishing backup" }
} finally { } finally {
snapshotCreator = null snapshotCreator = null
} }

View file

@ -8,13 +8,17 @@ package com.stevesoltys.seedvault.transport.backup
import com.stevesoltys.seedvault.transport.SnapshotManager import com.stevesoltys.seedvault.transport.SnapshotManager
import org.koin.android.ext.koin.androidContext import org.koin.android.ext.koin.androidContext
import org.koin.dsl.module import org.koin.dsl.module
import java.io.File
val backupModule = module { val backupModule = module {
single { BackupInitializer(get()) } single { BackupInitializer(get()) }
single { BackupReceiver(get(), get(), get()) } single { BackupReceiver(get(), get(), get()) }
single { BlobCache(androidContext()) } single { BlobCache(androidContext()) }
single { BlobCreator(get(), get()) } single { BlobCreator(get(), get()) }
single { SnapshotManager(get(), get(), get()) } single {
val snapshotFolder = File(androidContext().filesDir, "snapshots")
SnapshotManager(snapshotFolder, get(), get(), get())
}
single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) } single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) }
single { InputFactory() } single { InputFactory() }
single { single {

View file

@ -5,15 +5,18 @@
package com.stevesoltys.seedvault.transport.restore package com.stevesoltys.seedvault.transport.restore
import com.android.internal.R.attr.handle
import com.github.luben.zstd.ZstdInputStream import com.github.luben.zstd.ZstdInputStream
import com.stevesoltys.seedvault.backend.BackendManager import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.File
import java.io.InputStream import java.io.InputStream
import java.io.SequenceInputStream import java.io.SequenceInputStream
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
@ -24,34 +27,30 @@ internal class Loader(
private val backendManager: BackendManager, private val backendManager: BackendManager,
) { ) {
private val log = KotlinLogging.logger {}
/**
* Downloads the given [fileHandle], decrypts and decompresses its content
* and returns the content as a decrypted and decompressed stream.
*
* Attention: The responsibility with closing the returned stream lies with the caller.
*
* @param cacheFile if non-null, the ciphertext of the loaded file will be cached there
* for later loading with [loadFile].
*/
suspend fun loadFile(fileHandle: AppBackupFileType, cacheFile: File? = null): InputStream {
val expectedHash = when (fileHandle) {
is AppBackupFileType.Snapshot -> fileHandle.hash
is AppBackupFileType.Blob -> fileHandle.name
}
return loadFromStream(backendManager.backend.load(fileHandle), expectedHash, cacheFile)
}
/** /**
* The responsibility with closing the returned stream lies with the caller. * The responsibility with closing the returned stream lies with the caller.
*/ */
suspend fun loadFile(handle: AppBackupFileType): InputStream { fun loadFile(file: File, expectedHash: String): InputStream {
// We load the entire ciphertext into memory, return loadFromStream(file.inputStream(), expectedHash)
// so we can check the SHA-256 hash before decrypting and parsing the data.
val cipherText = backendManager.backend.load(handle).use { inputStream ->
inputStream.readAllBytes()
}
// check SHA-256 hash first thing
val sha256 = crypto.sha256(cipherText).toHexString()
val expectedHash = when (handle) {
is AppBackupFileType.Snapshot -> handle.hash
is AppBackupFileType.Blob -> handle.name
}
if (sha256 != expectedHash) {
throw GeneralSecurityException("File had wrong SHA-256 hash: $handle")
}
// check that we can handle the version of that snapshot
val version = cipherText[0]
if (version <= 1) throw GeneralSecurityException("Unexpected version: $version")
if (version > VERSION) throw UnsupportedVersionException(version)
// get associated data for version, used for authenticated decryption
val ad = crypto.getAdForVersion(version)
// skip first version byte when creating cipherText stream
val inputStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
// decrypt and decompress cipherText stream and parse snapshot
return ZstdInputStream(crypto.newDecryptingStream(inputStream, ad))
} }
suspend fun loadFiles(handles: List<AppBackupFileType>): InputStream { suspend fun loadFiles(handles: List<AppBackupFileType>): InputStream {
@ -68,4 +67,38 @@ internal class Loader(
} }
return SequenceInputStream(enumeration) return SequenceInputStream(enumeration)
} }
private fun loadFromStream(
inputStream: InputStream,
expectedHash: String,
cacheFile: File? = null,
): InputStream {
// We load the entire ciphertext into memory,
// so we can check the SHA-256 hash before decrypting and parsing the data.
val cipherText = inputStream.use { it.readAllBytes() }
// check SHA-256 hash first thing
val sha256 = crypto.sha256(cipherText).toHexString()
if (sha256 != expectedHash) {
throw GeneralSecurityException("File had wrong SHA-256 hash: $handle")
}
// check that we can handle the version of that snapshot
val version = cipherText[0]
if (version <= 1) throw GeneralSecurityException("Unexpected version: $version")
if (version > VERSION) throw UnsupportedVersionException(version)
// cache ciperText in cacheFile, if existing
try {
cacheFile?.outputStream()?.use { outputStream ->
outputStream.write(cipherText)
}
} catch (e: Exception) {
log.error(e) { "Error writing cache file $cacheFile: " }
}
// get associated data for version, used for authenticated decryption
val ad = crypto.getAdForVersion(version)
// skip first version byte when creating cipherText stream
val byteStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
// decrypt and decompress cipherText stream and parse snapshot
return ZstdInputStream(crypto.newDecryptingStream(byteStream, ad))
}
} }

View file

@ -11,6 +11,7 @@ import android.app.backup.IBackupObserver
import android.content.Context import android.content.Context
import android.content.pm.ApplicationInfo.FLAG_SYSTEM import android.content.pm.ApplicationInfo.FLAG_SYSTEM
import android.content.pm.PackageManager.NameNotFoundException import android.content.pm.PackageManager.NameNotFoundException
import android.os.Looper
import android.util.Log import android.util.Log
import android.util.Log.INFO import android.util.Log.INFO
import android.util.Log.isLoggable import android.util.Log.isLoggable
@ -136,7 +137,7 @@ internal class NotificationBackupObserver(
if (isLoggable(TAG, INFO)) { if (isLoggable(TAG, INFO)) {
Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status") Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status")
} }
val success = status == 0 var success = status == 0
val size = if (success) metadataManager.getPackagesBackupSize() else 0L val size = if (success) metadataManager.getPackagesBackupSize() else 0L
val total = try { val total = try {
packageService.allUserPackages.size packageService.allUserPackages.size
@ -144,11 +145,10 @@ internal class NotificationBackupObserver(
Log.e(TAG, "Error getting number of all user packages: ", e) Log.e(TAG, "Error getting number of all user packages: ", e)
requestedPackages requestedPackages
} }
// TODO handle exceptions
runBlocking { runBlocking {
// TODO check if UI thread check(!Looper.getMainLooper().isCurrentThread)
Log.d("TAG", "Finalizing backup...") Log.d(TAG, "Finalizing backup...")
appBackupManager.afterBackupFinished(success) success = appBackupManager.afterBackupFinished(success)
} }
nm.onBackupFinished(success, numPackagesToReport, total, size) nm.onBackupFinished(success, numPackagesToReport, total, size)
} }

View file

@ -6,9 +6,13 @@
package com.stevesoltys.seedvault.transport package com.stevesoltys.seedvault.transport
import com.stevesoltys.seedvault.backend.BackendManager import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.proto.snapshot
import com.stevesoltys.seedvault.transport.restore.Loader import com.stevesoltys.seedvault.transport.restore.Loader
import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every import io.mockk.every
import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.slot import io.mockk.slot
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
@ -16,13 +20,19 @@ import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.backends.Backend import org.calyxos.seedvault.core.backends.Backend
import org.calyxos.seedvault.core.toByteArrayFromHex import org.calyxos.seedvault.core.toByteArrayFromHex
import org.calyxos.seedvault.core.toHexString import org.calyxos.seedvault.core.toHexString
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.io.TempDir
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.nio.file.Path
import java.security.MessageDigest import java.security.MessageDigest
import kotlin.random.Random import kotlin.random.Random
@ -30,17 +40,104 @@ internal class SnapshotManagerTest : TransportTest() {
private val backendManager: BackendManager = mockk() private val backendManager: BackendManager = mockk()
private val backend: Backend = mockk() private val backend: Backend = mockk()
private val loader: Loader = mockk()
private val loader = Loader(crypto, backendManager) // need a real loader private val messageDigest = MessageDigest.getInstance("SHA-256")
private val snapshotManager = SnapshotManager(crypto, loader, backendManager)
private val ad = Random.nextBytes(1) private val ad = Random.nextBytes(1)
private val passThroughOutputStream = slot<OutputStream>() private val passThroughOutputStream = slot<OutputStream>()
private val passThroughInputStream = slot<InputStream>() private val passThroughInputStream = slot<InputStream>()
private val snapshotHandle = slot<AppBackupFileType.Snapshot>() private val snapshotHandle = slot<AppBackupFileType.Snapshot>()
init {
every { backendManager.backend } returns backend
}
@Test @Test
fun `test saving and loading`() = runBlocking { fun `test onSnapshotsLoaded sets latestSnapshot`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData1 = snapshot { token = 20 }.toByteArray()
val snapshotData2 = snapshot { token = 10 }.toByteArray()
val inputStream1 = ByteArrayInputStream(snapshotData1)
val inputStream2 = ByteArrayInputStream(snapshotData2)
val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1)
val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2)
coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream1
coEvery { loader.loadFile(snapshotHandle2, any()) } returns inputStream2
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2))
// snapshot with largest token is latest
assertEquals(20, snapshotManager.latestSnapshot?.token)
}
@Test
fun `saveSnapshot saves to local cache`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
val outputStream = ByteArrayOutputStream()
every { crypto.getAdForVersion() } returns ad
every { crypto.newEncryptingStream(capture(passThroughOutputStream), ad) } answers {
passThroughOutputStream.captured // not really encrypting here
}
every { crypto.sha256(any()) } returns chunkId1.toByteArrayFromHex()
every { crypto.repoId } returns repoId
coEvery { backend.save(snapshotHandle) } returns outputStream
snapshotManager.saveSnapshot(snapshot)
val snapshotFile = File(tmpDir.toString(), snapshotHandle.name)
assertTrue(snapshotFile.isFile)
assertTrue(outputStream.size() > 0)
val cachedBytes = snapshotFile.inputStream().use { it.readAllBytes() }
assertArrayEquals(outputStream.toByteArray(), cachedBytes)
}
@Test
fun `snapshot loads from cache without backend`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData = snapshot { token = 1337 }.toByteArray()
val inputStream = ByteArrayInputStream(snapshotData)
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
// create cached file
val file = File(tmpDir.toString(), snapshotHandle.name)
file.outputStream().use { it.write(snapshotData) }
coEvery { loader.loadFile(file, snapshotHandle.hash) } returns inputStream
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle))
coVerify(exactly = 0) { // did not load from backend
loader.loadFile(snapshotHandle, any())
}
}
@Test
fun `failing to load a snapshot isn't fatal`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData = snapshot { token = 42 }.toByteArray()
val inputStream = ByteArrayInputStream(snapshotData)
val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1)
val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2)
coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream
coEvery { loader.loadFile(snapshotHandle2, any()) } throws IOException()
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2))
// still one snapshot survived and we didn't crash
assertEquals(42, snapshotManager.latestSnapshot?.token)
}
@Test
fun `test saving and loading`(@TempDir tmpDir: Path) = runBlocking {
val loader = Loader(crypto, backendManager) // need a real loader
val snapshotManager = getSnapshotManager(File(tmpDir.toString()), loader)
val bytes = slot<ByteArray>()
val outputStream = ByteArrayOutputStream() val outputStream = ByteArrayOutputStream()
every { crypto.getAdForVersion() } returns ad every { crypto.getAdForVersion() } returns ad
@ -48,13 +145,14 @@ internal class SnapshotManagerTest : TransportTest() {
passThroughOutputStream.captured // not really encrypting here passThroughOutputStream.captured // not really encrypting here
} }
every { crypto.repoId } returns repoId every { crypto.repoId } returns repoId
every { backendManager.backend } returns backend every { crypto.sha256(capture(bytes)) } answers {
messageDigest.digest(bytes.captured)
}
coEvery { backend.save(capture(snapshotHandle)) } returns outputStream coEvery { backend.save(capture(snapshotHandle)) } returns outputStream
snapshotManager.saveSnapshot(snapshot) snapshotManager.saveSnapshot(snapshot)
// check that file content hash matches snapshot hash // check that file content hash matches snapshot hash
val messageDigest = MessageDigest.getInstance("SHA-256")
assertEquals( assertEquals(
messageDigest.digest(outputStream.toByteArray()).toHexString(), messageDigest.digest(outputStream.toByteArray()).toHexString(),
snapshotHandle.captured.hash, snapshotHandle.captured.hash,
@ -75,4 +173,25 @@ internal class SnapshotManagerTest : TransportTest() {
assertEquals(snapshot, snapshots[0]) assertEquals(snapshot, snapshots[0])
} }
} }
@Test
fun `remove snapshot removes from backend and cache`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
val file = File(tmpDir.toString(), snapshotHandle.name)
file.createNewFile()
assertTrue(file.isFile)
coEvery { backend.remove(snapshotHandle) } just Runs
snapshotManager.removeSnapshot(snapshotHandle)
assertFalse(file.exists())
coVerify { backend.remove(snapshotHandle) }
}
private fun getSnapshotManager(tmpFolder: File, loader: Loader = this.loader): SnapshotManager {
return SnapshotManager(tmpFolder, crypto, loader, backendManager)
}
} }