Do full backups with new version 1 with new crypto

Restoring still supports version 0 with old crypto
This commit is contained in:
Torsten Grote 2021-09-09 15:55:13 +02:00 committed by Chirayu Desai
parent 0c3ea7679b
commit f4dc776ed3
7 changed files with 140 additions and 91 deletions

View file

@ -140,6 +140,7 @@ internal class CryptoImpl(
deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray()) deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray())
} }
@Throws(IOException::class, GeneralSecurityException::class)
override fun newEncryptingStream( override fun newEncryptingStream(
outputStream: OutputStream, outputStream: OutputStream,
associatedData: ByteArray associatedData: ByteArray
@ -147,6 +148,7 @@ internal class CryptoImpl(
return StreamCrypto.newEncryptingStream(key, outputStream, associatedData) return StreamCrypto.newEncryptingStream(key, outputStream, associatedData)
} }
@Throws(IOException::class, GeneralSecurityException::class)
override fun newDecryptingStream( override fun newDecryptingStream(
inputStream: InputStream, inputStream: InputStream,
associatedData: ByteArray associatedData: ByteArray

View file

@ -1,6 +1,7 @@
package com.stevesoltys.seedvault.header package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_FULL
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV
import java.nio.ByteBuffer import java.nio.ByteBuffer
@ -40,6 +41,15 @@ internal fun getADForKV(version: Byte, packageName: String): ByteArray {
.array() .array()
} }
internal fun getADForFull(version: Byte, packageName: String): ByteArray {
val packageNameBytes = packageName.toByteArray()
return ByteBuffer.allocate(2 + packageNameBytes.size)
.put(version)
.put(TYPE_BACKUP_FULL)
.put(packageNameBytes)
.array()
}
internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES
internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt() internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt()
internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int = internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int =

View file

@ -10,7 +10,9 @@ import android.os.ParcelFileDescriptor
import android.util.Log import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderWriter import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
import java.io.EOFException import java.io.EOFException
@ -24,6 +26,9 @@ private class FullBackupState(
val inputStream: InputStream, val inputStream: InputStream,
var outputStreamInit: (suspend () -> OutputStream)? var outputStreamInit: (suspend () -> OutputStream)?
) { ) {
/**
* This is an encrypted stream that can be written to directly.
*/
var outputStream: OutputStream? = null var outputStream: OutputStream? = null
val packageName: String = packageInfo.packageName val packageName: String = packageInfo.packageName
var size: Long = 0 var size: Long = 0
@ -120,14 +125,8 @@ internal class FullBackup(
// store version header // store version header
val state = this.state ?: throw AssertionError() val state = this.state ?: throw AssertionError()
val header = VersionHeader(packageName = state.packageName) val header = VersionHeader(packageName = state.packageName)
try { headerWriter.writeVersion(outputStream, header)
headerWriter.writeVersion(outputStream, header) crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName))
crypto.encryptHeader(outputStream, header)
} catch (e: IOException) {
Log.e(TAG, "Error writing backup header", e)
throw(e)
}
outputStream
} // this lambda is only called before we actually write backup data the first time } // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK return TRANSPORT_OK
} }
@ -159,11 +158,11 @@ internal class FullBackup(
}() }()
state.outputStreamInit = null // the stream init lambda is not needed beyond that point state.outputStreamInit = null // the stream init lambda is not needed beyond that point
// read backup data, encrypt it and write it to output stream // read backup data and write it to encrypted output stream
val payload = ByteArray(numBytes) val payload = ByteArray(numBytes)
val read = state.inputStream.read(payload, 0, numBytes) val read = state.inputStream.read(payload, 0, numBytes)
if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.") if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.")
crypto.encryptSegment(outputStream, payload) outputStream.write(payload)
TRANSPORT_OK TRANSPORT_OK
} catch (e: IOException) { } catch (e: IOException) {
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e) Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)

View file

