Fully implement SnapshotManager

which manages interactions with snapshots, such as loading, saving and removing them.
It also keeps a reference to the latestSnapshot that holds important re-usable data.
This commit is contained in:
Torsten Grote 2024-09-12 17:41:15 -03:00
parent 952cdec55d
commit bfa17fa5ec
No known key found for this signature in database
GPG key ID: 3E5F77D92CF891FF
6 changed files with 269 additions and 58 deletions

View file

@ -12,12 +12,18 @@ import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.proto.Snapshot
import com.stevesoltys.seedvault.transport.restore.Loader
import io.github.oshai.kotlinlogging.KotlinLogging
import okio.Buffer
import okio.buffer
import okio.sink
import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
/**
* Manages interactions with snapshots, such as loading, saving and removing them.
* Also keeps a reference to the [latestSnapshot] that holds important re-usable data.
*/
internal class SnapshotManager(
private val snapshotFolder: File,
private val crypto: Crypto,
private val loader: Loader,
private val backendManager: BackendManager,
@ -32,12 +38,20 @@ internal class SnapshotManager(
var latestSnapshot: Snapshot? = null
private set
/**
* Call this before starting a backup run with the [handles] of snapshots
* currently available on the backend.
*/
suspend fun onSnapshotsLoaded(handles: List<AppBackupFileType.Snapshot>): List<Snapshot> {
return handles.map { snapshotHandle ->
// TODO set up local snapshot cache, so we don't need to download those all the time
// TODO is it a fatal error when one snapshot is corrupted or couldn't get loaded?
val snapshot = loader.loadFile(snapshotHandle).use { inputStream ->
Snapshot.parseFrom(inputStream)
return handles.mapNotNull { snapshotHandle ->
val snapshot = try {
loadSnapshot(snapshotHandle)
} catch (e: Exception) {
// This isn't ideal, but the show must go on and we take the snapshots we can get.
// After the first load, a snapshot will get cached, so we are not hitting backend.
// TODO use a re-trying backend for snapshot loading
log.error(e) { "Error loading snapshot: $snapshotHandle" }
return@mapNotNull null
}
// update latest snapshot if this one is more recent
if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot
@ -45,24 +59,61 @@ internal class SnapshotManager(
}
}
/**
* Saves the given [snapshot] to the backend and local cache.
*
* @throws IOException or others if saving fails.
*/
@Throws(IOException::class)
suspend fun saveSnapshot(snapshot: Snapshot) {
val buffer = Buffer()
val bufferStream = buffer.outputStream()
bufferStream.write(VERSION.toInt())
crypto.newEncryptingStream(bufferStream, crypto.getAdForVersion()).use { cryptoStream ->
val byteStream = ByteArrayOutputStream()
byteStream.write(VERSION.toInt())
crypto.newEncryptingStream(byteStream, crypto.getAdForVersion()).use { cryptoStream ->
ZstdOutputStream(cryptoStream).use { zstdOutputStream ->
snapshot.writeTo(zstdOutputStream)
}
}
val sha256ByteString = buffer.sha256()
val handle = AppBackupFileType.Snapshot(crypto.repoId, sha256ByteString.hex())
// TODO exception handling
backendManager.backend.save(handle).use { outputStream ->
outputStream.sink().buffer().apply {
writeAll(buffer)
flush() // needs flushing
val bytes = byteStream.toByteArray()
val sha256 = crypto.sha256(bytes).toHexString()
val snapshotHandle = AppBackupFileType.Snapshot(crypto.repoId, sha256)
backendManager.backend.save(snapshotHandle).use { outputStream ->
outputStream.write(bytes)
}
// save to local cache while at it
try {
if (!snapshotFolder.isDirectory) snapshotFolder.mkdirs()
File(snapshotFolder, snapshotHandle.name).outputStream().use { outputStream ->
outputStream.write(bytes)
}
} catch (e: Exception) { // we'll let this one pass
log.error(e) { "Error saving snapshot ${snapshotHandle.hash} to cache: " }
}
}
/**
* Removes the snapshot referenced by the given [snapshotHandle] from the backend
* and local cache.
*/
@Throws(IOException::class)
suspend fun removeSnapshot(snapshotHandle: AppBackupFileType.Snapshot) {
backendManager.backend.remove(snapshotHandle)
// remove from cache as well
File(snapshotFolder, snapshotHandle.name).delete()
}
/**
* Loads and parses the snapshot referenced by the given [snapshotHandle].
* If a locally cached version exists, the backend will not be hit.
*/
@Throws(IOException::class)
suspend fun loadSnapshot(snapshotHandle: AppBackupFileType.Snapshot): Snapshot {
val file = File(snapshotFolder, snapshotHandle.name)
val inputStream = if (file.isFile) {
loader.loadFile(file, snapshotHandle.hash)
} else {
loader.loadFile(snapshotHandle, file)
}
return inputStream.use { Snapshot.parseFrom(it) }
}
}

View file

@ -48,21 +48,25 @@ internal class AppBackupManager(
blobCache.populateCache(blobInfos, snapshots)
}
suspend fun afterBackupFinished(success: Boolean) {
suspend fun afterBackupFinished(success: Boolean): Boolean {
log.info { "After backup finished. Success: $success" }
// free up memory by clearing blobs cache
blobCache.clear()
var result = false
try {
if (success) {
val snapshot =
snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator")
keepTrying {
keepTrying { // saving this is so important, we even keep trying
snapshotManager.saveSnapshot(snapshot)
}
settingsManager.token = snapshot.token
// after snapshot was written, we can clear local cache as its info is in snapshot
blobCache.clearLocalCache()
}
result = true
} catch (e: Exception) {
log.error(e) { "Error finishing backup" }
} finally {
snapshotCreator = null
}

View file

@ -8,13 +8,17 @@ package com.stevesoltys.seedvault.transport.backup
import com.stevesoltys.seedvault.transport.SnapshotManager
import org.koin.android.ext.koin.androidContext
import org.koin.dsl.module
import java.io.File
val backupModule = module {
single { BackupInitializer(get()) }
single { BackupReceiver(get(), get(), get()) }
single { BlobCache(androidContext()) }
single { BlobCreator(get(), get()) }
single { SnapshotManager(get(), get(), get()) }
single {
val snapshotFolder = File(androidContext().filesDir, "snapshots")
SnapshotManager(snapshotFolder, get(), get(), get())
}
single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) }
single { InputFactory() }
single {

View file

@ -5,15 +5,18 @@
package com.stevesoltys.seedvault.transport.restore
import com.android.internal.R.attr.handle
import com.github.luben.zstd.ZstdInputStream
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayInputStream
import java.io.File
import java.io.InputStream
import java.io.SequenceInputStream
import java.security.GeneralSecurityException
@ -24,34 +27,30 @@ internal class Loader(
private val backendManager: BackendManager,
) {
private val log = KotlinLogging.logger {}
/**
* Downloads the given [fileHandle], decrypts and decompresses its content
* and returns the content as a decrypted and decompressed stream.
*
* Attention: The responsibility with closing the returned stream lies with the caller.
*
* @param cacheFile if non-null, the ciphertext of the loaded file will be cached there
* for later loading with [loadFile].
*/
suspend fun loadFile(fileHandle: AppBackupFileType, cacheFile: File? = null): InputStream {
val expectedHash = when (fileHandle) {
is AppBackupFileType.Snapshot -> fileHandle.hash
is AppBackupFileType.Blob -> fileHandle.name
}
return loadFromStream(backendManager.backend.load(fileHandle), expectedHash, cacheFile)
}
/**
* The responsibility with closing the returned stream lies with the caller.
*/
suspend fun loadFile(handle: AppBackupFileType): InputStream {
// We load the entire ciphertext into memory,
// so we can check the SHA-256 hash before decrypting and parsing the data.
val cipherText = backendManager.backend.load(handle).use { inputStream ->
inputStream.readAllBytes()
}
// check SHA-256 hash first thing
val sha256 = crypto.sha256(cipherText).toHexString()
val expectedHash = when (handle) {
is AppBackupFileType.Snapshot -> handle.hash
is AppBackupFileType.Blob -> handle.name
}
if (sha256 != expectedHash) {
throw GeneralSecurityException("File had wrong SHA-256 hash: $handle")
}
// check that we can handle the version of that snapshot
val version = cipherText[0]
if (version <= 1) throw GeneralSecurityException("Unexpected version: $version")
if (version > VERSION) throw UnsupportedVersionException(version)
// get associated data for version, used for authenticated decryption
val ad = crypto.getAdForVersion(version)
// skip first version byte when creating cipherText stream
val inputStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
// decrypt and decompress cipherText stream and parse snapshot
return ZstdInputStream(crypto.newDecryptingStream(inputStream, ad))
fun loadFile(file: File, expectedHash: String): InputStream {
return loadFromStream(file.inputStream(), expectedHash)
}
suspend fun loadFiles(handles: List<AppBackupFileType>): InputStream {
@ -68,4 +67,38 @@ internal class Loader(
}
return SequenceInputStream(enumeration)
}
private fun loadFromStream(
inputStream: InputStream,
expectedHash: String,
cacheFile: File? = null,
): InputStream {
// We load the entire ciphertext into memory,
// so we can check the SHA-256 hash before decrypting and parsing the data.
val cipherText = inputStream.use { it.readAllBytes() }
// check SHA-256 hash first thing
val sha256 = crypto.sha256(cipherText).toHexString()
if (sha256 != expectedHash) {
throw GeneralSecurityException("File had wrong SHA-256 hash: $handle")
}
// check that we can handle the version of that snapshot
val version = cipherText[0]
if (version <= 1) throw GeneralSecurityException("Unexpected version: $version")
if (version > VERSION) throw UnsupportedVersionException(version)
// cache ciperText in cacheFile, if existing
try {
cacheFile?.outputStream()?.use { outputStream ->
outputStream.write(cipherText)
}
} catch (e: Exception) {
log.error(e) { "Error writing cache file $cacheFile: " }
}
// get associated data for version, used for authenticated decryption
val ad = crypto.getAdForVersion(version)
// skip first version byte when creating cipherText stream
val byteStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
// decrypt and decompress cipherText stream and parse snapshot
return ZstdInputStream(crypto.newDecryptingStream(byteStream, ad))
}
}

View file

@ -11,6 +11,7 @@ import android.app.backup.IBackupObserver
import android.content.Context
import android.content.pm.ApplicationInfo.FLAG_SYSTEM
import android.content.pm.PackageManager.NameNotFoundException
import android.os.Looper
import android.util.Log
import android.util.Log.INFO
import android.util.Log.isLoggable
@ -136,7 +137,7 @@ internal class NotificationBackupObserver(
if (isLoggable(TAG, INFO)) {
Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status")
}
val success = status == 0
var success = status == 0
val size = if (success) metadataManager.getPackagesBackupSize() else 0L
val total = try {
packageService.allUserPackages.size
@ -144,11 +145,10 @@ internal class NotificationBackupObserver(
Log.e(TAG, "Error getting number of all user packages: ", e)
requestedPackages
}
// TODO handle exceptions
runBlocking {
// TODO check if UI thread
Log.d("TAG", "Finalizing backup...")
appBackupManager.afterBackupFinished(success)
check(!Looper.getMainLooper().isCurrentThread)
Log.d(TAG, "Finalizing backup...")
success = appBackupManager.afterBackupFinished(success)
}
nm.onBackupFinished(success, numPackagesToReport, total, size)
}

View file

@ -6,9 +6,13 @@
package com.stevesoltys.seedvault.transport
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.proto.snapshot
import com.stevesoltys.seedvault.transport.restore.Loader
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.slot
import kotlinx.coroutines.runBlocking
@ -16,13 +20,19 @@ import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.backends.Backend
import org.calyxos.seedvault.core.toByteArrayFromHex
import org.calyxos.seedvault.core.toHexString
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.io.TempDir
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.file.Path
import java.security.MessageDigest
import kotlin.random.Random
@ -30,17 +40,104 @@ internal class SnapshotManagerTest : TransportTest() {
private val backendManager: BackendManager = mockk()
private val backend: Backend = mockk()
private val loader: Loader = mockk()
private val loader = Loader(crypto, backendManager) // need a real loader
private val snapshotManager = SnapshotManager(crypto, loader, backendManager)
private val messageDigest = MessageDigest.getInstance("SHA-256")
private val ad = Random.nextBytes(1)
private val passThroughOutputStream = slot<OutputStream>()
private val passThroughInputStream = slot<InputStream>()
private val snapshotHandle = slot<AppBackupFileType.Snapshot>()
init {
every { backendManager.backend } returns backend
}
@Test
fun `test saving and loading`() = runBlocking {
fun `test onSnapshotsLoaded sets latestSnapshot`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData1 = snapshot { token = 20 }.toByteArray()
val snapshotData2 = snapshot { token = 10 }.toByteArray()
val inputStream1 = ByteArrayInputStream(snapshotData1)
val inputStream2 = ByteArrayInputStream(snapshotData2)
val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1)
val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2)
coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream1
coEvery { loader.loadFile(snapshotHandle2, any()) } returns inputStream2
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2))
// snapshot with largest token is latest
assertEquals(20, snapshotManager.latestSnapshot?.token)
}
@Test
fun `saveSnapshot saves to local cache`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
val outputStream = ByteArrayOutputStream()
every { crypto.getAdForVersion() } returns ad
every { crypto.newEncryptingStream(capture(passThroughOutputStream), ad) } answers {
passThroughOutputStream.captured // not really encrypting here
}
every { crypto.sha256(any()) } returns chunkId1.toByteArrayFromHex()
every { crypto.repoId } returns repoId
coEvery { backend.save(snapshotHandle) } returns outputStream
snapshotManager.saveSnapshot(snapshot)
val snapshotFile = File(tmpDir.toString(), snapshotHandle.name)
assertTrue(snapshotFile.isFile)
assertTrue(outputStream.size() > 0)
val cachedBytes = snapshotFile.inputStream().use { it.readAllBytes() }
assertArrayEquals(outputStream.toByteArray(), cachedBytes)
}
@Test
fun `snapshot loads from cache without backend`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData = snapshot { token = 1337 }.toByteArray()
val inputStream = ByteArrayInputStream(snapshotData)
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
// create cached file
val file = File(tmpDir.toString(), snapshotHandle.name)
file.outputStream().use { it.write(snapshotData) }
coEvery { loader.loadFile(file, snapshotHandle.hash) } returns inputStream
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle))
coVerify(exactly = 0) { // did not load from backend
loader.loadFile(snapshotHandle, any())
}
}
@Test
fun `failing to load a snapshot isn't fatal`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotData = snapshot { token = 42 }.toByteArray()
val inputStream = ByteArrayInputStream(snapshotData)
val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1)
val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2)
coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream
coEvery { loader.loadFile(snapshotHandle2, any()) } throws IOException()
snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2))
// still one snapshot survived and we didn't crash
assertEquals(42, snapshotManager.latestSnapshot?.token)
}
@Test
fun `test saving and loading`(@TempDir tmpDir: Path) = runBlocking {
val loader = Loader(crypto, backendManager) // need a real loader
val snapshotManager = getSnapshotManager(File(tmpDir.toString()), loader)
val bytes = slot<ByteArray>()
val outputStream = ByteArrayOutputStream()
every { crypto.getAdForVersion() } returns ad
@ -48,13 +145,14 @@ internal class SnapshotManagerTest : TransportTest() {
passThroughOutputStream.captured // not really encrypting here
}
every { crypto.repoId } returns repoId
every { backendManager.backend } returns backend
every { crypto.sha256(capture(bytes)) } answers {
messageDigest.digest(bytes.captured)
}
coEvery { backend.save(capture(snapshotHandle)) } returns outputStream
snapshotManager.saveSnapshot(snapshot)
// check that file content hash matches snapshot hash
val messageDigest = MessageDigest.getInstance("SHA-256")
assertEquals(
messageDigest.digest(outputStream.toByteArray()).toHexString(),
snapshotHandle.captured.hash,
@ -75,4 +173,25 @@ internal class SnapshotManagerTest : TransportTest() {
assertEquals(snapshot, snapshots[0])
}
}
@Test
fun `remove snapshot removes from backend and cache`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1)
val file = File(tmpDir.toString(), snapshotHandle.name)
file.createNewFile()
assertTrue(file.isFile)
coEvery { backend.remove(snapshotHandle) } just Runs
snapshotManager.removeSnapshot(snapshotHandle)
assertFalse(file.exists())
coVerify { backend.remove(snapshotHandle) }
}
private fun getSnapshotManager(tmpFolder: File, loader: Loader = this.loader): SnapshotManager {
return SnapshotManager(tmpFolder, crypto, loader, backendManager)
}
}