diff --git a/app/src/main/java/com/stevesoltys/seedvault/crypto/Crypto.kt b/app/src/main/java/com/stevesoltys/seedvault/crypto/Crypto.kt index 0ab1e79f..b97177ed 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/crypto/Crypto.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/crypto/Crypto.kt @@ -140,6 +140,7 @@ internal class CryptoImpl( deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray()) } + @Throws(IOException::class, GeneralSecurityException::class) override fun newEncryptingStream( outputStream: OutputStream, associatedData: ByteArray @@ -147,6 +148,7 @@ internal class CryptoImpl( return StreamCrypto.newEncryptingStream(key, outputStream, associatedData) } + @Throws(IOException::class, GeneralSecurityException::class) override fun newDecryptingStream( inputStream: InputStream, associatedData: ByteArray diff --git a/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt b/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt index 3f8aabed..085ac59d 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt @@ -1,6 +1,7 @@ package com.stevesoltys.seedvault.header import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH +import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_FULL import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV import java.nio.ByteBuffer @@ -40,6 +41,15 @@ internal fun getADForKV(version: Byte, packageName: String): ByteArray { .array() } +internal fun getADForFull(version: Byte, packageName: String): ByteArray { + val packageNameBytes = packageName.toByteArray() + return ByteBuffer.allocate(2 + packageNameBytes.size) + .put(version) + .put(TYPE_BACKUP_FULL) + .put(packageNameBytes) + .array() +} + internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt() internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int = diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt index b7ba16e0..ba9c22e5 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/FullBackup.kt @@ -10,7 +10,9 @@ import android.os.ParcelFileDescriptor import android.util.Log import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.header.HeaderWriter +import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VersionHeader +import com.stevesoltys.seedvault.header.getADForFull import com.stevesoltys.seedvault.settings.SettingsManager import libcore.io.IoUtils.closeQuietly import java.io.EOFException @@ -24,6 +26,9 @@ private class FullBackupState( val inputStream: InputStream, var outputStreamInit: (suspend () -> OutputStream)? ) { + /** + * This is an encrypted stream that can be written to directly. + */ var outputStream: OutputStream? = null val packageName: String = packageInfo.packageName var size: Long = 0 @@ -120,14 +125,8 @@ internal class FullBackup( // store version header val state = this.state ?: throw AssertionError() val header = VersionHeader(packageName = state.packageName) - try { - headerWriter.writeVersion(outputStream, header) - crypto.encryptHeader(outputStream, header) - } catch (e: IOException) { - Log.e(TAG, "Error writing backup header", e) - throw(e) - } - outputStream + headerWriter.writeVersion(outputStream, header) + crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName)) } // this lambda is only called before we actually write backup data the first time return TRANSPORT_OK } @@ -159,11 +158,11 @@ internal class FullBackup( }() state.outputStreamInit = null // the stream init lambda is not needed beyond that point - // read backup data, encrypt it and write it to output stream + // read backup data and write it to encrypted output stream val payload = ByteArray(numBytes) val read = state.inputStream.read(payload, 0, numBytes) if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.") - crypto.encryptSegment(outputStream, payload) + outputStream.write(payload) TRANSPORT_OK } catch (e: IOException) { Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e) diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/FullRestore.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/FullRestore.kt index 5f90d7db..4cf3bb2f 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/FullRestore.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/FullRestore.kt @@ -9,16 +9,21 @@ import android.os.ParcelFileDescriptor import android.util.Log import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.header.HeaderReader +import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH import com.stevesoltys.seedvault.header.UnsupportedVersionException +import com.stevesoltys.seedvault.header.getADForFull import libcore.io.IoUtils.closeQuietly import java.io.EOFException import java.io.IOException import java.io.InputStream +import java.io.OutputStream +import java.security.GeneralSecurityException private class FullRestoreState( val token: Long, val packageInfo: PackageInfo ) { + var version: Byte? = null var inputStream: InputStream? = null } @@ -81,7 +86,7 @@ internal class FullRestore( * Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition * that aborts all further restore operations on the current dataset. */ - suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int { + suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int = socket.use { pfd -> val state = this.state ?: throw IllegalStateException("no state") val packageName = state.packageInfo.packageName @@ -90,33 +95,48 @@ internal class FullRestore( try { val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo) val version = headerReader.readVersion(inputStream) - crypto.decryptHeader(inputStream, version, packageName) - state.inputStream = inputStream + state.version = version + if (version == 0.toByte()) { + crypto.decryptHeader(inputStream, version, packageName) + state.inputStream = inputStream + } else { + val ad = getADForFull(version, packageName) + state.inputStream = crypto.newDecryptingStream(inputStream, ad) + } } catch (e: IOException) { Log.w(TAG, "Error getting input stream for $packageName", e) return TRANSPORT_PACKAGE_REJECTED } catch (e: SecurityException) { Log.e(TAG, "Security Exception while getting input stream for $packageName", e) return TRANSPORT_ERROR + } catch (e: GeneralSecurityException) { + Log.e(TAG, "Security Exception while getting input stream for $packageName", e) + return TRANSPORT_ERROR } catch (e: UnsupportedVersionException) { Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e) return TRANSPORT_PACKAGE_REJECTED } } - return readInputStream(socket) + return outputFactory.getOutputStream(pfd).use { outputStream -> + try { + copyInputStream(outputStream) + } catch (e: IOException) { + Log.w(TAG, "Error copying stream for package $packageName.", e) + return TRANSPORT_PACKAGE_REJECTED + } + } } - private fun readInputStream(socket: ParcelFileDescriptor): Int = socket.use { fileDescriptor -> + @Throws(IOException::class) + private fun copyInputStream(outputStream: OutputStream): Int { val state = this.state ?: throw IllegalStateException("no state") - val packageName = state.packageInfo.packageName val inputStream = state.inputStream ?: throw IllegalStateException("no stream") - val outputStream = outputFactory.getOutputStream(fileDescriptor) + val version = state.version ?: throw IllegalStateException("no version") - try { + if (version == 0.toByte()) { // read segment from input stream and decrypt it val decrypted = try { - // TODO handle IOException crypto.decryptSegment(inputStream) } catch (e: EOFException) { Log.i(TAG, " EOF") @@ -129,12 +149,17 @@ internal class FullRestore( outputStream.write(decrypted) // return number of written bytes return decrypted.size - } catch (e: IOException) { - Log.w(TAG, "Error processing stream for package $packageName.", e) - closeQuietly(inputStream) - return TRANSPORT_PACKAGE_REJECTED - } finally { - closeQuietly(outputStream) + } else { + val buffer = ByteArray(MAX_SEGMENT_LENGTH) + val bytesRead = inputStream.read(buffer) + if (bytesRead == -1) { + Log.i(TAG, " EOF") + // close input stream here as we won't need it anymore + closeQuietly(inputStream) + return NO_MORE_DATA + } + outputStream.write(buffer, 0, bytesRead) + return bytesRead } } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt index 7bc3320c..8f526b4d 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/CoordinatorIntegrationTest.kt @@ -337,8 +337,7 @@ internal class CoordinatorIntegrationTest : TransportTest() { every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream // restore data - assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor)) - assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertEquals(appData.size, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) restore.finishRestore() diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt index c5c49d3f..b50edbfb 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/FullBackupTest.kt @@ -4,6 +4,8 @@ import android.app.backup.BackupTransport.TRANSPORT_ERROR import android.app.backup.BackupTransport.TRANSPORT_OK import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED +import com.stevesoltys.seedvault.header.VERSION +import com.stevesoltys.seedvault.header.getADForFull import io.mockk.Runs import io.mockk.coEvery import io.mockk.every @@ -25,8 +27,8 @@ internal class FullBackupTest : BackupTest() { private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto) private val bytes = ByteArray(23).apply { Random.nextBytes(this) } - private val closeBytes = ByteArray(42).apply { Random.nextBytes(this) } private val inputStream = mockk() + private val ad = getADForFull(VERSION, packageInfo.packageName) @Test fun `has no initial state`() { @@ -129,6 +131,7 @@ internal class FullBackupTest : BackupTest() { expectInitializeOutputStream() every { settingsManager.isQuotaUnlimited() } returns false every { plugin.getQuota() } returns quota + every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream every { inputStream.read(any(), any(), bytes.size) } throws IOException() expectClearState() @@ -183,8 +186,9 @@ internal class FullBackupTest : BackupTest() { expectInitializeOutputStream() every { settingsManager.isQuotaUnlimited() } returns false every { plugin.getQuota() } returns quota + every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream every { inputStream.read(any(), any(), bytes.size) } returns bytes.size - every { crypto.encryptSegment(outputStream, any()) } throws IOException() + every { encryptedOutputStream.write(any()) } throws IOException() expectClearState() assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) @@ -256,8 +260,7 @@ internal class FullBackupTest : BackupTest() { expectInitializeOutputStream() val numBytes = 42 expectSendData(numBytes) - every { outputStream.write(closeBytes) } just Runs - every { outputStream.flush() } throws IOException() + every { encryptedOutputStream.flush() } throws IOException() assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertTrue(backup.hasState()) @@ -314,18 +317,18 @@ internal class FullBackupTest : BackupTest() { private fun expectInitializeOutputStream() { coEvery { plugin.getOutputStream(packageInfo) } returns outputStream every { headerWriter.writeVersion(outputStream, header) } just Runs - every { crypto.encryptHeader(outputStream, header) } just Runs } private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) { every { plugin.getQuota() } returns quota every { inputStream.read(any(), any(), numBytes) } returns readBytes - every { crypto.encryptSegment(outputStream, any()) } just Runs + every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream + every { encryptedOutputStream.write(any()) } just Runs } private fun expectClearState() { - every { outputStream.write(closeBytes) } just Runs - every { outputStream.flush() } just Runs + every { encryptedOutputStream.flush() } just Runs + every { encryptedOutputStream.close() } just Runs every { outputStream.close() } just Runs every { inputStream.close() } just Runs every { data.close() } just Runs diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/FullRestoreTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/FullRestoreTest.kt index 3e3c22e8..668cea66 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/FullRestoreTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/FullRestoreTest.kt @@ -6,9 +6,12 @@ import android.app.backup.BackupTransport.TRANSPORT_OK import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED import com.stevesoltys.seedvault.coAssertThrows import com.stevesoltys.seedvault.getRandomByteArray +import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VersionHeader +import com.stevesoltys.seedvault.header.getADForFull +import io.mockk.CapturingSlot import io.mockk.Runs import io.mockk.coEvery import io.mockk.every @@ -20,9 +23,10 @@ import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test +import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream -import java.io.EOFException import java.io.IOException +import java.security.GeneralSecurityException import kotlin.random.Random @Suppress("BlockingMethodInNonBlockingContext") @@ -33,7 +37,7 @@ internal class FullRestoreTest : RestoreTest() { private val encrypted = getRandomByteArray() private val outputStream = ByteArrayOutputStream() - private val versionHeader = VersionHeader(VERSION, packageInfo.packageName) + private val ad = getADForFull(VERSION, packageInfo.packageName) @Test fun `has no initial state`() { @@ -67,6 +71,7 @@ internal class FullRestoreTest : RestoreTest() { restore.initializeState(token, packageInfo) coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException() + every { fileDescriptor.close() } just Runs assertEquals( TRANSPORT_PACKAGE_REJECTED, @@ -80,6 +85,7 @@ internal class FullRestoreTest : RestoreTest() { coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream every { headerReader.readVersion(inputStream) } throws IOException() + every { fileDescriptor.close() } just Runs assertEquals( TRANSPORT_PACKAGE_REJECTED, @@ -95,6 +101,7 @@ internal class FullRestoreTest : RestoreTest() { every { headerReader.readVersion(inputStream) } throws UnsupportedVersionException(unsupportedVersion) + every { fileDescriptor.close() } just Runs assertEquals( TRANSPORT_PACKAGE_REJECTED, @@ -103,18 +110,13 @@ internal class FullRestoreTest : RestoreTest() { } @Test - fun `decrypting version header when getting first chunk throws`() = runBlocking { + fun `getting decrypted stream when getting first chunk throws`() = runBlocking { restore.initializeState(token, packageInfo) coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName - ) - } throws IOException() + every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() + every { fileDescriptor.close() } just Runs assertEquals( TRANSPORT_PACKAGE_REJECTED, @@ -123,54 +125,20 @@ internal class FullRestoreTest : RestoreTest() { } @Test - fun `decrypting version header when getting first chunk throws security exception`() = + fun `getting decrypted stream when getting first chunk throws general security exception`() = runBlocking { restore.initializeState(token, packageInfo) coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName - ) - } throws SecurityException() + every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException() + every { fileDescriptor.close() } just Runs assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor)) } @Test - fun `decrypting segment throws IOException`() = runBlocking { - restore.initializeState(token, packageInfo) - - initInputStream() - every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream - every { crypto.decryptSegment(inputStream) } throws IOException() - every { inputStream.close() } just Runs - every { fileDescriptor.close() } just Runs - - assertEquals( - TRANSPORT_PACKAGE_REJECTED, - restore.getNextFullRestoreDataChunk(fileDescriptor) - ) - } - - @Test - fun `decrypting segment throws EOFException`() = runBlocking { - restore.initializeState(token, packageInfo) - - initInputStream() - every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream - every { crypto.decryptSegment(inputStream) } throws EOFException() - every { inputStream.close() } just Runs - every { fileDescriptor.close() } just Runs - - assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) - } - - @Test - fun `full chunk gets encrypted`() = runBlocking { + fun `full chunk gets decrypted`() = runBlocking { restore.initializeState(token, packageInfo) initInputStream() @@ -183,6 +151,50 @@ internal class FullRestoreTest : RestoreTest() { assertFalse(restore.hasState()) } + @Test + fun `full chunk gets decrypted from version 0`() = runBlocking { + restore.initializeState(token, packageInfo) + + coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream + every { headerReader.readVersion(inputStream) } returns 0.toByte() + every { + crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName) + } returns VersionHeader(0.toByte(), packageInfo.packageName) + every { crypto.decryptSegment(inputStream) } returns encrypted + + every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream + every { fileDescriptor.close() } just Runs + every { inputStream.close() } just Runs + + assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertArrayEquals(encrypted, outputStream.toByteArray()) + restore.finishRestore() + assertFalse(restore.hasState()) + } + + @Test + fun `three full chunk get decrypted and then return no more data`() = runBlocking { + val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1) + val decryptedInputStream = ByteArrayInputStream(encryptedBytes) + restore.initializeState(token, packageInfo) + + coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream + every { headerReader.readVersion(inputStream) } returns VERSION + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream + every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream + every { fileDescriptor.close() } just Runs + every { inputStream.close() } just Runs + + assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertEquals(1, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) + assertArrayEquals(encryptedBytes, outputStream.toByteArray()) + restore.finishRestore() + assertFalse(restore.hasState()) + } + @Test fun `aborting full restore closes stream, resets state`() = runBlocking { restore.initializeState(token, packageInfo) @@ -201,18 +213,17 @@ internal class FullRestoreTest : RestoreTest() { private fun initInputStream() { coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName - ) - } returns versionHeader + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream } private fun readAndEncryptInputStream(encryptedBytes: ByteArray) { every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream - every { crypto.decryptSegment(inputStream) } returns encryptedBytes + val slot = CapturingSlot() + every { decryptedInputStream.read(capture(slot)) } answers { + encryptedBytes.copyInto(slot.captured) + encryptedBytes.size + } + every { decryptedInputStream.close() } just Runs every { fileDescriptor.close() } just Runs }