Don't get or write to full backup output stream before we are not sure there will be data to write

This commit is contained in:
Torsten Grote 2020-01-09 10:44:13 -03:00
parent 7b27242625
commit 9f01d09962
No known key found for this signature in database
GPG key ID: 3E5F77D92CF891FF
2 changed files with 98 additions and 70 deletions

View file

@ -17,7 +17,8 @@ private class FullBackupState(
internal val packageInfo: PackageInfo, internal val packageInfo: PackageInfo,
internal val inputFileDescriptor: ParcelFileDescriptor, internal val inputFileDescriptor: ParcelFileDescriptor,
internal val inputStream: InputStream, internal val inputStream: InputStream,
internal val outputStream: OutputStream) { internal var outputStreamInit: (() -> OutputStream)?) {
internal var outputStream: OutputStream? = null
internal val packageName: String = packageInfo.packageName internal val packageName: String = packageInfo.packageName
internal var size: Long = 0 internal var size: Long = 0
} }
@ -88,45 +89,32 @@ internal class FullBackup(
if (state != null) throw AssertionError() if (state != null) throw AssertionError()
Log.i(TAG, "Perform full backup for ${targetPackage.packageName}.") Log.i(TAG, "Perform full backup for ${targetPackage.packageName}.")
// create new state
val inputStream = inputFactory.getInputStream(socket)
state = FullBackupState(targetPackage, socket, inputStream) {
Log.d(TAG, "Initializing OutputStream for ${targetPackage.packageName}.")
// get OutputStream to write backup data into // get OutputStream to write backup data into
val outputStream = try { val outputStream = try {
plugin.getOutputStream(targetPackage) plugin.getOutputStream(targetPackage)
} catch (e: IOException) { } catch (e: IOException) {
Log.e(TAG, "Error getting OutputStream for full backup of ${targetPackage.packageName}", e) Log.e(TAG, "Error getting OutputStream for full backup of ${targetPackage.packageName}", e)
return backupError(TRANSPORT_ERROR) throw(e)
} }
// create new state
val inputStream = inputFactory.getInputStream(socket)
state = FullBackupState(targetPackage, socket, inputStream, outputStream)
// TODO store this is clearable lamdba and only run it when we actually get data
// this is to avoid unnecessary disk I/O
// store version header // store version header
val state = this.state ?: throw AssertionError() val state = this.state ?: throw AssertionError()
val header = VersionHeader(packageName = state.packageName) val header = VersionHeader(packageName = state.packageName)
try { try {
headerWriter.writeVersion(state.outputStream, header) headerWriter.writeVersion(outputStream, header)
crypto.encryptHeader(state.outputStream, header) crypto.encryptHeader(outputStream, header)
} catch (e: IOException) { } catch (e: IOException) {
Log.e(TAG, "Error writing backup header", e) Log.e(TAG, "Error writing backup header", e)
return backupError(TRANSPORT_ERROR) throw(e)
} }
outputStream
} // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK return TRANSPORT_OK
} }
/**
* Method to reset state,
* because [finishBackup] is not called
* when we don't return [TRANSPORT_OK] from [performFullBackup].
*/
private fun backupError(result: Int): Int {
Log.i(TAG, "Resetting state because of full backup error.")
state = null
return result
}
fun sendBackupData(numBytes: Int): Int { fun sendBackupData(numBytes: Int): Int {
val state = this.state val state = this.state
?: throw AssertionError("Attempted sendBackupData before performFullBackup") ?: throw AssertionError("Attempted sendBackupData before performFullBackup")
@ -142,8 +130,18 @@ internal class FullBackup(
Log.i(TAG, "Send full backup data of $numBytes bytes.") Log.i(TAG, "Send full backup data of $numBytes bytes.")
return try { return try {
// get output stream or initialize it, if it does not yet exist
check((state.outputStream != null) xor (state.outputStreamInit != null)) { "No OutputStream xor no StreamGetter" }
val outputStream = state.outputStream ?: {
val stream = state.outputStreamInit!!.invoke() // not-null due to check above
state.outputStream = stream
stream
}.invoke()
state.outputStreamInit = null // the stream init lambda is not needed beyond that point
// read backup data, encrypt it and write it to output stream
val payload = IOUtils.readFully(state.inputStream, numBytes) val payload = IOUtils.readFully(state.inputStream, numBytes)
crypto.encryptSegment(state.outputStream, payload) crypto.encryptSegment(outputStream, payload)
TRANSPORT_OK TRANSPORT_OK
} catch (e: IOException) { } catch (e: IOException) {
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e) Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)
@ -151,6 +149,7 @@ internal class FullBackup(
} }
} }
@Throws(IOException::class)
fun clearBackupData(packageInfo: PackageInfo) { fun clearBackupData(packageInfo: PackageInfo) {
plugin.removeDataOfPackage(packageInfo) plugin.removeDataOfPackage(packageInfo)
} }
@ -158,12 +157,12 @@ internal class FullBackup(
fun cancelFullBackup() { fun cancelFullBackup() {
Log.i(TAG, "Cancel full backup") Log.i(TAG, "Cancel full backup")
val state = this.state ?: throw AssertionError("No state when canceling") val state = this.state ?: throw AssertionError("No state when canceling")
clearState()
try { try {
plugin.removeDataOfPackage(state.packageInfo) plugin.removeDataOfPackage(state.packageInfo)
} catch (e: IOException) { } catch (e: IOException) {
Log.w(TAG, "Error cancelling full backup for ${state.packageName}", e) Log.w(TAG, "Error cancelling full backup for ${state.packageName}", e)
} }
clearState()
// TODO roll back to the previous known-good archive // TODO roll back to the previous known-good archive
} }
@ -175,7 +174,7 @@ internal class FullBackup(
private fun clearState(): Int { private fun clearState(): Int {
val state = this.state ?: throw AssertionError("Trying to clear empty state.") val state = this.state ?: throw AssertionError("Trying to clear empty state.")
return try { return try {
state.outputStream.flush() state.outputStream?.flush()
closeQuietly(state.outputStream) closeQuietly(state.outputStream)
closeQuietly(state.inputStream) closeQuietly(state.inputStream)
closeQuietly(state.inputFileDescriptor) closeQuietly(state.inputFileDescriptor)

View file

@ -56,27 +56,9 @@ internal class FullBackupTest : BackupTest() {
assertEquals(TRANSPORT_OK, backup.checkFullBackupSize(quota)) assertEquals(TRANSPORT_OK, backup.checkFullBackupSize(quota))
} }
@Test
fun `performFullBackup throws exception when getting outputStream`() {
every { plugin.getOutputStream(packageInfo) } throws IOException()
assertEquals(TRANSPORT_ERROR, backup.performFullBackup(packageInfo, data))
assertFalse(backup.hasState())
}
@Test
fun `performFullBackup throws exception when writing header`() {
every { plugin.getOutputStream(packageInfo) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream
every { headerWriter.writeVersion(outputStream, header) } throws IOException()
assertEquals(TRANSPORT_ERROR, backup.performFullBackup(packageInfo, data))
assertFalse(backup.hasState())
}
@Test @Test
fun `performFullBackup runs ok`() { fun `performFullBackup runs ok`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
@ -87,7 +69,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `sendBackupData first call over quota`() { fun `sendBackupData first call over quota`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
val numBytes = (quota + 1).toInt() val numBytes = (quota + 1).toInt()
expectSendData(numBytes) expectSendData(numBytes)
expectClearState() expectClearState()
@ -102,7 +85,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `sendBackupData second call over quota`() { fun `sendBackupData second call over quota`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
val numBytes1 = quota.toInt() val numBytes1 = quota.toInt()
expectSendData(numBytes1) expectSendData(numBytes1)
val numBytes2 = 1 val numBytes2 = 1
@ -121,7 +105,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `sendBackupData throws exception when reading from InputStream`() { fun `sendBackupData throws exception when reading from InputStream`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
every { inputStream.read(any(), any(), bytes.size) } throws IOException() every { inputStream.read(any(), any(), bytes.size) } throws IOException()
expectClearState() expectClearState()
@ -134,9 +119,44 @@ internal class FullBackupTest : BackupTest() {
assertFalse(backup.hasState()) assertFalse(backup.hasState())
} }
@Test
fun `sendBackupData throws exception when getting outputStream`() {
every { inputFactory.getInputStream(data) } returns inputStream
every { plugin.getQuota() } returns quota
every { plugin.getOutputStream(packageInfo) } throws IOException()
expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
}
@Test
fun `sendBackupData throws exception when writing header`() {
every { inputFactory.getInputStream(data) } returns inputStream
every { plugin.getQuota() } returns quota
every { plugin.getOutputStream(packageInfo) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream
every { headerWriter.writeVersion(outputStream, header) } throws IOException()
expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_ERROR, backup.sendBackupData(bytes.size))
assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.finishBackup())
assertFalse(backup.hasState())
}
@Test @Test
fun `sendBackupData throws exception when writing encrypted data to OutputStream`() { fun `sendBackupData throws exception when writing encrypted data to OutputStream`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
every { crypto.encryptSegment(outputStream, any()) } throws IOException() every { crypto.encryptSegment(outputStream, any()) } throws IOException()
@ -152,7 +172,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `sendBackupData runs ok`() { fun `sendBackupData runs ok`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
val numBytes1 = (quota / 2).toInt() val numBytes1 = (quota / 2).toInt()
expectSendData(numBytes1) expectSendData(numBytes1)
val numBytes2 = (quota / 2).toInt() val numBytes2 = (quota / 2).toInt()
@ -178,7 +199,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `cancel full backup runs ok`() { fun `cancel full backup runs ok`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
expectClearState() expectClearState()
every { plugin.removeDataOfPackage(packageInfo) } just Runs every { plugin.removeDataOfPackage(packageInfo) } just Runs
@ -190,7 +212,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `cancel full backup ignores exception when calling plugin`() { fun `cancel full backup ignores exception when calling plugin`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
expectClearState() expectClearState()
every { plugin.removeDataOfPackage(packageInfo) } throws IOException() every { plugin.removeDataOfPackage(packageInfo) } throws IOException()
@ -202,19 +225,24 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `clearState throws exception when flushing OutputStream`() { fun `clearState throws exception when flushing OutputStream`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
val numBytes = 42
expectSendData(numBytes)
every { outputStream.write(closeBytes) } just Runs every { outputStream.write(closeBytes) } just Runs
every { outputStream.flush() } throws IOException() every { outputStream.flush() } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
assertTrue(backup.hasState()) assertTrue(backup.hasState())
assertEquals(TRANSPORT_OK, backup.sendBackupData(numBytes))
assertEquals(TRANSPORT_ERROR, backup.finishBackup()) assertEquals(TRANSPORT_ERROR, backup.finishBackup())
assertFalse(backup.hasState()) assertFalse(backup.hasState())
} }
@Test @Test
fun `clearState ignores exception when closing OutputStream`() { fun `clearState ignores exception when closing OutputStream`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { outputStream.flush() } just Runs every { outputStream.flush() } just Runs
every { outputStream.close() } throws IOException() every { outputStream.close() } throws IOException()
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
@ -228,7 +256,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `clearState ignores exception when closing InputStream`() { fun `clearState ignores exception when closing InputStream`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { outputStream.flush() } just Runs every { outputStream.flush() } just Runs
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { inputStream.close() } throws IOException() every { inputStream.close() } throws IOException()
@ -242,7 +271,8 @@ internal class FullBackupTest : BackupTest() {
@Test @Test
fun `clearState ignores exception when closing ParcelFileDescriptor`() { fun `clearState ignores exception when closing ParcelFileDescriptor`() {
expectPerformFullBackup() every { inputFactory.getInputStream(data) } returns inputStream
expectInitializeOutputStream()
every { outputStream.flush() } just Runs every { outputStream.flush() } just Runs
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
@ -254,9 +284,8 @@ internal class FullBackupTest : BackupTest() {
assertFalse(backup.hasState()) assertFalse(backup.hasState())
} }
private fun expectPerformFullBackup() { private fun expectInitializeOutputStream() {
every { plugin.getOutputStream(packageInfo) } returns outputStream every { plugin.getOutputStream(packageInfo) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream
every { headerWriter.writeVersion(outputStream, header) } just Runs every { headerWriter.writeVersion(outputStream, header) } just Runs
every { crypto.encryptHeader(outputStream, header) } just Runs every { crypto.encryptHeader(outputStream, header) } just Runs
} }