diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt index 762a46b3..ed6b1940 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt @@ -17,7 +17,8 @@ private class FullBackupState( internal val packageInfo: PackageInfo, internal val inputFileDescriptor: ParcelFileDescriptor, internal val inputStream: InputStream, - internal val outputStream: OutputStream) { + internal var outputStreamInit: (() -> OutputStream)?) { + internal var outputStream: OutputStream? = null internal val packageName: String = packageInfo.packageName internal var size: Long = 0 } @@ -88,45 +89,32 @@ internal class FullBackup( if (state != null) throw AssertionError() Log.i(TAG, "Perform full backup for ${targetPackage.packageName}.") - // get OutputStream to write backup data into - val outputStream = try { - plugin.getOutputStream(targetPackage) - } catch (e: IOException) { - Log.e(TAG, "Error getting OutputStream for full backup of ${targetPackage.packageName}", e) - return backupError(TRANSPORT_ERROR) - } - // create new state val inputStream = inputFactory.getInputStream(socket) - state = FullBackupState(targetPackage, socket, inputStream, outputStream) - - // TODO store this is clearable lamdba and only run it when we actually get data - // this is to avoid unnecessary disk I/O - - // store version header - val state = this.state ?: throw AssertionError() - val header = VersionHeader(packageName = state.packageName) - try { - headerWriter.writeVersion(state.outputStream, header) - crypto.encryptHeader(state.outputStream, header) - } catch (e: IOException) { - Log.e(TAG, "Error writing backup header", e) - return backupError(TRANSPORT_ERROR) - } + state = FullBackupState(targetPackage, socket, inputStream) { + Log.d(TAG, "Initializing OutputStream for ${targetPackage.packageName}.") + // get OutputStream to write backup data into + val outputStream = try { + plugin.getOutputStream(targetPackage) + } catch (e: IOException) { + Log.e(TAG, "Error getting OutputStream for full backup of ${targetPackage.packageName}", e) + throw(e) + } + // store version header + val state = this.state ?: throw AssertionError() + val header = VersionHeader(packageName = state.packageName) + try { + headerWriter.writeVersion(outputStream, header) + crypto.encryptHeader(outputStream, header) + } catch (e: IOException) { + Log.e(TAG, "Error writing backup header", e) + throw(e) + } + outputStream + } // this lambda is only called before we actually write backup data the first time return TRANSPORT_OK } - /** - * Method to reset state, - * because [finishBackup] is not called - * when we don't return [TRANSPORT_OK] from [performFullBackup]. - */ - private fun backupError(result: Int): Int { - Log.i(TAG, "Resetting state because of full backup error.") - state = null - return result - } - fun sendBackupData(numBytes: Int): Int { val state = this.state ?: throw AssertionError("Attempted sendBackupData before performFullBackup") @@ -142,8 +130,18 @@ internal class FullBackup( Log.i(TAG, "Send full backup data of $numBytes bytes.") return try { + // get output stream or initialize it, if it does not yet exist + check((state.outputStream != null) xor (state.outputStreamInit != null)) { "No OutputStream xor no StreamGetter" } + val outputStream = state.outputStream ?: { + val stream = state.outputStreamInit!!.invoke() // not-null due to check above + state.outputStream = stream + stream + }.invoke() + state.outputStreamInit = null // the stream init lambda is not needed beyond that point + + // read backup data, encrypt it and write it to output stream val payload = IOUtils.readFully(state.inputStream, numBytes) - crypto.encryptSegment(state.outputStream, payload) + crypto.encryptSegment(outputStream, payload) TRANSPORT_OK } catch (e: IOException) { Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e) @@ -151,6 +149,7 @@ internal class FullBackup( } } + @Throws(IOException::class) fun clearBackupData(packageInfo: PackageInfo) { plugin.removeDataOfPackage(packageInfo) } @@ -158,12 +157,12 @@ internal class FullBackup( fun cancelFullBackup() { Log.i(TAG, "Cancel full backup") val state = this.state ?: throw AssertionError("No state when canceling") - clearState() try { plugin.removeDataOfPackage(state.packageInfo) } catch (e: IOException) { Log.w(TAG, "Error cancelling full backup for ${state.packageName}", e) } + clearState() // TODO roll back to the previous known-good archive } @@ -175,7 +174,7 @@ internal class FullBackup( private fun clearState(): Int { val state = this.state ?: throw AssertionError("Trying to clear empty state.") return try { - state.outputStream.flush() + state.outputStream?.flush() closeQuietly(state.outputStream) closeQuietly(state.inputStream) closeQuietly(state.inputFileDescriptor) diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt index df7ca21a..80cf7f51 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt @@ -56,27 +56,9 @@ internal class FullBackupTest : BackupTest() { assertEquals(TRANSPORT_OK, backup.checkFullBackupSize(quota)) } - @Test - fun `performFullBackup throws exception when getting outputStream`() { - every { plugin.getOutputStream(packageInfo) } throws IOException() - - assertEquals(TRANSPORT_ERROR, backup.performFullBackup(packageInfo, data)) - assertFalse(backup.hasState()) - } - - @Test - fun `performFullBackup throws exception when writing header`() { - every { plugin.getOutputStream(packageInfo) } returns outputStream - every { inputFactory.getInputStream(data) } returns inputStream - every { headerWriter.writeVersion(outputStream, header) } throws IOException() - - assertEquals(TRANSPORT_ERROR, backup.performFullBackup(packageInfo, data)) - assertFalse(backup.hasState()) - } - @Test fun `performFullBackup runs ok`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream expectClearState() assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) @@ -87,7 +69,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `sendBackupData first call over quota`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() val numBytes = (quota + 1).toInt() expectSendData(numBytes) expectClearState() @@ -102,7 +85,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `sendBackupData second call over quota`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() val numBytes1 = quota.toInt() expectSendData(numBytes1) val numBytes2 = 1 @@ -121,7 +105,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `sendBackupData throws exception when reading from InputStream`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() every { plugin.getQuota() } returns quota every { inputStream.read(any(), any(), bytes.size) } throws IOException() expectClearState() @@ -134,9 +119,44 @@ internal class FullBackupTest : BackupTest() { assertFalse(backup.hasState()) } + @Test + fun `sendBackupData throws exception when getting outputStream`() { + every { inputFactory.getInputStream(data) } returns inputStream + + every { plugin.getQuota() } returns quota + every { plugin.getOutputStream(packageInfo) } throws IOException() + expectClearState() + + assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) + assertTrue(backup.hasState()) + assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size)) + assertTrue(backup.hasState()) + assertEquals(TRANSPORT_OK, backup.finishBackup()) + assertFalse(backup.hasState()) + } + + @Test + fun `sendBackupData throws exception when writing header`() { + every { inputFactory.getInputStream(data) } returns inputStream + + every { plugin.getQuota() } returns quota + every { plugin.getOutputStream(packageInfo) } returns outputStream + every { inputFactory.getInputStream(data) } returns inputStream + every { headerWriter.writeVersion(outputStream, header) } throws IOException() + expectClearState() + + assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) + assertTrue(backup.hasState()) + assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size)) + assertTrue(backup.hasState()) + assertEquals(TRANSPORT_OK, backup.finishBackup()) + assertFalse(backup.hasState()) + } + @Test fun `sendBackupData throws exception when writing encrypted data to OutputStream`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() every { plugin.getQuota() } returns quota every { inputStream.read(any(), any(), bytes.size) } returns bytes.size every { crypto.encryptSegment(outputStream, any()) } throws IOException() @@ -152,7 +172,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `sendBackupData runs ok`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() val numBytes1 = (quota / 2).toInt() expectSendData(numBytes1) val numBytes2 = (quota / 2).toInt() @@ -178,7 +199,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `cancel full backup runs ok`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() expectClearState() every { plugin.removeDataOfPackage(packageInfo) } just Runs @@ -190,7 +212,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `cancel full backup ignores exception when calling plugin`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() expectClearState() every { plugin.removeDataOfPackage(packageInfo) } throws IOException() @@ -202,19 +225,24 @@ internal class FullBackupTest : BackupTest() { @Test fun `clearState throws exception when flushing OutputStream`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() + val numBytes = 42 + expectSendData(numBytes) every { outputStream.write(closeBytes) } just Runs every { outputStream.flush() } throws IOException() assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertTrue(backup.hasState()) + assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes)) assertEquals(TRANSPORT_ERROR, backup.finishBackup()) assertFalse(backup.hasState()) } @Test fun `clearState ignores exception when closing OutputStream`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() every { outputStream.flush() } just Runs every { outputStream.close() } throws IOException() every { inputStream.close() } just Runs @@ -228,7 +256,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `clearState ignores exception when closing InputStream`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() every { outputStream.flush() } just Runs every { outputStream.close() } just Runs every { inputStream.close() } throws IOException() @@ -242,7 +271,8 @@ internal class FullBackupTest : BackupTest() { @Test fun `clearState ignores exception when closing ParcelFileDescriptor`() { - expectPerformFullBackup() + every { inputFactory.getInputStream(data) } returns inputStream + expectInitializeOutputStream() every { outputStream.flush() } just Runs every { outputStream.close() } just Runs every { inputStream.close() } just Runs @@ -254,9 +284,8 @@ internal class FullBackupTest : BackupTest() { assertFalse(backup.hasState()) } - private fun expectPerformFullBackup() { + private fun expectInitializeOutputStream() { every { plugin.getOutputStream(packageInfo) } returns outputStream - every { inputFactory.getInputStream(data) } returns inputStream every { headerWriter.writeVersion(outputStream, header) } just Runs every { crypto.encryptHeader(outputStream, header) } just Runs }