Check version of backup files against expected version from metadata
and throw security exception if it does not match
This commit is contained in:
parent
5523e57fe7
commit
2932af463c
9 changed files with 118 additions and 63 deletions
|
@ -5,10 +5,11 @@ import java.io.EOFException
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
|
import java.security.GeneralSecurityException
|
||||||
|
|
||||||
internal interface HeaderReader {
|
internal interface HeaderReader {
|
||||||
@Throws(IOException::class, UnsupportedVersionException::class)
|
@Throws(IOException::class, UnsupportedVersionException::class)
|
||||||
fun readVersion(inputStream: InputStream): Byte
|
fun readVersion(inputStream: InputStream, expectedVersion: Byte): Byte
|
||||||
|
|
||||||
@Deprecated("")
|
@Deprecated("")
|
||||||
@Throws(SecurityException::class)
|
@Throws(SecurityException::class)
|
||||||
|
@ -21,11 +22,14 @@ internal interface HeaderReader {
|
||||||
|
|
||||||
internal class HeaderReaderImpl : HeaderReader {
|
internal class HeaderReaderImpl : HeaderReader {
|
||||||
|
|
||||||
@Throws(IOException::class, UnsupportedVersionException::class)
|
@Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
|
||||||
override fun readVersion(inputStream: InputStream): Byte {
|
override fun readVersion(inputStream: InputStream, expectedVersion: Byte): Byte {
|
||||||
val version = inputStream.read().toByte()
|
val version = inputStream.read().toByte()
|
||||||
if (version < 0) throw IOException()
|
if (version < 0) throw IOException()
|
||||||
if (version > VERSION) throw UnsupportedVersionException(version)
|
if (version > VERSION) throw UnsupportedVersionException(version)
|
||||||
|
if (expectedVersion != version) throw GeneralSecurityException(
|
||||||
|
"Expected version ${expectedVersion.toInt()}, but got ${version.toInt()}"
|
||||||
|
)
|
||||||
return version
|
return version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,10 @@ import java.io.OutputStream
|
||||||
import java.security.GeneralSecurityException
|
import java.security.GeneralSecurityException
|
||||||
|
|
||||||
private class FullRestoreState(
|
private class FullRestoreState(
|
||||||
|
val version: Byte,
|
||||||
val token: Long,
|
val token: Long,
|
||||||
val packageInfo: PackageInfo
|
val packageInfo: PackageInfo
|
||||||
) {
|
) {
|
||||||
var version: Byte? = null
|
|
||||||
var inputStream: InputStream? = null
|
var inputStream: InputStream? = null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,8 +55,8 @@ internal class FullRestore(
|
||||||
* It is possible that the system decides to not restore the package.
|
* It is possible that the system decides to not restore the package.
|
||||||
* Then a new state will be initialized right away without calling other methods.
|
* Then a new state will be initialized right away without calling other methods.
|
||||||
*/
|
*/
|
||||||
fun initializeState(token: Long, packageInfo: PackageInfo) {
|
fun initializeState(version: Byte, token: Long, packageInfo: PackageInfo) {
|
||||||
state = FullRestoreState(token, packageInfo)
|
state = FullRestoreState(version, token, packageInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -94,8 +94,7 @@ internal class FullRestore(
|
||||||
Log.i(TAG, "First Chunk, initializing package input stream.")
|
Log.i(TAG, "First Chunk, initializing package input stream.")
|
||||||
try {
|
try {
|
||||||
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
|
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
|
||||||
val version = headerReader.readVersion(inputStream)
|
val version = headerReader.readVersion(inputStream, state.version)
|
||||||
state.version = version
|
|
||||||
if (version == 0.toByte()) {
|
if (version == 0.toByte()) {
|
||||||
crypto.decryptHeader(inputStream, version, packageName)
|
crypto.decryptHeader(inputStream, version, packageName)
|
||||||
state.inputStream = inputStream
|
state.inputStream = inputStream
|
||||||
|
@ -132,9 +131,8 @@ internal class FullRestore(
|
||||||
private fun copyInputStream(outputStream: OutputStream): Int {
|
private fun copyInputStream(outputStream: OutputStream): Int {
|
||||||
val state = this.state ?: throw IllegalStateException("no state")
|
val state = this.state ?: throw IllegalStateException("no state")
|
||||||
val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
|
val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
|
||||||
val version = state.version ?: throw IllegalStateException("no version")
|
|
||||||
|
|
||||||
if (version == 0.toByte()) {
|
if (state.version == 0.toByte()) {
|
||||||
// read segment from input stream and decrypt it
|
// read segment from input stream and decrypt it
|
||||||
val decrypted = try {
|
val decrypted = try {
|
||||||
crypto.decryptSegment(inputStream)
|
crypto.decryptSegment(inputStream)
|
||||||
|
|
|
@ -17,10 +17,12 @@ import com.stevesoltys.seedvault.header.VERSION
|
||||||
import com.stevesoltys.seedvault.header.getADForKV
|
import com.stevesoltys.seedvault.header.getADForKV
|
||||||
import libcore.io.IoUtils.closeQuietly
|
import libcore.io.IoUtils.closeQuietly
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
|
import java.security.GeneralSecurityException
|
||||||
import java.util.ArrayList
|
import java.util.ArrayList
|
||||||
import javax.crypto.AEADBadTagException
|
import javax.crypto.AEADBadTagException
|
||||||
|
|
||||||
private class KVRestoreState(
|
private class KVRestoreState(
|
||||||
|
val version: Byte,
|
||||||
val token: Long,
|
val token: Long,
|
||||||
val packageInfo: PackageInfo,
|
val packageInfo: PackageInfo,
|
||||||
/**
|
/**
|
||||||
|
@ -57,8 +59,13 @@ internal class KVRestore(
|
||||||
*
|
*
|
||||||
* @param pmPackageInfo single optional [PackageInfo] to optimize restore of @pm@
|
* @param pmPackageInfo single optional [PackageInfo] to optimize restore of @pm@
|
||||||
*/
|
*/
|
||||||
fun initializeState(token: Long, packageInfo: PackageInfo, pmPackageInfo: PackageInfo? = null) {
|
fun initializeState(
|
||||||
state = KVRestoreState(token, packageInfo, pmPackageInfo)
|
version: Byte,
|
||||||
|
token: Long,
|
||||||
|
packageInfo: PackageInfo,
|
||||||
|
pmPackageInfo: PackageInfo? = null
|
||||||
|
) {
|
||||||
|
state = KVRestoreState(version, token, packageInfo, pmPackageInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,6 +105,9 @@ internal class KVRestore(
|
||||||
} catch (e: SecurityException) {
|
} catch (e: SecurityException) {
|
||||||
Log.e(TAG, "Security exception while reading backup records", e)
|
Log.e(TAG, "Security exception while reading backup records", e)
|
||||||
TRANSPORT_ERROR
|
TRANSPORT_ERROR
|
||||||
|
} catch (e: GeneralSecurityException) {
|
||||||
|
Log.e(TAG, "General security exception while reading backup records", e)
|
||||||
|
TRANSPORT_ERROR
|
||||||
} catch (e: UnsupportedVersionException) {
|
} catch (e: UnsupportedVersionException) {
|
||||||
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
|
Log.e(TAG, "Unsupported version in backup: ${e.version}", e)
|
||||||
TRANSPORT_ERROR
|
TRANSPORT_ERROR
|
||||||
|
@ -140,14 +150,14 @@ internal class KVRestore(
|
||||||
/**
|
/**
|
||||||
* Read the encrypted value for the given key and write it to the given [BackupDataOutput].
|
* Read the encrypted value for the given key and write it to the given [BackupDataOutput].
|
||||||
*/
|
*/
|
||||||
@Throws(IOException::class, UnsupportedVersionException::class, SecurityException::class)
|
@Throws(IOException::class, UnsupportedVersionException::class, GeneralSecurityException::class)
|
||||||
private suspend fun readAndWriteValue(
|
private suspend fun readAndWriteValue(
|
||||||
state: KVRestoreState,
|
state: KVRestoreState,
|
||||||
dKey: DecodedKey,
|
dKey: DecodedKey,
|
||||||
out: BackupDataOutput
|
out: BackupDataOutput
|
||||||
) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
|
) = plugin.getInputStreamForRecord(state.token, state.packageInfo, dKey.base64Key)
|
||||||
.use { inputStream ->
|
.use { inputStream ->
|
||||||
val version = headerReader.readVersion(inputStream)
|
val version = headerReader.readVersion(inputStream, state.version)
|
||||||
val packageName = state.packageInfo.packageName
|
val packageName = state.packageInfo.packageName
|
||||||
val value = if (version == 0.toByte()) {
|
val value = if (version == 0.toByte()) {
|
||||||
crypto.decryptHeader(inputStream, version, packageName, dKey.key)
|
crypto.decryptHeader(inputStream, version, packageName, dKey.key)
|
||||||
|
|
|
@ -200,6 +200,7 @@ internal class RestoreCoordinator(
|
||||||
val state = this.state ?: throw IllegalStateException("no state")
|
val state = this.state ?: throw IllegalStateException("no state")
|
||||||
|
|
||||||
if (!state.packages.hasNext()) return NO_MORE_PACKAGES
|
if (!state.packages.hasNext()) return NO_MORE_PACKAGES
|
||||||
|
val version = state.backupMetadata.version
|
||||||
val packageInfo = state.packages.next()
|
val packageInfo = state.packages.next()
|
||||||
val packageName = packageInfo.packageName
|
val packageName = packageInfo.packageName
|
||||||
|
|
||||||
|
@ -208,13 +209,13 @@ internal class RestoreCoordinator(
|
||||||
// check key/value data first and if available, don't even check for full data
|
// check key/value data first and if available, don't even check for full data
|
||||||
kv.hasDataForPackage(state.token, packageInfo) -> {
|
kv.hasDataForPackage(state.token, packageInfo) -> {
|
||||||
Log.i(TAG, "Found K/V data for $packageName.")
|
Log.i(TAG, "Found K/V data for $packageName.")
|
||||||
kv.initializeState(state.token, packageInfo, state.pmPackageInfo)
|
kv.initializeState(version, state.token, packageInfo, state.pmPackageInfo)
|
||||||
state.currentPackage = packageName
|
state.currentPackage = packageName
|
||||||
TYPE_KEY_VALUE
|
TYPE_KEY_VALUE
|
||||||
}
|
}
|
||||||
full.hasDataForPackage(state.token, packageInfo) -> {
|
full.hasDataForPackage(state.token, packageInfo) -> {
|
||||||
Log.i(TAG, "Found full backup data for $packageName.")
|
Log.i(TAG, "Found full backup data for $packageName.")
|
||||||
full.initializeState(state.token, packageInfo)
|
full.initializeState(version, state.token, packageInfo)
|
||||||
state.currentPackage = packageName
|
state.currentPackage = packageName
|
||||||
TYPE_FULL_STREAM
|
TYPE_FULL_STREAM
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ internal class HeaderReaderTest {
|
||||||
val input = byteArrayOf(VERSION)
|
val input = byteArrayOf(VERSION)
|
||||||
val inputStream = ByteArrayInputStream(input)
|
val inputStream = ByteArrayInputStream(input)
|
||||||
|
|
||||||
assertEquals(VERSION, reader.readVersion(inputStream))
|
assertEquals(VERSION, reader.readVersion(inputStream, VERSION))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -34,7 +34,7 @@ internal class HeaderReaderTest {
|
||||||
val input = ByteArray(0)
|
val input = ByteArray(0)
|
||||||
val inputStream = ByteArrayInputStream(input)
|
val inputStream = ByteArrayInputStream(input)
|
||||||
assertThrows(IOException::class.javaObjectType) {
|
assertThrows(IOException::class.javaObjectType) {
|
||||||
reader.readVersion(inputStream)
|
reader.readVersion(inputStream, VERSION)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ internal class HeaderReaderTest {
|
||||||
val input = byteArrayOf((VERSION + 1).toByte())
|
val input = byteArrayOf((VERSION + 1).toByte())
|
||||||
val inputStream = ByteArrayInputStream(input)
|
val inputStream = ByteArrayInputStream(input)
|
||||||
assertThrows(UnsupportedVersionException::class.javaObjectType) {
|
assertThrows(UnsupportedVersionException::class.javaObjectType) {
|
||||||
reader.readVersion(inputStream)
|
reader.readVersion(inputStream, VERSION)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ internal class HeaderReaderTest {
|
||||||
val input = byteArrayOf((-1).toByte())
|
val input = byteArrayOf((-1).toByte())
|
||||||
val inputStream = ByteArrayInputStream(input)
|
val inputStream = ByteArrayInputStream(input)
|
||||||
assertThrows(IOException::class.javaObjectType) {
|
assertThrows(IOException::class.javaObjectType) {
|
||||||
reader.readVersion(inputStream)
|
reader.readVersion(inputStream, VERSION)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,16 @@ internal class HeaderReaderTest {
|
||||||
val input = byteArrayOf(Byte.MAX_VALUE)
|
val input = byteArrayOf(Byte.MAX_VALUE)
|
||||||
val inputStream = ByteArrayInputStream(input)
|
val inputStream = ByteArrayInputStream(input)
|
||||||
assertThrows(UnsupportedVersionException::class.javaObjectType) {
|
assertThrows(UnsupportedVersionException::class.javaObjectType) {
|
||||||
reader.readVersion(inputStream)
|
reader.readVersion(inputStream, Byte.MAX_VALUE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `unexpected version throws exception`() {
|
||||||
|
val input = byteArrayOf(VERSION + 1)
|
||||||
|
val inputStream = ByteArrayInputStream(input)
|
||||||
|
assertThrows(UnsupportedVersionException::class.javaObjectType) {
|
||||||
|
reader.readVersion(inputStream, VERSION)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
@Test
|
@Test
|
||||||
fun `initializing state leaves a state`() {
|
fun `initializing state leaves a state`() {
|
||||||
assertFalse(restore.hasState())
|
assertFalse(restore.hasState())
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
assertTrue(restore.hasState())
|
assertTrue(restore.hasState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `getting InputStream for package when getting first chunk throws`() = runBlocking {
|
fun `getting InputStream for package when getting first chunk throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
@ -81,10 +81,10 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `reading version header when getting first chunk throws`() = runBlocking {
|
fun `reading version header when getting first chunk throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } throws IOException()
|
every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
|
@ -95,11 +95,11 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `reading unsupported version when getting first chunk`() = runBlocking {
|
fun `reading unsupported version when getting first chunk`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every {
|
every {
|
||||||
headerReader.readVersion(inputStream)
|
headerReader.readVersion(inputStream, VERSION)
|
||||||
} throws UnsupportedVersionException(unsupportedVersion)
|
} throws UnsupportedVersionException(unsupportedVersion)
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
|
@ -111,10 +111,10 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
|
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
|
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
|
@ -127,10 +127,10 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
@Test
|
@Test
|
||||||
fun `getting decrypted stream when getting first chunk throws general security exception`() =
|
fun `getting decrypted stream when getting first chunk throws general security exception`() =
|
||||||
runBlocking {
|
runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
|
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `full chunk gets decrypted`() = runBlocking {
|
fun `full chunk gets decrypted`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
initInputStream()
|
initInputStream()
|
||||||
readAndEncryptInputStream(encrypted)
|
readAndEncryptInputStream(encrypted)
|
||||||
|
@ -153,10 +153,10 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `full chunk gets decrypted from version 0`() = runBlocking {
|
fun `full chunk gets decrypted from version 0`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(0.toByte(), token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns 0.toByte()
|
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
|
||||||
every {
|
every {
|
||||||
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
|
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
|
||||||
} returns VersionHeader(0.toByte(), packageInfo.packageName)
|
} returns VersionHeader(0.toByte(), packageInfo.packageName)
|
||||||
|
@ -172,14 +172,30 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
assertFalse(restore.hasState())
|
assertFalse(restore.hasState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `unexpected version aborts with error`() = runBlocking {
|
||||||
|
restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
|
||||||
|
|
||||||
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
|
every {
|
||||||
|
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
|
||||||
|
} throws GeneralSecurityException()
|
||||||
|
every { inputStream.close() } just Runs
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
|
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
restore.abortFullRestore()
|
||||||
|
assertFalse(restore.hasState())
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
|
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
|
||||||
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
|
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
|
||||||
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
|
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
|
@ -197,7 +213,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `aborting full restore closes stream, resets state`() = runBlocking {
|
fun `aborting full restore closes stream, resets state`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
initInputStream()
|
initInputStream()
|
||||||
readAndEncryptInputStream(encrypted)
|
readAndEncryptInputStream(encrypted)
|
||||||
|
@ -212,7 +228,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
private fun initInputStream() {
|
private fun initInputStream() {
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
import java.security.GeneralSecurityException
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@Suppress("BlockingMethodInNonBlockingContext")
|
@Suppress("BlockingMethodInNonBlockingContext")
|
||||||
|
@ -60,7 +61,7 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `listing records throws`() = runBlocking {
|
fun `listing records throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.listRecords(token, packageInfo) } throws IOException()
|
coEvery { plugin.listRecords(token, packageInfo) } throws IOException()
|
||||||
|
|
||||||
|
@ -69,12 +70,12 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `reading VersionHeader with unsupported version throws`() = runBlocking {
|
fun `reading VersionHeader with unsupported version throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every {
|
every {
|
||||||
headerReader.readVersion(inputStream)
|
headerReader.readVersion(inputStream, VERSION)
|
||||||
} throws UnsupportedVersionException(unsupportedVersion)
|
} throws UnsupportedVersionException(unsupportedVersion)
|
||||||
streamsGetClosed()
|
streamsGetClosed()
|
||||||
|
|
||||||
|
@ -84,11 +85,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `error reading VersionHeader throws`() = runBlocking {
|
fun `error reading VersionHeader throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } throws IOException()
|
every { headerReader.readVersion(inputStream, VERSION) } throws IOException()
|
||||||
streamsGetClosed()
|
streamsGetClosed()
|
||||||
|
|
||||||
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
|
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
|
||||||
|
@ -97,11 +98,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `decrypting stream throws`() = runBlocking {
|
fun `decrypting stream throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
|
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
|
||||||
streamsGetClosed()
|
streamsGetClosed()
|
||||||
|
|
||||||
|
@ -111,11 +112,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `decrypting stream throws security exception`() = runBlocking {
|
fun `decrypting stream throws security exception`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException()
|
every { crypto.newDecryptingStream(inputStream, ad) } throws SecurityException()
|
||||||
streamsGetClosed()
|
streamsGetClosed()
|
||||||
|
|
||||||
|
@ -125,11 +126,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `writing header throws`() = runBlocking {
|
fun `writing header throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
every { decryptedInputStream.readBytes() } returns data
|
every { decryptedInputStream.readBytes() } returns data
|
||||||
every { output.writeEntityHeader(key, data.size) } throws IOException()
|
every { output.writeEntityHeader(key, data.size) } throws IOException()
|
||||||
|
@ -141,11 +142,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `writing value throws`() = runBlocking {
|
fun `writing value throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
every { decryptedInputStream.readBytes() } returns data
|
every { decryptedInputStream.readBytes() } returns data
|
||||||
every { output.writeEntityHeader(key, data.size) } returns 42
|
every { output.writeEntityHeader(key, data.size) } returns 42
|
||||||
|
@ -158,11 +159,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `writing value succeeds`() = runBlocking {
|
fun `writing value succeeds`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
every { decryptedInputStream.readBytes() } returns data
|
every { decryptedInputStream.readBytes() } returns data
|
||||||
every { output.writeEntityHeader(key, data.size) } returns 42
|
every { output.writeEntityHeader(key, data.size) } returns 42
|
||||||
|
@ -175,11 +176,11 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `writing value uses old v0 code`() = runBlocking {
|
fun `writing value uses old v0 code`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(0.toByte(), token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput()
|
getRecordsAndOutput()
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns 0.toByte()
|
every { headerReader.readVersion(inputStream, 0.toByte()) } returns 0.toByte()
|
||||||
every {
|
every {
|
||||||
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key)
|
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName, key)
|
||||||
} returns VersionHeader(VERSION, packageInfo.packageName, key)
|
} returns VersionHeader(VERSION, packageInfo.packageName, key)
|
||||||
|
@ -192,24 +193,39 @@ internal class KVRestoreTest : RestoreTest() {
|
||||||
verifyStreamWasClosed()
|
verifyStreamWasClosed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `unexpected version aborts with error`() = runBlocking {
|
||||||
|
restore.initializeState(Byte.MAX_VALUE, token, packageInfo)
|
||||||
|
|
||||||
|
getRecordsAndOutput()
|
||||||
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
|
every {
|
||||||
|
headerReader.readVersion(inputStream, Byte.MAX_VALUE)
|
||||||
|
} throws GeneralSecurityException()
|
||||||
|
streamsGetClosed()
|
||||||
|
|
||||||
|
assertEquals(TRANSPORT_ERROR, restore.getRestoreData(fileDescriptor))
|
||||||
|
verifyStreamWasClosed()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `writing two values succeeds`() = runBlocking {
|
fun `writing two values succeeds`() = runBlocking {
|
||||||
val data2 = getRandomByteArray()
|
val data2 = getRandomByteArray()
|
||||||
val inputStream2 = mockk<InputStream>()
|
val inputStream2 = mockk<InputStream>()
|
||||||
val decryptedInputStream2 = mockk<InputStream>()
|
val decryptedInputStream2 = mockk<InputStream>()
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(VERSION, token, packageInfo)
|
||||||
|
|
||||||
getRecordsAndOutput(listOf(key64, key264))
|
getRecordsAndOutput(listOf(key64, key264))
|
||||||
// first key/value
|
// first key/value
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key64) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
every { decryptedInputStream.readBytes() } returns data
|
every { decryptedInputStream.readBytes() } returns data
|
||||||
every { output.writeEntityHeader(key, data.size) } returns 42
|
every { output.writeEntityHeader(key, data.size) } returns 42
|
||||||
every { output.writeEntityData(data, data.size) } returns data.size
|
every { output.writeEntityData(data, data.size) } returns data.size
|
||||||
// second key/value
|
// second key/value
|
||||||
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2
|
coEvery { plugin.getInputStreamForRecord(token, packageInfo, key264) } returns inputStream2
|
||||||
every { headerReader.readVersion(inputStream2) } returns VERSION
|
every { headerReader.readVersion(inputStream2, VERSION) } returns VERSION
|
||||||
every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2
|
every { crypto.newDecryptingStream(inputStream2, ad) } returns decryptedInputStream2
|
||||||
every { decryptedInputStream2.readBytes() } returns data2
|
every { decryptedInputStream2.readBytes() } returns data2
|
||||||
every { output.writeEntityHeader(key2, data2.size) } returns 42
|
every { output.writeEntityHeader(key2, data2.size) } returns 42
|
||||||
|
|
|
@ -10,6 +10,7 @@ import android.content.pm.PackageInfo
|
||||||
import android.os.ParcelFileDescriptor
|
import android.os.ParcelFileDescriptor
|
||||||
import com.stevesoltys.seedvault.coAssertThrows
|
import com.stevesoltys.seedvault.coAssertThrows
|
||||||
import com.stevesoltys.seedvault.getRandomString
|
import com.stevesoltys.seedvault.getRandomString
|
||||||
|
import com.stevesoltys.seedvault.header.VERSION
|
||||||
import com.stevesoltys.seedvault.metadata.EncryptedBackupMetadata
|
import com.stevesoltys.seedvault.metadata.EncryptedBackupMetadata
|
||||||
import com.stevesoltys.seedvault.metadata.MetadataReader
|
import com.stevesoltys.seedvault.metadata.MetadataReader
|
||||||
import com.stevesoltys.seedvault.metadata.PackageMetadata
|
import com.stevesoltys.seedvault.metadata.PackageMetadata
|
||||||
|
@ -213,7 +214,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
|
||||||
restore.startRestore(token, packageInfoArray)
|
restore.startRestore(token, packageInfoArray)
|
||||||
|
|
||||||
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
|
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
|
||||||
every { kv.initializeState(token, packageInfo) } just Runs
|
every { kv.initializeState(VERSION, token, packageInfo) } just Runs
|
||||||
|
|
||||||
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
|
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
|
||||||
assertEquals(expected, restore.nextRestorePackage())
|
assertEquals(expected, restore.nextRestorePackage())
|
||||||
|
@ -226,7 +227,7 @@ internal class RestoreCoordinatorTest : TransportTest() {
|
||||||
|
|
||||||
coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
|
coEvery { kv.hasDataForPackage(token, packageInfo) } returns false
|
||||||
coEvery { full.hasDataForPackage(token, packageInfo) } returns true
|
coEvery { full.hasDataForPackage(token, packageInfo) } returns true
|
||||||
every { full.initializeState(token, packageInfo) } just Runs
|
every { full.initializeState(VERSION, token, packageInfo) } just Runs
|
||||||
|
|
||||||
val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM)
|
val expected = RestoreDescription(packageInfo.packageName, TYPE_FULL_STREAM)
|
||||||
assertEquals(expected, restore.nextRestorePackage())
|
assertEquals(expected, restore.nextRestorePackage())
|
||||||
|
@ -249,14 +250,14 @@ internal class RestoreCoordinatorTest : TransportTest() {
|
||||||
restore.startRestore(token, packageInfoArray2)
|
restore.startRestore(token, packageInfoArray2)
|
||||||
|
|
||||||
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
|
coEvery { kv.hasDataForPackage(token, packageInfo) } returns true
|
||||||
every { kv.initializeState(token, packageInfo) } just Runs
|
every { kv.initializeState(VERSION, token, packageInfo) } just Runs
|
||||||
|
|
||||||
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
|
val expected = RestoreDescription(packageInfo.packageName, TYPE_KEY_VALUE)
|
||||||
assertEquals(expected, restore.nextRestorePackage())
|
assertEquals(expected, restore.nextRestorePackage())
|
||||||
|
|
||||||
coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false
|
coEvery { kv.hasDataForPackage(token, packageInfo2) } returns false
|
||||||
coEvery { full.hasDataForPackage(token, packageInfo2) } returns true
|
coEvery { full.hasDataForPackage(token, packageInfo2) } returns true
|
||||||
every { full.initializeState(token, packageInfo2) } just Runs
|
every { full.initializeState(VERSION, token, packageInfo2) } just Runs
|
||||||
|
|
||||||
val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
|
val expected2 = RestoreDescription(packageInfo2.packageName, TYPE_FULL_STREAM)
|
||||||
assertEquals(expected2, restore.nextRestorePackage())
|
assertEquals(expected2, restore.nextRestorePackage())
|
||||||
|
|
|
@ -61,7 +61,7 @@ internal class RestoreV0IntegrationTest : TransportTest() {
|
||||||
kvRestore,
|
kvRestore,
|
||||||
fullRestore,
|
fullRestore,
|
||||||
metadataReader
|
metadataReader
|
||||||
).apply { beforeStartRestore(metadata) }
|
).apply { beforeStartRestore(metadata.copy(version = 0x00)) }
|
||||||
|
|
||||||
private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true)
|
private val fileDescriptor = mockk<ParcelFileDescriptor>(relaxed = true)
|
||||||
private val appData = ("562AB665C3543120FC794D7CDA3AC18E5959235A4D" +
|
private val appData = ("562AB665C3543120FC794D7CDA3AC18E5959235A4D" +
|
||||||
|
|
Loading…
Reference in a new issue