diff --git a/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt b/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt index f6a6dc2f..3f8aabed 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/header/Header.kt @@ -1,6 +1,8 @@ package com.stevesoltys.seedvault.header import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH +import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV +import java.nio.ByteBuffer internal const val VERSION: Byte = 1 internal const val MAX_PACKAGE_LENGTH_SIZE = 255 @@ -29,6 +31,15 @@ data class VersionHeader( } } +internal fun getADForKV(version: Byte, packageName: String): ByteArray { + val packageNameBytes = packageName.toByteArray() + return ByteBuffer.allocate(2 + packageNameBytes.size) + .put(version) + .put(TYPE_BACKUP_KV) + .put(packageNameBytes) + .array() +} + internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt() internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int = diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVBackup.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVBackup.kt index 27455aea..0ba0fb1c 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVBackup.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/KVBackup.kt @@ -13,10 +13,11 @@ import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.encodeBase64 import com.stevesoltys.seedvault.header.HeaderWriter +import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VersionHeader +import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager -import libcore.io.IoUtils.closeQuietly import java.io.IOException class KVBackupState(internal val packageInfo: PackageInfo) @@ -166,18 +167,17 @@ internal class KVBackup( Log.e(TAG, "Deleting record with base64Key ${op.base64Key}") plugin.deleteRecord(packageInfo, op.base64Key) } else { - val outputStream = plugin.getOutputStreamForRecord(packageInfo, op.base64Key) - try { + plugin.getOutputStreamForRecord(packageInfo, op.base64Key).use { outputStream -> val header = VersionHeader( packageName = packageInfo.packageName, key = op.key ) headerWriter.writeVersion(outputStream, header) - crypto.encryptHeader(outputStream, header) - crypto.encryptMultipleSegments(outputStream, op.value) - outputStream.flush() - } finally { - closeQuietly(outputStream) + val ad = getADForKV(VERSION, packageInfo.packageName) + crypto.newEncryptingStream(outputStream, ad).use { encryptedStream -> + encryptedStream.write(op.value) + encryptedStream.flush() + } } } } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt index 4c34a67d..9ec0f306 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/KVRestore.kt @@ -13,6 +13,8 @@ import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.decodeBase64 import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.UnsupportedVersionException +import com.stevesoltys.seedvault.header.VERSION +import com.stevesoltys.seedvault.header.getADForKV import libcore.io.IoUtils.closeQuietly import java.io.IOException import java.util.ArrayList @@ -146,8 +148,16 @@ internal class KVRestore( ) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key) .use { inputStream -> val version = headerReader.readVersion(inputStream) - crypto.decryptHeader(inputStream, version, state.packageInfo.packageName, dKey.key) - val value = crypto.decryptMultipleSegments(inputStream) + val packageName = state.packageInfo.packageName + val value = if (version == 0.toByte()) { + crypto.decryptHeader(inputStream, version, packageName, dKey.key) + crypto.decryptMultipleSegments(inputStream) + } else { + val ad = getADForKV(VERSION, packageName) + crypto.newDecryptingStream(inputStream, ad).use { decryptedStream -> + decryptedStream.readBytes() + } + } val size = value.size Log.v(TAG, " ... key=${dKey.key} size=$size") diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/BackupTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/BackupTest.kt index d00d43b0..b603b238 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/BackupTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/BackupTest.kt @@ -13,6 +13,7 @@ internal abstract class BackupTest : TransportTest() { protected val headerWriter = mockk() protected val data = mockk() protected val outputStream = mockk() + protected val encryptedOutputStream = mockk() protected val header = VersionHeader(packageName = packageInfo.packageName) protected val quota = 42L diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/KVBackupTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/KVBackupTest.kt index f7670825..820c0913 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/backup/KVBackupTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/backup/KVBackupTest.kt @@ -12,7 +12,9 @@ import com.stevesoltys.seedvault.Utf8 import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE import com.stevesoltys.seedvault.header.VersionHeader +import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager +import io.mockk.CapturingSlot import io.mockk.Runs import io.mockk.coEvery import io.mockk.every @@ -47,7 +49,7 @@ internal class KVBackupTest : BackupTest() { private val key = getRandomString(MAX_KEY_LENGTH_SIZE) private val key64 = Base64.getEncoder().encodeToString(key.toByteArray(Utf8)) - private val value = ByteArray(23).apply { Random.nextBytes(this) } + private val dataValue = Random.nextBytes(23) private val versionHeader = VersionHeader(packageName = packageInfo.packageName, key = key) @Test @@ -73,26 +75,26 @@ internal class KVBackupTest : BackupTest() { every { dataInput.readNextHeader() } returnsMany listOf(true, true, false) every { dataInput.key } returnsMany listOf("key1", "key2") // we don't care about values, so just use the same one always - every { dataInput.dataSize } returns value.size - every { dataInput.readEntityData(any(), 0, value.size) } returns value.size + every { dataInput.dataSize } returns dataValue.size + every { dataInput.readEntityData(any(), 0, dataValue.size) } returns dataValue.size // store first record and show notification for it every { notificationManager.onPmKvBackup("key1", 1, 2) } just Runs coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5MQ") } returns outputStream val versionHeader1 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key1") every { headerWriter.writeVersion(outputStream, versionHeader1) } just Runs - every { crypto.encryptHeader(outputStream, versionHeader1) } just Runs // store second record and show notification for it every { notificationManager.onPmKvBackup("key2", 2, 2) } just Runs coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5Mg") } returns outputStream val versionHeader2 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key2") every { headerWriter.writeVersion(outputStream, versionHeader2) } just Runs - every { crypto.encryptHeader(outputStream, versionHeader2) } just Runs // encrypt to and close output stream - every { crypto.encryptMultipleSegments(outputStream, any()) } just Runs - every { outputStream.write(value) } just Runs + every { crypto.newEncryptingStream(outputStream, any()) } returns encryptedOutputStream + every { encryptedOutputStream.write(any()) } just Runs + every { encryptedOutputStream.flush() } just Runs + every { encryptedOutputStream.close() } just Runs every { outputStream.flush() } just Runs every { outputStream.close() } just Runs @@ -190,8 +192,8 @@ internal class KVBackupTest : BackupTest() { createBackupDataInput() every { dataInput.readNextHeader() } returns true every { dataInput.key } returns key - every { dataInput.dataSize } returns value.size - every { dataInput.readEntityData(any(), 0, value.size) } throws IOException() + every { dataInput.dataSize } returns dataValue.size + every { dataInput.readEntityData(any(), 0, dataValue.size) } throws IOException() every { plugin.packageFinished(packageInfo) } just Runs assertEquals(TRANSPORT_ERROR, backup.performBackup(packageInfo, data, 0)) @@ -230,10 +232,7 @@ internal class KVBackupTest : BackupTest() { initPlugin(false) getDataInput(listOf(true)) writeHeaderAndEncrypt() - coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream - every { headerWriter.writeVersion(outputStream, versionHeader) } just Runs - every { crypto.encryptMultipleSegments(outputStream, any()) } throws IOException() - every { outputStream.close() } just Runs + every { encryptedOutputStream.write(dataValue) } throws IOException() every { plugin.packageFinished(packageInfo) } just Runs assertEquals(TRANSPORT_ERROR, backup.performBackup(packageInfo, data, 0)) @@ -247,8 +246,9 @@ internal class KVBackupTest : BackupTest() { initPlugin(false) getDataInput(listOf(true)) writeHeaderAndEncrypt() - every { outputStream.write(value) } just Runs - every { outputStream.flush() } throws IOException() + every { encryptedOutputStream.write(dataValue) } just Runs + every { encryptedOutputStream.flush() } throws IOException() + every { encryptedOutputStream.close() } just Runs every { outputStream.close() } just Runs every { plugin.packageFinished(packageInfo) } just Runs @@ -263,9 +263,10 @@ internal class KVBackupTest : BackupTest() { initPlugin(false) getDataInput(listOf(true, false)) writeHeaderAndEncrypt() - every { outputStream.write(value) } just Runs - every { outputStream.flush() } just Runs - every { outputStream.close() } throws IOException() + every { encryptedOutputStream.write(dataValue) } just Runs + every { encryptedOutputStream.flush() } just Runs + every { encryptedOutputStream.close() } just Runs + every { outputStream.close() } just Runs every { plugin.packageFinished(packageInfo) } just Runs assertEquals(TRANSPORT_OK, backup.performBackup(packageInfo, data, 0)) @@ -278,8 +279,9 @@ internal class KVBackupTest : BackupTest() { initPlugin(hasDataForPackage) getDataInput(listOf(true, false)) writeHeaderAndEncrypt() - every { outputStream.write(value) } just Runs - every { outputStream.flush() } just Runs + every { encryptedOutputStream.write(dataValue) } just Runs + every { encryptedOutputStream.flush() } just Runs + every { encryptedOutputStream.close() } just Runs every { outputStream.close() } just Runs every { plugin.packageFinished(packageInfo) } just Runs } @@ -296,15 +298,19 @@ internal class KVBackupTest : BackupTest() { createBackupDataInput() every { dataInput.readNextHeader() } returnsMany returnValues every { dataInput.key } returns key - every { dataInput.dataSize } returns value.size - every { dataInput.readEntityData(any(), 0, value.size) } returns value.size + every { dataInput.dataSize } returns dataValue.size + val slot = CapturingSlot() + every { dataInput.readEntityData(capture(slot), 0, dataValue.size) } answers { + dataValue.copyInto(slot.captured) + dataValue.size + } } private fun writeHeaderAndEncrypt() { coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream every { headerWriter.writeVersion(outputStream, versionHeader) } just Runs - every { crypto.encryptHeader(outputStream, versionHeader) } just Runs - every { crypto.encryptMultipleSegments(outputStream, any()) } just Runs + val ad = getADForKV(versionHeader.version, packageInfo.packageName) + every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream } } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt index 393d19e8..975977a6 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/KVRestoreTest.kt @@ -9,11 +9,13 @@ import com.stevesoltys.seedvault.getRandomByteArray import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VersionHeader +import com.stevesoltys.seedvault.header.getADForKV import io.mockk.Runs import io.mockk.coEvery import io.mockk.every import io.mockk.just import io.mockk.mockk +import io.mockk.mockkStatic import io.mockk.verifyAll import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertEquals @@ -28,13 +30,17 @@ internal class KVRestoreTest : RestoreTest() { private val plugin = mockk() private val output = mockk() private val restore = KVRestore(plugin, outputFactory, headerReader, crypto) + private val ad = getADForKV(VERSION, packageInfo.packageName) private val key = "Restore Key" private val key64 = key.encodeBase64() - private val versionHeader = VersionHeader(VERSION, packageInfo.packageName, key) private val key2 = "Restore Key2" private val key264 = key2.encodeBase64() - private val versionHeader2 = VersionHeader(VERSION, packageInfo.packageName, key2) + + init { + // for InputStream#readBytes() + mockkStatic("kotlin.io.ByteStreamsKt") + } @Test fun `hasDataForPackage() delegates to plugin`() = runBlocking { @@ -90,21 +96,13 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `decrypting segment throws`() = runBlocking { + fun `decrypting stream throws`() = runBlocking { restore.initializeState(token, packageInfo) getRecordsAndOutput() coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } returns versionHeader - every { crypto.decryptMultipleSegments(inputStream) } throws IOException() + every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() streamsGetClosed() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) @@ -112,41 +110,13 @@ internal class KVRestoreTest : RestoreTest() { } @Test - fun `decrypting header throws`() = runBlocking { + fun `decrypting stream throws security exception`() = runBlocking { restore.initializeState(token, packageInfo) getRecordsAndOutput() coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } throws IOException() - streamsGetClosed() - - assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) - verifyStreamWasClosed() - } - - @Test - fun `decrypting header throws security exception`() = runBlocking { - restore.initializeState(token, packageInfo) - - getRecordsAndOutput() - coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream - every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } throws SecurityException() + every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException() streamsGetClosed() assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) @@ -160,15 +130,8 @@ internal class KVRestoreTest : RestoreTest() { getRecordsAndOutput() coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } returns versionHeader - every { crypto.decryptMultipleSegments(inputStream) } returns data + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream + every { decryptedInputStream.readBytes() } returns data every { output.writeEntityHeader(key, data.size) } throws IOException() streamsGetClosed() @@ -183,15 +146,8 @@ internal class KVRestoreTest : RestoreTest() { getRecordsAndOutput() coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } returns versionHeader - every { crypto.decryptMultipleSegments(inputStream) } returns data + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream + every { decryptedInputStream.readBytes() } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } throws IOException() streamsGetClosed() @@ -207,14 +163,26 @@ internal class KVRestoreTest : RestoreTest() { getRecordsAndOutput() coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream + every { decryptedInputStream.readBytes() } returns data + every { output.writeEntityHeader(key, data.size) } returns 42 + every { output.writeEntityData(data, data.size) } returns data.size + streamsGetClosed() + + assertEquals(TRANSPORT_OK, restore.getRestoreData(fileDescriptor)) + verifyStreamWasClosed() + } + + @Test + fun `writing value uses old v0 code`() = runBlocking { + restore.initializeState(token, packageInfo) + + getRecordsAndOutput() + coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream + every { headerReader.readVersion(inputStream) } returns 0.toByte() every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } returns versionHeader + crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key) + } returns VersionHeader(VERSION, packageInfo.packageName, key) every { crypto.decryptMultipleSegments(inputStream) } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } returns data.size @@ -228,37 +196,25 @@ internal class KVRestoreTest : RestoreTest() { fun `writing two values succeeds`() = runBlocking { val data2 = getRandomByteArray() val inputStream2 = mockk() + val decryptedInputStream2 = mockk() restore.initializeState(token, packageInfo) getRecordsAndOutput(listOf(key64, key264)) // first key/value coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream every { headerReader.readVersion(inputStream) } returns VERSION - every { - crypto.decryptHeader( - inputStream, - VERSION, - packageInfo.packageName, - key - ) - } returns versionHeader - every { crypto.decryptMultipleSegments(inputStream) } returns data + every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream + every { decryptedInputStream.readBytes() } returns data every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityData(data, data.size) } returns data.size // second key/value coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2 every { headerReader.readVersion(inputStream2) } returns VERSION - every { - crypto.decryptHeader( - inputStream2, - VERSION, - packageInfo.packageName, - key2 - ) - } returns versionHeader2 - every { crypto.decryptMultipleSegments(inputStream2) } returns data2 + every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2 + every { decryptedInputStream2.readBytes() } returns data2 every { output.writeEntityHeader(key2, data2.size) } returns 42 every { output.writeEntityData(data2, data2.size) } returns data2.size + every { decryptedInputStream2.close() } just Runs every { inputStream2.close() } just Runs streamsGetClosed() @@ -271,6 +227,7 @@ internal class KVRestoreTest : RestoreTest() { } private fun streamsGetClosed() { + every { decryptedInputStream.close() } just Runs every { inputStream.close() } just Runs every { fileDescriptor.close() } just Runs } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreTest.kt index fe31e370..e720bc79 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/restore/RestoreTest.kt @@ -16,6 +16,7 @@ internal abstract class RestoreTest : TransportTest() { protected val data = getRandomByteArray() protected val inputStream = mockk() + protected val decryptedInputStream = mockk() protected val unsupportedVersion = (VERSION + 1).toByte()