Do full backups with new version 1 with new crypto
Restoring still supports version 0 with old crypto
This commit is contained in:
parent
0c3ea7679b
commit
f4dc776ed3
7 changed files with 140 additions and 91 deletions
|
@ -140,6 +140,7 @@ internal class CryptoImpl(
|
||||||
deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray())
|
deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Throws(IOException::class, GeneralSecurityException::class)
|
||||||
override fun newEncryptingStream(
|
override fun newEncryptingStream(
|
||||||
outputStream: OutputStream,
|
outputStream: OutputStream,
|
||||||
associatedData: ByteArray
|
associatedData: ByteArray
|
||||||
|
@ -147,6 +148,7 @@ internal class CryptoImpl(
|
||||||
return StreamCrypto.newEncryptingStream(key, outputStream, associatedData)
|
return StreamCrypto.newEncryptingStream(key, outputStream, associatedData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Throws(IOException::class, GeneralSecurityException::class)
|
||||||
override fun newDecryptingStream(
|
override fun newDecryptingStream(
|
||||||
inputStream: InputStream,
|
inputStream: InputStream,
|
||||||
associatedData: ByteArray
|
associatedData: ByteArray
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package com.stevesoltys.seedvault.header
|
package com.stevesoltys.seedvault.header
|
||||||
|
|
||||||
import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH
|
import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH
|
||||||
|
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_FULL
|
||||||
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV
|
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
|
|
||||||
|
@ -40,6 +41,15 @@ internal fun getADForKV(version: Byte, packageName: String): ByteArray {
|
||||||
.array()
|
.array()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun getADForFull(version: Byte, packageName: String): ByteArray {
|
||||||
|
val packageNameBytes = packageName.toByteArray()
|
||||||
|
return ByteBuffer.allocate(2 + packageNameBytes.size)
|
||||||
|
.put(version)
|
||||||
|
.put(TYPE_BACKUP_FULL)
|
||||||
|
.put(packageNameBytes)
|
||||||
|
.array()
|
||||||
|
}
|
||||||
|
|
||||||
internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES
|
internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES
|
||||||
internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt()
|
internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt()
|
||||||
internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int =
|
internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int =
|
||||||
|
|
|
@ -10,7 +10,9 @@ import android.os.ParcelFileDescriptor
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import com.stevesoltys.seedvault.crypto.Crypto
|
import com.stevesoltys.seedvault.crypto.Crypto
|
||||||
import com.stevesoltys.seedvault.header.HeaderWriter
|
import com.stevesoltys.seedvault.header.HeaderWriter
|
||||||
|
import com.stevesoltys.seedvault.header.VERSION
|
||||||
import com.stevesoltys.seedvault.header.VersionHeader
|
import com.stevesoltys.seedvault.header.VersionHeader
|
||||||
|
import com.stevesoltys.seedvault.header.getADForFull
|
||||||
import com.stevesoltys.seedvault.settings.SettingsManager
|
import com.stevesoltys.seedvault.settings.SettingsManager
|
||||||
import libcore.io.IoUtils.closeQuietly
|
import libcore.io.IoUtils.closeQuietly
|
||||||
import java.io.EOFException
|
import java.io.EOFException
|
||||||
|
@ -24,6 +26,9 @@ private class FullBackupState(
|
||||||
val inputStream: InputStream,
|
val inputStream: InputStream,
|
||||||
var outputStreamInit: (suspend () -> OutputStream)?
|
var outputStreamInit: (suspend () -> OutputStream)?
|
||||||
) {
|
) {
|
||||||
|
/**
|
||||||
|
* This is an encrypted stream that can be written to directly.
|
||||||
|
*/
|
||||||
var outputStream: OutputStream? = null
|
var outputStream: OutputStream? = null
|
||||||
val packageName: String = packageInfo.packageName
|
val packageName: String = packageInfo.packageName
|
||||||
var size: Long = 0
|
var size: Long = 0
|
||||||
|
@ -120,14 +125,8 @@ internal class FullBackup(
|
||||||
// store version header
|
// store version header
|
||||||
val state = this.state ?: throw AssertionError()
|
val state = this.state ?: throw AssertionError()
|
||||||
val header = VersionHeader(packageName = state.packageName)
|
val header = VersionHeader(packageName = state.packageName)
|
||||||
try {
|
headerWriter.writeVersion(outputStream, header)
|
||||||
headerWriter.writeVersion(outputStream, header)
|
crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName))
|
||||||
crypto.encryptHeader(outputStream, header)
|
|
||||||
} catch (e: IOException) {
|
|
||||||
Log.e(TAG, "Error writing backup header", e)
|
|
||||||
throw(e)
|
|
||||||
}
|
|
||||||
outputStream
|
|
||||||
} // this lambda is only called before we actually write backup data the first time
|
} // this lambda is only called before we actually write backup data the first time
|
||||||
return TRANSPORT_OK
|
return TRANSPORT_OK
|
||||||
}
|
}
|
||||||
|
@ -159,11 +158,11 @@ internal class FullBackup(
|
||||||
}()
|
}()
|
||||||
state.outputStreamInit = null // the stream init lambda is not needed beyond that point
|
state.outputStreamInit = null // the stream init lambda is not needed beyond that point
|
||||||
|
|
||||||
// read backup data, encrypt it and write it to output stream
|
// read backup data and write it to encrypted output stream
|
||||||
val payload = ByteArray(numBytes)
|
val payload = ByteArray(numBytes)
|
||||||
val read = state.inputStream.read(payload, 0, numBytes)
|
val read = state.inputStream.read(payload, 0, numBytes)
|
||||||
if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.")
|
if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.")
|
||||||
crypto.encryptSegment(outputStream, payload)
|
outputStream.write(payload)
|
||||||
TRANSPORT_OK
|
TRANSPORT_OK
|
||||||
} catch (e: IOException) {
|
} catch (e: IOException) {
|
||||||
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)
|
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)
|
||||||
|
|
|
@ -9,16 +9,21 @@ import android.os.ParcelFileDescriptor
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import com.stevesoltys.seedvault.crypto.Crypto
|
import com.stevesoltys.seedvault.crypto.Crypto
|
||||||
import com.stevesoltys.seedvault.header.HeaderReader
|
import com.stevesoltys.seedvault.header.HeaderReader
|
||||||
|
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
|
||||||
import com.stevesoltys.seedvault.header.UnsupportedVersionException
|
import com.stevesoltys.seedvault.header.UnsupportedVersionException
|
||||||
|
import com.stevesoltys.seedvault.header.getADForFull
|
||||||
import libcore.io.IoUtils.closeQuietly
|
import libcore.io.IoUtils.closeQuietly
|
||||||
import java.io.EOFException
|
import java.io.EOFException
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
import java.io.OutputStream
|
||||||
|
import java.security.GeneralSecurityException
|
||||||
|
|
||||||
private class FullRestoreState(
|
private class FullRestoreState(
|
||||||
val token: Long,
|
val token: Long,
|
||||||
val packageInfo: PackageInfo
|
val packageInfo: PackageInfo
|
||||||
) {
|
) {
|
||||||
|
var version: Byte? = null
|
||||||
var inputStream: InputStream? = null
|
var inputStream: InputStream? = null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +86,7 @@ internal class FullRestore(
|
||||||
* Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition
|
* Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition
|
||||||
* that aborts all further restore operations on the current dataset.
|
* that aborts all further restore operations on the current dataset.
|
||||||
*/
|
*/
|
||||||
suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int {
|
suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int = socket.use { pfd ->
|
||||||
val state = this.state ?: throw IllegalStateException("no state")
|
val state = this.state ?: throw IllegalStateException("no state")
|
||||||
val packageName = state.packageInfo.packageName
|
val packageName = state.packageInfo.packageName
|
||||||
|
|
||||||
|
@ -90,33 +95,48 @@ internal class FullRestore(
|
||||||
try {
|
try {
|
||||||
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
|
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
|
||||||
val version = headerReader.readVersion(inputStream)
|
val version = headerReader.readVersion(inputStream)
|
||||||
crypto.decryptHeader(inputStream, version, packageName)
|
state.version = version
|
||||||
state.inputStream = inputStream
|
if (version == 0.toByte()) {
|
||||||
|
crypto.decryptHeader(inputStream, version, packageName)
|
||||||
|
state.inputStream = inputStream
|
||||||
|
} else {
|
||||||
|
val ad = getADForFull(version, packageName)
|
||||||
|
state.inputStream = crypto.newDecryptingStream(inputStream, ad)
|
||||||
|
}
|
||||||
} catch (e: IOException) {
|
} catch (e: IOException) {
|
||||||
Log.w(TAG, "Error getting input stream for $packageName", e)
|
Log.w(TAG, "Error getting input stream for $packageName", e)
|
||||||
return TRANSPORT_PACKAGE_REJECTED
|
return TRANSPORT_PACKAGE_REJECTED
|
||||||
} catch (e: SecurityException) {
|
} catch (e: SecurityException) {
|
||||||
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
|
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
|
||||||
return TRANSPORT_ERROR
|
return TRANSPORT_ERROR
|
||||||
|
} catch (e: GeneralSecurityException) {
|
||||||
|
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
|
||||||
|
return TRANSPORT_ERROR
|
||||||
} catch (e: UnsupportedVersionException) {
|
} catch (e: UnsupportedVersionException) {
|
||||||
Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e)
|
Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e)
|
||||||
return TRANSPORT_PACKAGE_REJECTED
|
return TRANSPORT_PACKAGE_REJECTED
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return readInputStream(socket)
|
return outputFactory.getOutputStream(pfd).use { outputStream ->
|
||||||
|
try {
|
||||||
|
copyInputStream(outputStream)
|
||||||
|
} catch (e: IOException) {
|
||||||
|
Log.w(TAG, "Error copying stream for package $packageName.", e)
|
||||||
|
return TRANSPORT_PACKAGE_REJECTED
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun readInputStream(socket: ParcelFileDescriptor): Int = socket.use { fileDescriptor ->
|
@Throws(IOException::class)
|
||||||
|
private fun copyInputStream(outputStream: OutputStream): Int {
|
||||||
val state = this.state ?: throw IllegalStateException("no state")
|
val state = this.state ?: throw IllegalStateException("no state")
|
||||||
val packageName = state.packageInfo.packageName
|
|
||||||
val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
|
val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
|
||||||
val outputStream = outputFactory.getOutputStream(fileDescriptor)
|
val version = state.version ?: throw IllegalStateException("no version")
|
||||||
|
|
||||||
try {
|
if (version == 0.toByte()) {
|
||||||
// read segment from input stream and decrypt it
|
// read segment from input stream and decrypt it
|
||||||
val decrypted = try {
|
val decrypted = try {
|
||||||
// TODO handle IOException
|
|
||||||
crypto.decryptSegment(inputStream)
|
crypto.decryptSegment(inputStream)
|
||||||
} catch (e: EOFException) {
|
} catch (e: EOFException) {
|
||||||
Log.i(TAG, " EOF")
|
Log.i(TAG, " EOF")
|
||||||
|
@ -129,12 +149,17 @@ internal class FullRestore(
|
||||||
outputStream.write(decrypted)
|
outputStream.write(decrypted)
|
||||||
// return number of written bytes
|
// return number of written bytes
|
||||||
return decrypted.size
|
return decrypted.size
|
||||||
} catch (e: IOException) {
|
} else {
|
||||||
Log.w(TAG, "Error processing stream for package $packageName.", e)
|
val buffer = ByteArray(MAX_SEGMENT_LENGTH)
|
||||||
closeQuietly(inputStream)
|
val bytesRead = inputStream.read(buffer)
|
||||||
return TRANSPORT_PACKAGE_REJECTED
|
if (bytesRead == -1) {
|
||||||
} finally {
|
Log.i(TAG, " EOF")
|
||||||
closeQuietly(outputStream)
|
// close input stream here as we won't need it anymore
|
||||||
|
closeQuietly(inputStream)
|
||||||
|
return NO_MORE_DATA
|
||||||
|
}
|
||||||
|
outputStream.write(buffer, 0, bytesRead)
|
||||||
|
return bytesRead
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -337,8 +337,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
|
||||||
every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream
|
every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream
|
||||||
|
|
||||||
// restore data
|
// restore data
|
||||||
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
assertEquals(appData.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
|
||||||
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
restore.finishRestore()
|
restore.finishRestore()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ import android.app.backup.BackupTransport.TRANSPORT_ERROR
|
||||||
import android.app.backup.BackupTransport.TRANSPORT_OK
|
import android.app.backup.BackupTransport.TRANSPORT_OK
|
||||||
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
|
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
|
||||||
import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
|
import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
|
||||||
|
import com.stevesoltys.seedvault.header.VERSION
|
||||||
|
import com.stevesoltys.seedvault.header.getADForFull
|
||||||
import io.mockk.Runs
|
import io.mockk.Runs
|
||||||
import io.mockk.coEvery
|
import io.mockk.coEvery
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
|
@ -25,8 +27,8 @@ internal class FullBackupTest : BackupTest() {
|
||||||
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto)
|
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto)
|
||||||
|
|
||||||
private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
|
private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
|
||||||
private val closeBytes = ByteArray(42).apply { Random.nextBytes(this) }
|
|
||||||
private val inputStream = mockk<FileInputStream>()
|
private val inputStream = mockk<FileInputStream>()
|
||||||
|
private val ad = getADForFull(VERSION, packageInfo.packageName)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `has no initial state`() {
|
fun `has no initial state`() {
|
||||||
|
@ -129,6 +131,7 @@ internal class FullBackupTest : BackupTest() {
|
||||||
expectInitializeOutputStream()
|
expectInitializeOutputStream()
|
||||||
every { settingsManager.isQuotaUnlimited() } returns false
|
every { settingsManager.isQuotaUnlimited() } returns false
|
||||||
every { plugin.getQuota() } returns quota
|
every { plugin.getQuota() } returns quota
|
||||||
|
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
|
||||||
every { inputStream.read(any(), any(), bytes.size) } throws IOException()
|
every { inputStream.read(any(), any(), bytes.size) } throws IOException()
|
||||||
expectClearState()
|
expectClearState()
|
||||||
|
|
||||||
|
@ -183,8 +186,9 @@ internal class FullBackupTest : BackupTest() {
|
||||||
expectInitializeOutputStream()
|
expectInitializeOutputStream()
|
||||||
every { settingsManager.isQuotaUnlimited() } returns false
|
every { settingsManager.isQuotaUnlimited() } returns false
|
||||||
every { plugin.getQuota() } returns quota
|
every { plugin.getQuota() } returns quota
|
||||||
|
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
|
||||||
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
|
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
|
||||||
every { crypto.encryptSegment(outputStream, any()) } throws IOException()
|
every { encryptedOutputStream.write(any<ByteArray>()) } throws IOException()
|
||||||
expectClearState()
|
expectClearState()
|
||||||
|
|
||||||
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
|
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
|
||||||
|
@ -256,8 +260,7 @@ internal class FullBackupTest : BackupTest() {
|
||||||
expectInitializeOutputStream()
|
expectInitializeOutputStream()
|
||||||
val numBytes = 42
|
val numBytes = 42
|
||||||
expectSendData(numBytes)
|
expectSendData(numBytes)
|
||||||
every { outputStream.write(closeBytes) } just Runs
|
every { encryptedOutputStream.flush() } throws IOException()
|
||||||
every { outputStream.flush() } throws IOException()
|
|
||||||
|
|
||||||
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
|
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
|
||||||
assertTrue(backup.hasState())
|
assertTrue(backup.hasState())
|
||||||
|
@ -314,18 +317,18 @@ internal class FullBackupTest : BackupTest() {
|
||||||
private fun expectInitializeOutputStream() {
|
private fun expectInitializeOutputStream() {
|
||||||
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
|
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
|
||||||
every { headerWriter.writeVersion(outputStream, header) } just Runs
|
every { headerWriter.writeVersion(outputStream, header) } just Runs
|
||||||
every { crypto.encryptHeader(outputStream, header) } just Runs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {
|
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {
|
||||||
every { plugin.getQuota() } returns quota
|
every { plugin.getQuota() } returns quota
|
||||||
every { inputStream.read(any(), any(), numBytes) } returns readBytes
|
every { inputStream.read(any(), any(), numBytes) } returns readBytes
|
||||||
every { crypto.encryptSegment(outputStream, any()) } just Runs
|
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
|
||||||
|
every { encryptedOutputStream.write(any<ByteArray>()) } just Runs
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun expectClearState() {
|
private fun expectClearState() {
|
||||||
every { outputStream.write(closeBytes) } just Runs
|
every { encryptedOutputStream.flush() } just Runs
|
||||||
every { outputStream.flush() } just Runs
|
every { encryptedOutputStream.close() } just Runs
|
||||||
every { outputStream.close() } just Runs
|
every { outputStream.close() } just Runs
|
||||||
every { inputStream.close() } just Runs
|
every { inputStream.close() } just Runs
|
||||||
every { data.close() } just Runs
|
every { data.close() } just Runs
|
||||||
|
|
|
@ -6,9 +6,12 @@ import android.app.backup.BackupTransport.TRANSPORT_OK
|
||||||
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
|
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
|
||||||
import com.stevesoltys.seedvault.coAssertThrows
|
import com.stevesoltys.seedvault.coAssertThrows
|
||||||
import com.stevesoltys.seedvault.getRandomByteArray
|
import com.stevesoltys.seedvault.getRandomByteArray
|
||||||
|
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
|
||||||
import com.stevesoltys.seedvault.header.UnsupportedVersionException
|
import com.stevesoltys.seedvault.header.UnsupportedVersionException
|
||||||
import com.stevesoltys.seedvault.header.VERSION
|
import com.stevesoltys.seedvault.header.VERSION
|
||||||
import com.stevesoltys.seedvault.header.VersionHeader
|
import com.stevesoltys.seedvault.header.VersionHeader
|
||||||
|
import com.stevesoltys.seedvault.header.getADForFull
|
||||||
|
import io.mockk.CapturingSlot
|
||||||
import io.mockk.Runs
|
import io.mockk.Runs
|
||||||
import io.mockk.coEvery
|
import io.mockk.coEvery
|
||||||
import io.mockk.every
|
import io.mockk.every
|
||||||
|
@ -20,9 +23,10 @@ import org.junit.jupiter.api.Assertions.assertEquals
|
||||||
import org.junit.jupiter.api.Assertions.assertFalse
|
import org.junit.jupiter.api.Assertions.assertFalse
|
||||||
import org.junit.jupiter.api.Assertions.assertTrue
|
import org.junit.jupiter.api.Assertions.assertTrue
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
import java.io.ByteArrayInputStream
|
||||||
import java.io.ByteArrayOutputStream
|
import java.io.ByteArrayOutputStream
|
||||||
import java.io.EOFException
|
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
|
import java.security.GeneralSecurityException
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@Suppress("BlockingMethodInNonBlockingContext")
|
@Suppress("BlockingMethodInNonBlockingContext")
|
||||||
|
@ -33,7 +37,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
private val encrypted = getRandomByteArray()
|
private val encrypted = getRandomByteArray()
|
||||||
private val outputStream = ByteArrayOutputStream()
|
private val outputStream = ByteArrayOutputStream()
|
||||||
private val versionHeader = VersionHeader(VERSION, packageInfo.packageName)
|
private val ad = getADForFull(VERSION, packageInfo.packageName)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `has no initial state`() {
|
fun `has no initial state`() {
|
||||||
|
@ -67,6 +71,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
TRANSPORT_PACKAGE_REJECTED,
|
TRANSPORT_PACKAGE_REJECTED,
|
||||||
|
@ -80,6 +85,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } throws IOException()
|
every { headerReader.readVersion(inputStream) } throws IOException()
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
TRANSPORT_PACKAGE_REJECTED,
|
TRANSPORT_PACKAGE_REJECTED,
|
||||||
|
@ -95,6 +101,7 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
every {
|
every {
|
||||||
headerReader.readVersion(inputStream)
|
headerReader.readVersion(inputStream)
|
||||||
} throws UnsupportedVersionException(unsupportedVersion)
|
} throws UnsupportedVersionException(unsupportedVersion)
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
TRANSPORT_PACKAGE_REJECTED,
|
TRANSPORT_PACKAGE_REJECTED,
|
||||||
|
@ -103,18 +110,13 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `decrypting version header when getting first chunk throws`() = runBlocking {
|
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream) } returns VERSION
|
||||||
every {
|
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
|
||||||
crypto.decryptHeader(
|
every { fileDescriptor.close() } just Runs
|
||||||
inputStream,
|
|
||||||
VERSION,
|
|
||||||
packageInfo.packageName
|
|
||||||
)
|
|
||||||
} throws IOException()
|
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
TRANSPORT_PACKAGE_REJECTED,
|
TRANSPORT_PACKAGE_REJECTED,
|
||||||
|
@ -123,54 +125,20 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `decrypting version header when getting first chunk throws security exception`() =
|
fun `getting decrypted stream when getting first chunk throws general security exception`() =
|
||||||
runBlocking {
|
runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream) } returns VERSION
|
||||||
every {
|
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
|
||||||
crypto.decryptHeader(
|
every { fileDescriptor.close() } just Runs
|
||||||
inputStream,
|
|
||||||
VERSION,
|
|
||||||
packageInfo.packageName
|
|
||||||
)
|
|
||||||
} throws SecurityException()
|
|
||||||
|
|
||||||
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `decrypting segment throws IOException`() = runBlocking {
|
fun `full chunk gets decrypted`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
|
||||||
|
|
||||||
initInputStream()
|
|
||||||
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
|
||||||
every { crypto.decryptSegment(inputStream) } throws IOException()
|
|
||||||
every { inputStream.close() } just Runs
|
|
||||||
every { fileDescriptor.close() } just Runs
|
|
||||||
|
|
||||||
assertEquals(
|
|
||||||
TRANSPORT_PACKAGE_REJECTED,
|
|
||||||
restore.getNextFullRestoreDataChunk(fileDescriptor)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun `decrypting segment throws EOFException`() = runBlocking {
|
|
||||||
restore.initializeState(token, packageInfo)
|
|
||||||
|
|
||||||
initInputStream()
|
|
||||||
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
|
||||||
every { crypto.decryptSegment(inputStream) } throws EOFException()
|
|
||||||
every { inputStream.close() } just Runs
|
|
||||||
every { fileDescriptor.close() } just Runs
|
|
||||||
|
|
||||||
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun `full chunk gets encrypted`() = runBlocking {
|
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
initInputStream()
|
initInputStream()
|
||||||
|
@ -183,6 +151,50 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
assertFalse(restore.hasState())
|
assertFalse(restore.hasState())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `full chunk gets decrypted from version 0`() = runBlocking {
|
||||||
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
|
every { headerReader.readVersion(inputStream) } returns 0.toByte()
|
||||||
|
every {
|
||||||
|
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
|
||||||
|
} returns VersionHeader(0.toByte(), packageInfo.packageName)
|
||||||
|
every { crypto.decryptSegment(inputStream) } returns encrypted
|
||||||
|
|
||||||
|
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
every { inputStream.close() } just Runs
|
||||||
|
|
||||||
|
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertArrayEquals(encrypted, outputStream.toByteArray())
|
||||||
|
restore.finishRestore()
|
||||||
|
assertFalse(restore.hasState())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
|
||||||
|
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
|
||||||
|
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
|
||||||
|
restore.initializeState(token, packageInfo)
|
||||||
|
|
||||||
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
|
every { headerReader.readVersion(inputStream) } returns VERSION
|
||||||
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
|
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
||||||
|
every { fileDescriptor.close() } just Runs
|
||||||
|
every { inputStream.close() } just Runs
|
||||||
|
|
||||||
|
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertEquals(1, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
|
||||||
|
assertArrayEquals(encryptedBytes, outputStream.toByteArray())
|
||||||
|
restore.finishRestore()
|
||||||
|
assertFalse(restore.hasState())
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `aborting full restore closes stream, resets state`() = runBlocking {
|
fun `aborting full restore closes stream, resets state`() = runBlocking {
|
||||||
restore.initializeState(token, packageInfo)
|
restore.initializeState(token, packageInfo)
|
||||||
|
@ -201,18 +213,17 @@ internal class FullRestoreTest : RestoreTest() {
|
||||||
private fun initInputStream() {
|
private fun initInputStream() {
|
||||||
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
|
||||||
every { headerReader.readVersion(inputStream) } returns VERSION
|
every { headerReader.readVersion(inputStream) } returns VERSION
|
||||||
every {
|
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
|
||||||
crypto.decryptHeader(
|
|
||||||
inputStream,
|
|
||||||
VERSION,
|
|
||||||
packageInfo.packageName
|
|
||||||
)
|
|
||||||
} returns versionHeader
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
|
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
|
||||||
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
|
||||||
every { crypto.decryptSegment(inputStream) } returns encryptedBytes
|
val slot = CapturingSlot<ByteArray>()
|
||||||
|
every { decryptedInputStream.read(capture(slot)) } answers {
|
||||||
|
encryptedBytes.copyInto(slot.captured)
|
||||||
|
encryptedBytes.size
|
||||||
|
}
|
||||||
|
every { decryptedInputStream.close() } just Runs
|
||||||
every { fileDescriptor.close() } just Runs
|
every { fileDescriptor.close() } just Runs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue