Check version of backup files against expected version from metadata

and throw security exception if it does not match
This commit is contained in:
Torsten Grote 2021-09-14 17:37:37 +02:00 committed by Chirayu Desai
parent 5523e57fe7
commit 2932af463c
9 changed files with 118 additions and 63 deletions

View file

@ -5,10 +5,11 @@ import java.io.EOFException
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.security.GeneralSecurityException
internal interface HeaderReader { internal interface HeaderReader {
@Throws(IOException::class, UnsupportedVersionException::class) @Throws(IOException::class, UnsupportedVersionException::class)
fun readVersion(inputStream: InputStream): Byte fun readVersion(inputStream: InputStream, expectedVersion: Byte): Byte
@Deprecated("") @Deprecated("")
@Throws(SecurityException::class) @Throws(SecurityException::class)
@ -21,11 +22,14 @@ internal interface HeaderReader {
internal class HeaderReaderImpl : HeaderReader { internal class HeaderReaderImpl : HeaderReader {
@Throws(IOException::class, UnsupportedVersionException::class) @Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
override fun readVersion(inputStream: InputStream): Byte { override fun readVersion(inputStream: InputStream, expectedVersion: Byte): Byte {
val version = inputStream.read().toByte() val version = inputStream.read().toByte()
if (version < 0) throw IOException() if (version < 0) throw IOException()
if (version > VERSION) throw UnsupportedVersionException(version) if (version > VERSION) throw UnsupportedVersionException(version)
if (expectedVersion != version) throw GeneralSecurityException(
"Expected version ${expectedVersion.toInt()}, but got ${version.toInt()}"
)
return version return version
} }

View file

@ -20,10 +20,10 @@ import java.io.OutputStream
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
private class FullRestoreState( private class FullRestoreState(
val version: Byte,
val token: Long, val token: Long,
val packageInfo: PackageInfo val packageInfo: PackageInfo
) { ) {
var version: Byte? = null
var inputStream: InputStream? = null var inputStream: InputStream? = null
} }
@ -55,8 +55,8 @@ internal class FullRestore(
* It is possible that the system decides to not restore the package. * It is possible that the system decides to not restore the package.
* Then a new state will be initialized right away without calling other methods. * Then a new state will be initialized right away without calling other methods.
*/ */
fun initializeState(token: Long, packageInfo: PackageInfo) { fun initializeState(version: Byte, token: Long, packageInfo: PackageInfo) {
state = FullRestoreState(token, packageInfo) state = FullRestoreState(version, token, packageInfo)
} }
/** /**
@ -94,8 +94,7 @@ internal class FullRestore(
Log.i(TAG, "First Chunk, initializing package input stream.") Log.i(TAG, "First Chunk, initializing package input stream.")
try { try {
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo) val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
val version = headerReader.readVersion(inputStream) val version = headerReader.readVersion(inputStream, state.version)
state.version = version
if (version == 0.toByte()) { if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName) crypto.decryptHeader(inputStream, version, packageName)
state.inputStream = inputStream state.inputStream = inputStream
@ -132,9 +131,8 @@ internal class FullRestore(
private fun copyInputStream(outputStream: OutputStream): Int { private fun copyInputStream(outputStream: OutputStream): Int {
val state = this.state ?: throw IllegalStateException("no state") val state = this.state ?: throw IllegalStateException("no state")
val inputStream = state.inputStream ?: throw IllegalStateException("no stream") val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
val version = state.version ?: throw IllegalStateException("no version")
if (version == 0.toByte()) { if (state.version == 0.toByte()) {
// read segment from input stream and decrypt it // read segment from input stream and decrypt it
val decrypted = try { val decrypted = try {
crypto.decryptSegment(inputStream) crypto.decryptSegment(inputStream)

View file

@ -17,10 +17,12 @@ import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.header.getADForKV
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
import java.io.IOException import java.io.IOException
import java.security.GeneralSecurityException
import java.util.ArrayList import java.util.ArrayList
import javax.crypto.AEADBadTagException import javax.crypto.AEADBadTagException
private class KVRestoreState( private class KVRestoreState(
val version: Byte,
val token: Long, val token: Long,
val packageInfo: PackageInfo, val packageInfo: PackageInfo,
/** /**
@ -57,8 +59,13 @@ internal class KVRestore(
* *
* @param pmPackageInfo single optional [PackageInfo] to optimize restore of @pm@ * @param pmPackageInfo single optional [PackageInfo] to optimize restore of @pm@
*/ */
fun initializeState(token: Long, packageInfo: PackageInfo, pmPackageInfo: PackageInfo? = null) { fun initializeState(
state = KVRestoreState(token, packageInfo, pmPackageInfo) version: Byte,
token: Long,
packageInfo: PackageInfo,
pmPackageInfo: PackageInfo? = null
) {
state = KVRestoreState(version, token, packageInfo, pmPackageInfo)
} }
/** /**
@ -98,6 +105,9 @@ internal class KVRestore(
} catch (e: SecurityException) { } catch (e: SecurityException) {
Log.e(TAG, "Security exception while reading backup records", e) Log.e(TAG, "Security exception while reading backup records", e)
TRANSPORT_ERROR TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "General security exception while reading backup records", e)
TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) { } catch (e: UnsupportedVersionException) {
Log.e(TAG, "Unsupported version in backup: ${e.version}", e) Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
TRANSPORT_ERROR TRANSPORT_ERROR
@ -140,14 +150,14 @@ internal class KVRestore(
/** /**
* Read the encrypted value for the given key and write it to the given [BackupDataOutput]. * Read the encrypted value for the given key and write it to the given [BackupDataOutput].
*/ */
@Throws(IOException::class, UnsupportedVersionException::class, SecurityException::class) @Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
private suspend fun readAndWriteValue( private suspend fun readAndWriteValue(
state: KVRestoreState, state: KVRestoreState,
dKey: DecodedKey, dKey: DecodedKey,
out: BackupDataOutput out: BackupDataOutput
) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key) ) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
.use { inputStream -> .use { inputStream ->
val version = headerReader.readVersion(inputStream) val version = headerReader.readVersion(inputStream, state.version)
val packageName = state.packageInfo.packageName val packageName = state.packageInfo.packageName
val value = if (version == 0.toByte()) { val value = if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName, dKey.key) crypto.decryptHeader(inputStream, version, packageName, dKey.key)

View file

@ -200,6 +200,7 @@ internal class RestoreCoordinator(
val state = this.state ?: throw IllegalStateException("no state") val state = this.state ?: throw IllegalStateException("no state")
if (!state.packages.hasNext()) return NO_MORE_PACKAGES if (!state.packages.hasNext()) return NO_MORE_PACKAGES
val version = state.backupMetadata.version
val packageInfo = state.packages.next() val packageInfo = state.packages.next()
val packageName = packageInfo.packageName val packageName = packageInfo.packageName
@ -208,13 +209,13 @@ internal class RestoreCoordinator(
// check key/value data first and if available, don't even check for full data // check key/value data first and if available, don't even check for full data
kv.hasDataForPackage(state.token, packageInfo) -> { kv.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found K/V data for $packageName.") Log.i(TAG, "Found K/V data for $packageName.")
kv.initializeState(state.token, packageInfo, state.pmPackageInfo) kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_KEY_VALUE TYPE_KEY_VALUE
} }
full.hasDataForPackage(state.token, packageInfo) -> { full.hasDataForPackage(state.token, packageInfo) -> {
Log.i(TAG, "Found full backup data for $packageName.") Log.i(TAG, "Found full backup data for $packageName.")
full.initializeState(state.token, packageInfo) full.initializeState(version, state.token, packageInfo)
state.currentPackage = packageName state.currentPackage = packageName
TYPE_FULL_STREAM TYPE_FULL_STREAM
} }

View file

@ -26,7 +26,7 @@ internal class HeaderReaderTest {
val input = byteArrayOf(VERSION) val input = byteArrayOf(VERSION)
val inputStream = ByteArrayInputStream(input) val inputStream = ByteArrayInputStream(input)
assertEquals(VERSION, reader.readVersion(inputStream)) assertEquals(VERSION, reader.readVersion(inputStream, VERSION))
} }
@Test @Test
@ -34,7 +34,7 @@ internal class HeaderReaderTest {
val input = ByteArray(0) val input = ByteArray(0)
val inputStream = ByteArrayInputStream(input) val inputStream = ByteArrayInputStream(input)
assertThrows(IOException::class.javaObjectType) { assertThrows(IOException::class.javaObjectType) {
reader.readVersion(inputStream) reader.readVersion(inputStream, VERSION)
} }
} }
@ -43,7 +43,7 @@ internal class HeaderReaderTest {
val input = byteArrayOf((VERSION + 1).toByte()) val input = byteArrayOf((VERSION + 1).toByte())
val inputStream = ByteArrayInputStream(input) val inputStream = ByteArrayInputStream(input)
assertThrows(UnsupportedVersionException::class.javaObjectType) { assertThrows(UnsupportedVersionException::class.javaObjectType) {
reader.readVersion(inputStream) reader.readVersion(inputStream, VERSION)
} }
} }
@ -52,7 +52,7 @@ internal class HeaderReaderTest {
val input = byteArrayOf((-1).toByte()) val input = byteArrayOf((-1).toByte())
val inputStream = ByteArrayInputStream(input) val inputStream = ByteArrayInputStream(input)
assertThrows(IOException::class.javaObjectType) { assertThrows(IOException::class.javaObjectType) {
reader.readVersion(inputStream) reader.readVersion(inputStream, VERSION)
} }
} }
@ -61,7 +61,16 @@ internal class HeaderReaderTest {
val input = byteArrayOf(Byte.MAX_VALUE) val input = byteArrayOf(Byte.MAX_VALUE)
val inputStream = ByteArrayInputStream(input) val inputStream = ByteArrayInputStream(input)
assertThrows(UnsupportedVersionException::class.javaObjectType) { assertThrows(UnsupportedVersionException::class.javaObjectType) {
reader.readVersion(inputStream) reader.readVersion(inputStream, Byte.MAX_VALUE)
}
}
@Test
fun `unexpected version throws exception`() {
val input = byteArrayOf(VERSION + 1)
val inputStream = ByteArrayInputStream(input)
assertThrows(UnsupportedVersionException::class.javaObjectType) {
reader.readVersion(inputStream, VERSION)
} }
} }

View file

@ -54,7 +54,7 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `initializing state leaves a state`() { fun `initializing state leaves a state`() {
assertFalse(restore.hasState()) assertFalse(restore.hasState())
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
assertTrue(restore.hasState()) assertTrue(restore.hasState())
} }
@ -68,7 +68,7 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `getting InputStream for package when getting first chunk throws`() = runBlocking { fun `getting InputStream for package when getting first chunk throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException() coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
@ -81,10 +81,10 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `reading version header when getting first chunk throws`() = runBlocking { fun `reading version header when getting first chunk throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } throws IOException() every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
assertEquals( assertEquals(
@ -95,11 +95,11 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `reading unsupported version when getting first chunk`() = runBlocking { fun `reading unsupported version when getting first chunk`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { every {
headerReader.readVersion(inputStream) headerReader.readVersion(inputStream, VERSION)
} throws UnsupportedVersionException(unsupportedVersion) } throws UnsupportedVersionException(unsupportedVersion)
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
@ -111,10 +111,10 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `getting decrypted stream when getting first chunk throws`() = runBlocking { fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
@ -127,10 +127,10 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `getting decrypted stream when getting first chunk throws general security exception`() = fun `getting decrypted stream when getting first chunk throws general security exception`() =
runBlocking { runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException() every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
@ -139,7 +139,7 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `full chunk gets decrypted`() = runBlocking { fun `full chunk gets decrypted`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
initInputStream() initInputStream()
readAndEncryptInputStream(encrypted) readAndEncryptInputStream(encrypted)
@ -153,10 +153,10 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `full chunk gets decrypted from version 0`() = runBlocking { fun `full chunk gets decrypted from version 0`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(0.toByte(), token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns 0.toByte() every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
every { every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName) crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
} returns VersionHeader(0.toByte(), packageInfo.packageName) } returns VersionHeader(0.toByte(), packageInfo.packageName)
@ -172,14 +172,30 @@ internal class FullRestoreTest : RestoreTest() {
assertFalse(restore.hasState()) assertFalse(restore.hasState())
} }
@Test
fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every {
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
} throws GeneralSecurityException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
restore.abortFullRestore()
assertFalse(restore.hasState())
}
@Test @Test
fun `three full chunk get decrypted and then return no more data`() = runBlocking { fun `three full chunk get decrypted and then return no more data`() = runBlocking {
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1) val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
val decryptedInputStream = ByteArrayInputStream(encryptedBytes) val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
@ -197,7 +213,7 @@ internal class FullRestoreTest : RestoreTest() {
@Test @Test
fun `aborting full restore closes stream, resets state`() = runBlocking { fun `aborting full restore closes stream, resets state`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
initInputStream() initInputStream()
readAndEncryptInputStream(encrypted) readAndEncryptInputStream(encrypted)
@ -212,7 +228,7 @@ internal class FullRestoreTest : RestoreTest() {
private fun initInputStream() { private fun initInputStream() {
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
} }

View file

@ -22,6 +22,7 @@ import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.security.GeneralSecurityException
import kotlin.random.Random import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
@ -60,7 +61,7 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `listing records throws`() = runBlocking { fun `listing records throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
coEvery { plugin.listRecords(token, packageInfo) } throws IOException() coEvery { plugin.listRecords(token, packageInfo) } throws IOException()
@ -69,12 +70,12 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `reading VersionHeader with unsupported version throws`() = runBlocking { fun `reading VersionHeader with unsupported version throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { every {
headerReader.readVersion(inputStream) headerReader.readVersion(inputStream, VERSION)
} throws UnsupportedVersionException(unsupportedVersion) } throws UnsupportedVersionException(unsupportedVersion)
streamsGetClosed() streamsGetClosed()
@ -84,11 +85,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `error reading VersionHeader throws`() = runBlocking { fun `error reading VersionHeader throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } throws IOException() every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
streamsGetClosed() streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
@ -97,11 +98,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `decrypting stream throws`() = runBlocking { fun `decrypting stream throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException() every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
streamsGetClosed() streamsGetClosed()
@ -111,11 +112,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `decrypting stream throws security exception`() = runBlocking { fun `decrypting stream throws security exception`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException() every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException()
streamsGetClosed() streamsGetClosed()
@ -125,11 +126,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `writing header throws`() = runBlocking { fun `writing header throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data every { decryptedInputStream.readBytes() } returns data
every { output.writeEntityHeader(key, data.size) } throws IOException() every { output.writeEntityHeader(key, data.size) } throws IOException()
@ -141,11 +142,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `writing value throws`() = runBlocking { fun `writing value throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data every { decryptedInputStream.readBytes() } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
@ -158,11 +159,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `writing value succeeds`() = runBlocking { fun `writing value succeeds`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data every { decryptedInputStream.readBytes() } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
@ -175,11 +176,11 @@ internal class KVRestoreTest : RestoreTest() {
@Test @Test
fun `writing value uses old v0 code`() = runBlocking { fun `writing value uses old v0 code`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(0.toByte(), token, packageInfo)
getRecordsAndOutput() getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns 0.toByte() every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
every { every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key) crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key)
} returns VersionHeader(VERSION, packageInfo.packageName, key) } returns VersionHeader(VERSION, packageInfo.packageName, key)
@ -192,24 +193,39 @@ internal class KVRestoreTest : RestoreTest() {
verifyStreamWasClosed() verifyStreamWasClosed()
} }
@Test
fun `unexpected version aborts with error`() = runBlocking {
restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
getRecordsAndOutput()
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every {
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
} throws GeneralSecurityException()
streamsGetClosed()
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
verifyStreamWasClosed()
}
@Test @Test
fun `writing two values succeeds`() = runBlocking { fun `writing two values succeeds`() = runBlocking {
val data2 = getRandomByteArray() val data2 = getRandomByteArray()
val inputStream2 = mockk<InputStream>() val inputStream2 = mockk<InputStream>()
val decryptedInputStream2 = mockk<InputStream>() val decryptedInputStream2 = mockk<InputStream>()
restore.initializeState(token, packageInfo) restore.initializeState(VERSION, token, packageInfo)
getRecordsAndOutput(listOf(key64, key264)) getRecordsAndOutput(listOf(key64, key264))
// first key/value // first key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { decryptedInputStream.readBytes() } returns data every { decryptedInputStream.readBytes() } returns data
every { output.writeEntityHeader(key, data.size) } returns 42 every { output.writeEntityHeader(key, data.size) } returns 42
every { output.writeEntityData(data, data.size) } returns data.size every { output.writeEntityData(data, data.size) } returns data.size
// second key/value // second key/value
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2 coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2
every { headerReader.readVersion(inputStream2) } returns VERSION every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION
every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2 every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2
every { decryptedInputStream2.readBytes() } returns data2 every { decryptedInputStream2.readBytes() } returns data2
every { output.writeEntityHeader(key2, data2.size) } returns 42 every { output.writeEntityHeader(key2, data2.size) } returns 42

View file

@ -10,6 +10,7 @@ import android.content.pm.PackageInfo
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import com.stevesoltys.seedvault.coAssertThrows import com.stevesoltys.seedvault.coAssertThrows
import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.metadata.EncryptedBackupMetadata import com.stevesoltys.seedvault.metadata.EncryptedBackupMetadata
import com.stevesoltys.seedvault.metadata.MetadataReader import com.stevesoltys.seedvault.metadata.MetadataReader
import com.stevesoltys.seedvault.metadata.PackageMetadata import com.stevesoltys.seedvault.metadata.PackageMetadata
@ -213,7 +214,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.startRestore(token, packageInfoArray) restore.startRestore(token, packageInfoArray)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(token, packageInfo) } just Runs every { kv.initializeState(VERSION, token, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -226,7 +227,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
coEvery { kv.hasDataForPackage(token, packageInfo) } returns false coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
coEvery { full.hasDataForPackage(token, packageInfo) } returns true coEvery { full.hasDataForPackage(token, packageInfo) } returns true
every { full.initializeState(token, packageInfo) } just Runs every { full.initializeState(VERSION, token, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM) val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
@ -249,14 +250,14 @@ internal class RestoreCoordinatorTest : TransportTest() {
restore.startRestore(token, packageInfoArray2) restore.startRestore(token, packageInfoArray2)
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
every { kv.initializeState(token, packageInfo) } just Runs every { kv.initializeState(VERSION, token, packageInfo) } just Runs
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE) val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
assertEquals(expected, restore.nextRestorePackage()) assertEquals(expected, restore.nextRestorePackage())
coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false
coEvery { full.hasDataForPackage(token, packageInfo2) } returns true coEvery { full.hasDataForPackage(token, packageInfo2) } returns true
every { full.initializeState(token, packageInfo2) } just Runs every { full.initializeState(VERSION, token, packageInfo2) } just Runs
val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM) val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
assertEquals(expected2, restore.nextRestorePackage()) assertEquals(expected2, restore.nextRestorePackage())

View file

@ -61,7 +61,7 @@ internal class RestoreV0IntegrationTest : TransportTest() {
kvRestore, kvRestore,
fullRestore, fullRestore,
metadataReader metadataReader
).apply { beforeStartRestore(metadata) } ).apply { beforeStartRestore(metadata.copy(version = 0x00)) }
private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true) private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true)
private val appData = ("562AB665C3543120FC794D7CDA3AC18E5959235A4D" + private val appData = ("562AB665C3543120FC794D7CDA3AC18E5959235A4D" +