diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt index e413667d..9ee20534 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt @@ -12,12 +12,18 @@ import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.proto.Snapshot import com.stevesoltys.seedvault.transport.restore.Loader import io.github.oshai.kotlinlogging.KotlinLogging -import okio.Buffer -import okio.buffer -import okio.sink import org.calyxos.seedvault.core.backends.AppBackupFileType +import org.calyxos.seedvault.core.toHexString +import java.io.ByteArrayOutputStream +import java.io.File +import java.io.IOException +/** + * Manages interactions with snapshots, such as loading, saving and removing them. + * Also keeps a reference to the [latestSnapshot] that holds important re-usable data. + */ internal class SnapshotManager( + private val snapshotFolder: File, private val crypto: Crypto, private val loader: Loader, private val backendManager: BackendManager, @@ -32,12 +38,20 @@ internal class SnapshotManager( var latestSnapshot: Snapshot? = null private set + /** + * Call this before starting a backup run with the [handles] of snapshots + * currently available on the backend. + */ suspend fun onSnapshotsLoaded(handles: List): List { - return handles.map { snapshotHandle -> - // TODO set up local snapshot cache, so we don't need to download those all the time - // TODO is it a fatal error when one snapshot is corrupted or couldn't get loaded? - val snapshot = loader.loadFile(snapshotHandle).use { inputStream -> - Snapshot.parseFrom(inputStream) + return handles.mapNotNull { snapshotHandle -> + val snapshot = try { + loadSnapshot(snapshotHandle) + } catch (e: Exception) { + // This isn't ideal, but the show must go on and we take the snapshots we can get. + // After the first load, a snapshot will get cached, so we are not hitting backend. + // TODO use a re-trying backend for snapshot loading + log.error(e) { "Error loading snapshot: $snapshotHandle" } + return@mapNotNull null } // update latest snapshot if this one is more recent if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot @@ -45,24 +59,61 @@ internal class SnapshotManager( } } + /** + * Saves the given [snapshot] to the backend and local cache. + * + * @throws IOException or others if saving fails. + */ + @Throws(IOException::class) suspend fun saveSnapshot(snapshot: Snapshot) { - val buffer = Buffer() - val bufferStream = buffer.outputStream() - bufferStream.write(VERSION.toInt()) - crypto.newEncryptingStream(bufferStream, crypto.getAdForVersion()).use { cryptoStream -> + val byteStream = ByteArrayOutputStream() + byteStream.write(VERSION.toInt()) + crypto.newEncryptingStream(byteStream, crypto.getAdForVersion()).use { cryptoStream -> ZstdOutputStream(cryptoStream).use { zstdOutputStream -> snapshot.writeTo(zstdOutputStream) } } - val sha256ByteString = buffer.sha256() - val handle = AppBackupFileType.Snapshot(crypto.repoId, sha256ByteString.hex()) - // TODO exception handling - backendManager.backend.save(handle).use { outputStream -> - outputStream.sink().buffer().apply { - writeAll(buffer) - flush() // needs flushing + val bytes = byteStream.toByteArray() + val sha256 = crypto.sha256(bytes).toHexString() + val snapshotHandle = AppBackupFileType.Snapshot(crypto.repoId, sha256) + backendManager.backend.save(snapshotHandle).use { outputStream -> + outputStream.write(bytes) + } + // save to local cache while at it + try { + if (!snapshotFolder.isDirectory) snapshotFolder.mkdirs() + File(snapshotFolder, snapshotHandle.name).outputStream().use { outputStream -> + outputStream.write(bytes) } + } catch (e: Exception) { // we'll let this one pass + log.error(e) { "Error saving snapshot ${snapshotHandle.hash} to cache: " } } } + /** + * Removes the snapshot referenced by the given [snapshotHandle] from the backend + * and local cache. + */ + @Throws(IOException::class) + suspend fun removeSnapshot(snapshotHandle: AppBackupFileType.Snapshot) { + backendManager.backend.remove(snapshotHandle) + // remove from cache as well + File(snapshotFolder, snapshotHandle.name).delete() + } + + /** + * Loads and parses the snapshot referenced by the given [snapshotHandle]. + * If a locally cached version exists, the backend will not be hit. + */ + @Throws(IOException::class) + suspend fun loadSnapshot(snapshotHandle: AppBackupFileType.Snapshot): Snapshot { + val file = File(snapshotFolder, snapshotHandle.name) + val inputStream = if (file.isFile) { + loader.loadFile(file, snapshotHandle.hash) + } else { + loader.loadFile(snapshotHandle, file) + } + return inputStream.use { Snapshot.parseFrom(it) } + } + } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt index bc47d6fa..1c89d2a0 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt @@ -48,21 +48,25 @@ internal class AppBackupManager( blobCache.populateCache(blobInfos, snapshots) } - suspend fun afterBackupFinished(success: Boolean) { + suspend fun afterBackupFinished(success: Boolean): Boolean { log.info { "After backup finished. Success: $success" } // free up memory by clearing blobs cache blobCache.clear() + var result = false try { if (success) { val snapshot = snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator") - keepTrying { + keepTrying { // saving this is so important, we even keep trying snapshotManager.saveSnapshot(snapshot) } settingsManager.token = snapshot.token // after snapshot was written, we can clear local cache as its info is in snapshot blobCache.clearLocalCache() } + result = true + } catch (e: Exception) { + log.error(e) { "Error finishing backup" } } finally { snapshotCreator = null } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt index 4bb15123..cbde0a22 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt @@ -8,13 +8,17 @@ package com.stevesoltys.seedvault.transport.backup import com.stevesoltys.seedvault.transport.SnapshotManager import org.koin.android.ext.koin.androidContext import org.koin.dsl.module +import java.io.File val backupModule = module { single { BackupInitializer(get()) } single { BackupReceiver(get(), get(), get()) } single { BlobCache(androidContext()) } single { BlobCreator(get(), get()) } - single { SnapshotManager(get(), get(), get()) } + single { + val snapshotFolder = File(androidContext().filesDir, "snapshots") + SnapshotManager(snapshotFolder, get(), get(), get()) + } single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) } single { InputFactory() } single { diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt index 66a3246e..c165070a 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt @@ -5,15 +5,18 @@ package com.stevesoltys.seedvault.transport.restore +import com.android.internal.R.attr.handle import com.github.luben.zstd.ZstdInputStream import com.stevesoltys.seedvault.backend.BackendManager import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION +import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.runBlocking import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.toHexString import java.io.ByteArrayInputStream +import java.io.File import java.io.InputStream import java.io.SequenceInputStream import java.security.GeneralSecurityException @@ -24,34 +27,30 @@ internal class Loader( private val backendManager: BackendManager, ) { + private val log = KotlinLogging.logger {} + + /** + * Downloads the given [fileHandle], decrypts and decompresses its content + * and returns the content as a decrypted and decompressed stream. + * + * Attention: The responsibility with closing the returned stream lies with the caller. + * + * @param cacheFile if non-null, the ciphertext of the loaded file will be cached there + * for later loading with [loadFile]. + */ + suspend fun loadFile(fileHandle: AppBackupFileType, cacheFile: File? = null): InputStream { + val expectedHash = when (fileHandle) { + is AppBackupFileType.Snapshot -> fileHandle.hash + is AppBackupFileType.Blob -> fileHandle.name + } + return loadFromStream(backendManager.backend.load(fileHandle), expectedHash, cacheFile) + } + /** * The responsibility with closing the returned stream lies with the caller. */ - suspend fun loadFile(handle: AppBackupFileType): InputStream { - // We load the entire ciphertext into memory, - // so we can check the SHA-256 hash before decrypting and parsing the data. - val cipherText = backendManager.backend.load(handle).use { inputStream -> - inputStream.readAllBytes() - } - // check SHA-256 hash first thing - val sha256 = crypto.sha256(cipherText).toHexString() - val expectedHash = when (handle) { - is AppBackupFileType.Snapshot -> handle.hash - is AppBackupFileType.Blob -> handle.name - } - if (sha256 != expectedHash) { - throw GeneralSecurityException("File had wrong SHA-256 hash: $handle") - } - // check that we can handle the version of that snapshot - val version = cipherText[0] - if (version <= 1) throw GeneralSecurityException("Unexpected version: $version") - if (version > VERSION) throw UnsupportedVersionException(version) - // get associated data for version, used for authenticated decryption - val ad = crypto.getAdForVersion(version) - // skip first version byte when creating cipherText stream - val inputStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1) - // decrypt and decompress cipherText stream and parse snapshot - return ZstdInputStream(crypto.newDecryptingStream(inputStream, ad)) + fun loadFile(file: File, expectedHash: String): InputStream { + return loadFromStream(file.inputStream(), expectedHash) } suspend fun loadFiles(handles: List): InputStream { @@ -68,4 +67,38 @@ internal class Loader( } return SequenceInputStream(enumeration) } + + private fun loadFromStream( + inputStream: InputStream, + expectedHash: String, + cacheFile: File? = null, + ): InputStream { + // We load the entire ciphertext into memory, + // so we can check the SHA-256 hash before decrypting and parsing the data. + val cipherText = inputStream.use { it.readAllBytes() } + // check SHA-256 hash first thing + val sha256 = crypto.sha256(cipherText).toHexString() + if (sha256 != expectedHash) { + throw GeneralSecurityException("File had wrong SHA-256 hash: $handle") + } + // check that we can handle the version of that snapshot + val version = cipherText[0] + if (version <= 1) throw GeneralSecurityException("Unexpected version: $version") + if (version > VERSION) throw UnsupportedVersionException(version) + // cache ciperText in cacheFile, if existing + try { + cacheFile?.outputStream()?.use { outputStream -> + outputStream.write(cipherText) + } + } catch (e: Exception) { + log.error(e) { "Error writing cache file $cacheFile: " } + } + // get associated data for version, used for authenticated decryption + val ad = crypto.getAdForVersion(version) + // skip first version byte when creating cipherText stream + val byteStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1) + // decrypt and decompress cipherText stream and parse snapshot + return ZstdInputStream(crypto.newDecryptingStream(byteStream, ad)) + } + } diff --git a/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt b/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt index 0b8dc599..f2648531 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt @@ -11,6 +11,7 @@ import android.app.backup.IBackupObserver import android.content.Context import android.content.pm.ApplicationInfo.FLAG_SYSTEM import android.content.pm.PackageManager.NameNotFoundException +import android.os.Looper import android.util.Log import android.util.Log.INFO import android.util.Log.isLoggable @@ -136,7 +137,7 @@ internal class NotificationBackupObserver( if (isLoggable(TAG, INFO)) { Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status") } - val success = status == 0 + var success = status == 0 val size = if (success) metadataManager.getPackagesBackupSize() else 0L val total = try { packageService.allUserPackages.size @@ -144,11 +145,10 @@ internal class NotificationBackupObserver( Log.e(TAG, "Error getting number of all user packages: ", e) requestedPackages } - // TODO handle exceptions runBlocking { - // TODO check if UI thread - Log.d("TAG", "Finalizing backup...") - appBackupManager.afterBackupFinished(success) + check(!Looper.getMainLooper().isCurrentThread) + Log.d(TAG, "Finalizing backup...") + success = appBackupManager.afterBackupFinished(success) } nm.onBackupFinished(success, numPackagesToReport, total, size) } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt index a4f4d54e..1c441928 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt @@ -6,9 +6,13 @@ package com.stevesoltys.seedvault.transport import com.stevesoltys.seedvault.backend.BackendManager +import com.stevesoltys.seedvault.proto.snapshot import com.stevesoltys.seedvault.transport.restore.Loader +import io.mockk.Runs import io.mockk.coEvery +import io.mockk.coVerify import io.mockk.every +import io.mockk.just import io.mockk.mockk import io.mockk.slot import kotlinx.coroutines.runBlocking @@ -16,13 +20,19 @@ import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.backends.Backend import org.calyxos.seedvault.core.toByteArrayFromHex import org.calyxos.seedvault.core.toHexString +import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream +import java.io.File +import java.io.IOException import java.io.InputStream import java.io.OutputStream +import java.nio.file.Path import java.security.MessageDigest import kotlin.random.Random @@ -30,17 +40,104 @@ internal class SnapshotManagerTest : TransportTest() { private val backendManager: BackendManager = mockk() private val backend: Backend = mockk() + private val loader: Loader = mockk() - private val loader = Loader(crypto, backendManager) // need a real loader - private val snapshotManager = SnapshotManager(crypto, loader, backendManager) - + private val messageDigest = MessageDigest.getInstance("SHA-256") private val ad = Random.nextBytes(1) private val passThroughOutputStream = slot() private val passThroughInputStream = slot() private val snapshotHandle = slot() + init { + every { backendManager.backend } returns backend + } + @Test - fun `test saving and loading`() = runBlocking { + fun `test onSnapshotsLoaded sets latestSnapshot`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + val snapshotData1 = snapshot { token = 20 }.toByteArray() + val snapshotData2 = snapshot { token = 10 }.toByteArray() + val inputStream1 = ByteArrayInputStream(snapshotData1) + val inputStream2 = ByteArrayInputStream(snapshotData2) + + val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1) + val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2) + + coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream1 + coEvery { loader.loadFile(snapshotHandle2, any()) } returns inputStream2 + snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2)) + + // snapshot with largest token is latest + assertEquals(20, snapshotManager.latestSnapshot?.token) + } + + @Test + fun `saveSnapshot saves to local cache`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1) + val outputStream = ByteArrayOutputStream() + + every { crypto.getAdForVersion() } returns ad + every { crypto.newEncryptingStream(capture(passThroughOutputStream), ad) } answers { + passThroughOutputStream.captured // not really encrypting here + } + every { crypto.sha256(any()) } returns chunkId1.toByteArrayFromHex() + every { crypto.repoId } returns repoId + coEvery { backend.save(snapshotHandle) } returns outputStream + + snapshotManager.saveSnapshot(snapshot) + + val snapshotFile = File(tmpDir.toString(), snapshotHandle.name) + assertTrue(snapshotFile.isFile) + assertTrue(outputStream.size() > 0) + val cachedBytes = snapshotFile.inputStream().use { it.readAllBytes() } + assertArrayEquals(outputStream.toByteArray(), cachedBytes) + } + + @Test + fun `snapshot loads from cache without backend`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + val snapshotData = snapshot { token = 1337 }.toByteArray() + val inputStream = ByteArrayInputStream(snapshotData) + val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1) + + // create cached file + val file = File(tmpDir.toString(), snapshotHandle.name) + file.outputStream().use { it.write(snapshotData) } + + coEvery { loader.loadFile(file, snapshotHandle.hash) } returns inputStream + + snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle)) + + coVerify(exactly = 0) { // did not load from backend + loader.loadFile(snapshotHandle, any()) + } + } + + @Test + fun `failing to load a snapshot isn't fatal`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + + val snapshotData = snapshot { token = 42 }.toByteArray() + val inputStream = ByteArrayInputStream(snapshotData) + + val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1) + val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2) + + coEvery { loader.loadFile(snapshotHandle1, any()) } returns inputStream + coEvery { loader.loadFile(snapshotHandle2, any()) } throws IOException() + snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2)) + + // still one snapshot survived and we didn't crash + assertEquals(42, snapshotManager.latestSnapshot?.token) + } + + @Test + fun `test saving and loading`(@TempDir tmpDir: Path) = runBlocking { + val loader = Loader(crypto, backendManager) // need a real loader + val snapshotManager = getSnapshotManager(File(tmpDir.toString()), loader) + + val bytes = slot() val outputStream = ByteArrayOutputStream() every { crypto.getAdForVersion() } returns ad @@ -48,13 +145,14 @@ internal class SnapshotManagerTest : TransportTest() { passThroughOutputStream.captured // not really encrypting here } every { crypto.repoId } returns repoId - every { backendManager.backend } returns backend + every { crypto.sha256(capture(bytes)) } answers { + messageDigest.digest(bytes.captured) + } coEvery { backend.save(capture(snapshotHandle)) } returns outputStream snapshotManager.saveSnapshot(snapshot) // check that file content hash matches snapshot hash - val messageDigest = MessageDigest.getInstance("SHA-256") assertEquals( messageDigest.digest(outputStream.toByteArray()).toHexString(), snapshotHandle.captured.hash, @@ -75,4 +173,25 @@ internal class SnapshotManagerTest : TransportTest() { assertEquals(snapshot, snapshots[0]) } } + + @Test + fun `remove snapshot removes from backend and cache`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + + val snapshotHandle = AppBackupFileType.Snapshot(repoId, chunkId1) + val file = File(tmpDir.toString(), snapshotHandle.name) + file.createNewFile() + assertTrue(file.isFile) + + coEvery { backend.remove(snapshotHandle) } just Runs + + snapshotManager.removeSnapshot(snapshotHandle) + + assertFalse(file.exists()) + coVerify { backend.remove(snapshotHandle) } + } + + private fun getSnapshotManager(tmpFolder: File, loader: Loader = this.loader): SnapshotManager { + return SnapshotManager(tmpFolder, crypto, loader, backendManager) + } }