@ -9,16 +9,21 @@ import android.os.ParcelFileDescriptor
import android.util.Log import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.getADForFull
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
import java.io.EOFException import java.io.EOFException
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream
import java.security.GeneralSecurityException
private class FullRestoreState( private class FullRestoreState(
val token: Long, val token: Long,
val packageInfo: PackageInfo val packageInfo: PackageInfo
) { ) {
var version: Byte? = null
var inputStream: InputStream? = null var inputStream: InputStream? = null
} }
@ -81,7 +86,7 @@ internal class FullRestore(
* Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition * Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition
* that aborts all further restore operations on the current dataset. * that aborts all further restore operations on the current dataset.
*/ */
suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int { suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int = socket.use { pfd ->
val state = this.state ?: throw IllegalStateException("no state") val state = this.state ?: throw IllegalStateException("no state")
val packageName = state.packageInfo.packageName val packageName = state.packageInfo.packageName
@ -90,33 +95,48 @@ internal class FullRestore(
try { try {
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo) val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
val version = headerReader.readVersion(inputStream) val version = headerReader.readVersion(inputStream)
crypto.decryptHeader(inputStream, version, packageName) state.version = version
state.inputStream = inputStream if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName)
state.inputStream = inputStream
} else {
val ad = getADForFull(version, packageName)
state.inputStream = crypto.newDecryptingStream(inputStream, ad)
}
} catch (e: IOException) { } catch (e: IOException) {
Log.w(TAG, "Error getting input stream for $packageName", e) Log.w(TAG, "Error getting input stream for $packageName", e)
return TRANSPORT_PACKAGE_REJECTED return TRANSPORT_PACKAGE_REJECTED
} catch (e: SecurityException) { } catch (e: SecurityException) {
Log.e(TAG, "Security Exception while getting input stream for $packageName", e) Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
return TRANSPORT_ERROR return TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
return TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) { } catch (e: UnsupportedVersionException) {
Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e) Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e)
return TRANSPORT_PACKAGE_REJECTED return TRANSPORT_PACKAGE_REJECTED
} }
} }
return readInputStream(socket) return outputFactory.getOutputStream(pfd).use { outputStream ->
try {
copyInputStream(outputStream)
} catch (e: IOException) {
Log.w(TAG, "Error copying stream for package $packageName.", e)
return TRANSPORT_PACKAGE_REJECTED
}
}
} }
private fun readInputStream(socket: ParcelFileDescriptor): Int = socket.use { fileDescriptor -> @Throws(IOException::class)
private fun copyInputStream(outputStream: OutputStream): Int {
val state = this.state ?: throw IllegalStateException("no state") val state = this.state ?: throw IllegalStateException("no state")
val packageName = state.packageInfo.packageName
val inputStream = state.inputStream ?: throw IllegalStateException("no stream") val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
val outputStream = outputFactory.getOutputStream(fileDescriptor) val version = state.version ?: throw IllegalStateException("no version")
try { if (version == 0.toByte()) {
// read segment from input stream and decrypt it // read segment from input stream and decrypt it
val decrypted = try { val decrypted = try {
// TODO handle IOException
crypto.decryptSegment(inputStream) crypto.decryptSegment(inputStream)
} catch (e: EOFException) { } catch (e: EOFException) {
Log.i(TAG, " EOF") Log.i(TAG, " EOF")
@ -129,12 +149,17 @@ internal class FullRestore(
outputStream.write(decrypted) outputStream.write(decrypted)
// return number of written bytes // return number of written bytes
return decrypted.size return decrypted.size
} catch (e: IOException) { } else {
Log.w(TAG, "Error processing stream for package $packageName.", e) val buffer = ByteArray(MAX_SEGMENT_LENGTH)
closeQuietly(inputStream) val bytesRead = inputStream.read(buffer)
return TRANSPORT_PACKAGE_REJECTED if (bytesRead == -1) {
} finally { Log.i(TAG, " EOF")
closeQuietly(outputStream) // close input stream here as we won't need it anymore
closeQuietly(inputStream)
return NO_MORE_DATA
}
outputStream.write(buffer, 0, bytesRead)
return bytesRead
} }
} }

View file

@ -337,8 +337,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream
// restore data // restore data
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(appData.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
restore.finishRestore() restore.finishRestore()

View file

@ -4,6 +4,8 @@ import android.app.backup.BackupTransport.TRANSPORT_ERROR
import android.app.backup.BackupTransport.TRANSPORT_OK import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
@ -25,8 +27,8 @@ internal class FullBackupTest : BackupTest() {
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto) private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto)
private val bytes = ByteArray(23).apply { Random.nextBytes(this) } private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
private val closeBytes = ByteArray(42).apply { Random.nextBytes(this) }
private val inputStream = mockk<FileInputStream>() private val inputStream = mockk<FileInputStream>()
private val ad = getADForFull(VERSION, packageInfo.packageName)
@Test @Test
fun `has no initial state`() { fun `has no initial state`() {
@ -129,6 +131,7 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream() expectInitializeOutputStream()
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } throws IOException() every { inputStream.read(any(), any(), bytes.size) } throws IOException()
expectClearState() expectClearState()
@ -183,8 +186,9 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream() expectInitializeOutputStream()
every { settingsManager.isQuotaUnlimited() } returns false every { settingsManager.isQuotaUnlimited() } returns false
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
every { crypto.encryptSegment(outputStream, any()) } throws IOException() every { encryptedOutputStream.write(any<ByteArray>()) } throws IOException()
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
@ -256,8 +260,7 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream() expectInitializeOutputStream()
val numBytes = 42 val numBytes = 42
expectSendData(numBytes) expectSendData(numBytes)
every { outputStream.write(closeBytes) } just Runs every { encryptedOutputStream.flush() } throws IOException()
every { outputStream.flush() } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
assertTrue(backup.hasState()) assertTrue(backup.hasState())
@ -314,18 +317,18 @@ internal class FullBackupTest : BackupTest() {
private fun expectInitializeOutputStream() { private fun expectInitializeOutputStream() {
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { headerWriter.writeVersion(outputStream, header) } just Runs every { headerWriter.writeVersion(outputStream, header) } just Runs
every { crypto.encryptHeader(outputStream, header) } just Runs
} }
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) { private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
every { inputStream.read(any(), any(), numBytes) } returns readBytes every { inputStream.read(any(), any(), numBytes) } returns readBytes
every { crypto.encryptSegment(outputStream, any()) } just Runs every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { encryptedOutputStream.write(any<ByteArray>()) } just Runs
} }
private fun expectClearState() { private fun expectClearState() {
every { outputStream.write(closeBytes) } just Runs every { encryptedOutputStream.flush() } just Runs
every { outputStream.flush() } just Runs every { encryptedOutputStream.close() } just Runs
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { inputStream.close() } just Runs every { inputStream.close() } just Runs
every { data.close() } just Runs every { data.close() } just Runs

View file

@ -6,9 +6,12 @@ import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import com.stevesoltys.seedvault.coAssertThrows import com.stevesoltys.seedvault.coAssertThrows
import com.stevesoltys.seedvault.getRandomByteArray import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.CapturingSlot
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.every import io.mockk.every
@ -20,9 +23,10 @@ import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.EOFException
import java.io.IOException import java.io.IOException
import java.security.GeneralSecurityException
import kotlin.random.Random import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
@ -33,7 +37,7 @@ internal class FullRestoreTest : RestoreTest() {
private val encrypted = getRandomByteArray() private val encrypted = getRandomByteArray()
private val outputStream = ByteArrayOutputStream() private val outputStream = ByteArrayOutputStream()
private val versionHeader = VersionHeader(VERSION, packageInfo.packageName) private val ad = getADForFull(VERSION, packageInfo.packageName)
@Test @Test
fun `has no initial state`() { fun `has no initial state`() {
@ -67,6 +71,7 @@ internal class FullRestoreTest : RestoreTest() {
restore.initializeState(token, packageInfo) restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException() coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals( assertEquals(
TRANSPORT_PACKAGE_REJECTED, TRANSPORT_PACKAGE_REJECTED,
@ -80,6 +85,7 @@ internal class FullRestoreTest : RestoreTest() {
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } throws IOException() every { headerReader.readVersion(inputStream) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals( assertEquals(
TRANSPORT_PACKAGE_REJECTED, TRANSPORT_PACKAGE_REJECTED,
@ -95,6 +101,7 @@ internal class FullRestoreTest : RestoreTest() {
every { every {
headerReader.readVersion(inputStream) headerReader.readVersion(inputStream)
} throws UnsupportedVersionException(unsupportedVersion) } throws UnsupportedVersionException(unsupportedVersion)
every { fileDescriptor.close() } just Runs
assertEquals( assertEquals(
TRANSPORT_PACKAGE_REJECTED, TRANSPORT_PACKAGE_REJECTED,
@ -103,18 +110,13 @@ internal class FullRestoreTest : RestoreTest() {
} }
@Test @Test
fun `decrypting version header when getting first chunk throws`() = runBlocking { fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream) } returns VERSION
every { every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
crypto.decryptHeader( every { fileDescriptor.close() } just Runs
inputStream,
VERSION,
packageInfo.packageName
)
} throws IOException()
assertEquals( assertEquals(
TRANSPORT_PACKAGE_REJECTED, TRANSPORT_PACKAGE_REJECTED,
@ -123,54 +125,20 @@ internal class FullRestoreTest : RestoreTest() {
} }
@Test @Test
fun `decrypting version header when getting first chunk throws security exception`() = fun `getting decrypted stream when getting first chunk throws general security exception`() =
runBlocking { runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream) } returns VERSION
every { every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
crypto.decryptHeader( every { fileDescriptor.close() } just Runs
inputStream,
VERSION,
packageInfo.packageName
)
} throws SecurityException()
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor)) assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
} }
@Test @Test
fun `decrypting segment throws IOException`() = runBlocking { fun `full chunk gets decrypted`() = runBlocking {
restore.initializeState(token, packageInfo)
initInputStream()
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } throws IOException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `decrypting segment throws EOFException`() = runBlocking {
restore.initializeState(token, packageInfo)
initInputStream()
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } throws EOFException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
}
@Test
fun `full chunk gets encrypted`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(token, packageInfo)
initInputStream() initInputStream()
@ -183,6 +151,50 @@ internal class FullRestoreTest : RestoreTest() {
assertFalse(restore.hasState()) assertFalse(restore.hasState())
} }
@Test
fun `full chunk gets decrypted from version 0`() = runBlocking {
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns 0.toByte()
every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
} returns VersionHeader(0.toByte(), packageInfo.packageName)
every { crypto.decryptSegment(inputStream) } returns encrypted
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState())
}
@Test
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(1, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encryptedBytes, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState())
}
@Test @Test
fun `aborting full restore closes stream, resets state`() = runBlocking { fun `aborting full restore closes stream, resets state`() = runBlocking {
restore.initializeState(token, packageInfo) restore.initializeState(token, packageInfo)
@ -201,18 +213,17 @@ internal class FullRestoreTest : RestoreTest() {
private fun initInputStream() { private fun initInputStream() {
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION every { headerReader.readVersion(inputStream) } returns VERSION
every { every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
crypto.decryptHeader(
inputStream,
VERSION,
packageInfo.packageName
)
} returns versionHeader
} }
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) { private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } returns encryptedBytes val slot = CapturingSlot<ByteArray>()
every { decryptedInputStream.read(capture(slot)) } answers {
encryptedBytes.copyInto(slot.captured)
encryptedBytes.size
}
every { decryptedInputStream.close() } just Runs
every { fileDescriptor.close() } just Runs every { fileDescriptor.close() } just Runs
} }