Do full backups with new version 1 with new crypto

Restoring still supports version 0 with old crypto
This commit is contained in:
Torsten Grote 2021-09-09 15:55:13 +02:00 committed by Chirayu Desai
parent 0c3ea7679b
commit f4dc776ed3
7 changed files with 140 additions and 91 deletions

View file

@ -140,6 +140,7 @@ internal class CryptoImpl(
deriveStreamKey(keyManager.getMainKey(), "app data key".toByteArray())
}
@Throws(IOException::class, GeneralSecurityException::class)
override fun newEncryptingStream(
outputStream: OutputStream,
associatedData: ByteArray
@ -147,6 +148,7 @@ internal class CryptoImpl(
return StreamCrypto.newEncryptingStream(key, outputStream, associatedData)
}
@Throws(IOException::class, GeneralSecurityException::class)
override fun newDecryptingStream(
inputStream: InputStream,
associatedData: ByteArray

View file

@ -1,6 +1,7 @@
package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.crypto.GCM_AUTHENTICATION_TAG_LENGTH
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_FULL
import com.stevesoltys.seedvault.crypto.TYPE_BACKUP_KV
import java.nio.ByteBuffer
@ -40,6 +41,15 @@ internal fun getADForKV(version: Byte, packageName: String): ByteArray {
.array()
}
internal fun getADForFull(version: Byte, packageName: String): ByteArray {
val packageNameBytes = packageName.toByteArray()
return ByteBuffer.allocate(2 + packageNameBytes.size)
.put(version)
.put(TYPE_BACKUP_FULL)
.put(packageNameBytes)
.array()
}
internal const val SEGMENT_LENGTH_SIZE: Int = Short.SIZE_BYTES
internal const val MAX_SEGMENT_LENGTH: Int = Short.MAX_VALUE.toInt()
internal const val MAX_SEGMENT_CLEARTEXT_LENGTH: Int =

View file

@ -10,7 +10,9 @@ import android.os.ParcelFileDescriptor
import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.settings.SettingsManager
import libcore.io.IoUtils.closeQuietly
import java.io.EOFException
@ -24,6 +26,9 @@ private class FullBackupState(
val inputStream: InputStream,
var outputStreamInit: (suspend () -> OutputStream)?
) {
/**
* This is an encrypted stream that can be written to directly.
*/
var outputStream: OutputStream? = null
val packageName: String = packageInfo.packageName
var size: Long = 0
@ -120,14 +125,8 @@ internal class FullBackup(
// store version header
val state = this.state ?: throw AssertionError()
val header = VersionHeader(packageName = state.packageName)
try {
headerWriter.writeVersion(outputStream, header)
crypto.encryptHeader(outputStream, header)
} catch (e: IOException) {
Log.e(TAG, "Error writing backup header", e)
throw(e)
}
outputStream
crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName))
} // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK
}
@ -159,11 +158,11 @@ internal class FullBackup(
}()
state.outputStreamInit = null // the stream init lambda is not needed beyond that point
// read backup data, encrypt it and write it to output stream
// read backup data and write it to encrypted output stream
val payload = ByteArray(numBytes)
val read = state.inputStream.read(payload, 0, numBytes)
if (read != numBytes) throw EOFException("Read $read bytes instead of $numBytes.")
crypto.encryptSegment(outputStream, payload)
outputStream.write(payload)
TRANSPORT_OK
} catch (e: IOException) {
Log.e(TAG, "Error handling backup data for ${state.packageName}: ", e)

View file

@ -9,16 +9,21 @@ import android.os.ParcelFileDescriptor
import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.getADForFull
import libcore.io.IoUtils.closeQuietly
import java.io.EOFException
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.security.GeneralSecurityException
private class FullRestoreState(
val token: Long,
val packageInfo: PackageInfo
) {
var version: Byte? = null
var inputStream: InputStream? = null
}
@ -81,7 +86,7 @@ internal class FullRestore(
* Any other negative value such as [TRANSPORT_ERROR] is treated as a fatal error condition
* that aborts all further restore operations on the current dataset.
*/
suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int {
suspend fun getNextFullRestoreDataChunk(socket: ParcelFileDescriptor): Int = socket.use { pfd ->
val state = this.state ?: throw IllegalStateException("no state")
val packageName = state.packageInfo.packageName
@ -90,33 +95,48 @@ internal class FullRestore(
try {
val inputStream = plugin.getInputStreamForPackage(state.token, state.packageInfo)
val version = headerReader.readVersion(inputStream)
state.version = version
if (version == 0.toByte()) {
crypto.decryptHeader(inputStream, version, packageName)
state.inputStream = inputStream
} else {
val ad = getADForFull(version, packageName)
state.inputStream = crypto.newDecryptingStream(inputStream, ad)
}
} catch (e: IOException) {
Log.w(TAG, "Error getting input stream for $packageName", e)
return TRANSPORT_PACKAGE_REJECTED
} catch (e: SecurityException) {
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
return TRANSPORT_ERROR
} catch (e: GeneralSecurityException) {
Log.e(TAG, "Security Exception while getting input stream for $packageName", e)
return TRANSPORT_ERROR
} catch (e: UnsupportedVersionException) {
Log.e(TAG, "Backup data for $packageName uses unsupported version ${e.version}.", e)
return TRANSPORT_PACKAGE_REJECTED
}
}
return readInputStream(socket)
return outputFactory.getOutputStream(pfd).use { outputStream ->
try {
copyInputStream(outputStream)
} catch (e: IOException) {
Log.w(TAG, "Error copying stream for package $packageName.", e)
return TRANSPORT_PACKAGE_REJECTED
}
}
}
private fun readInputStream(socket: ParcelFileDescriptor): Int = socket.use { fileDescriptor ->
@Throws(IOException::class)
private fun copyInputStream(outputStream: OutputStream): Int {
val state = this.state ?: throw IllegalStateException("no state")
val packageName = state.packageInfo.packageName
val inputStream = state.inputStream ?: throw IllegalStateException("no stream")
val outputStream = outputFactory.getOutputStream(fileDescriptor)
val version = state.version ?: throw IllegalStateException("no version")
try {
if (version == 0.toByte()) {
// read segment from input stream and decrypt it
val decrypted = try {
// TODO handle IOException
crypto.decryptSegment(inputStream)
} catch (e: EOFException) {
Log.i(TAG, " EOF")
@ -129,12 +149,17 @@ internal class FullRestore(
outputStream.write(decrypted)
// return number of written bytes
return decrypted.size
} catch (e: IOException) {
Log.w(TAG, "Error processing stream for package $packageName.", e)
} else {
val buffer = ByteArray(MAX_SEGMENT_LENGTH)
val bytesRead = inputStream.read(buffer)
if (bytesRead == -1) {
Log.i(TAG, " EOF")
// close input stream here as we won't need it anymore
closeQuietly(inputStream)
return TRANSPORT_PACKAGE_REJECTED
} finally {
closeQuietly(outputStream)
return NO_MORE_DATA
}
outputStream.write(buffer, 0, bytesRead)
return bytesRead
}
}

View file

@ -337,8 +337,7 @@ internal class CoordinatorIntegrationTest : TransportTest() {
every { outputFactory.getOutputStream(fileDescriptor) } returns rOutputStream
// restore data
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(appData.size / 2, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(appData.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
restore.finishRestore()

View file

@ -4,6 +4,8 @@ import android.app.backup.BackupTransport.TRANSPORT_ERROR
import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import android.app.backup.BackupTransport.TRANSPORT_QUOTA_EXCEEDED
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
@ -25,8 +27,8 @@ internal class FullBackupTest : BackupTest() {
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto)
private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
private val closeBytes = ByteArray(42).apply { Random.nextBytes(this) }
private val inputStream = mockk<FileInputStream>()
private val ad = getADForFull(VERSION, packageInfo.packageName)
@Test
fun `has no initial state`() {
@ -129,6 +131,7 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream()
every { settingsManager.isQuotaUnlimited() } returns false
every { plugin.getQuota() } returns quota
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } throws IOException()
expectClearState()
@ -183,8 +186,9 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream()
every { settingsManager.isQuotaUnlimited() } returns false
every { plugin.getQuota() } returns quota
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { inputStream.read(any(), any(), bytes.size) } returns bytes.size
every { crypto.encryptSegment(outputStream, any()) } throws IOException()
every { encryptedOutputStream.write(any<ByteArray>()) } throws IOException()
expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
@ -256,8 +260,7 @@ internal class FullBackupTest : BackupTest() {
expectInitializeOutputStream()
val numBytes = 42
expectSendData(numBytes)
every { outputStream.write(closeBytes) } just Runs
every { outputStream.flush() } throws IOException()
every { encryptedOutputStream.flush() } throws IOException()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
assertTrue(backup.hasState())
@ -314,18 +317,18 @@ internal class FullBackupTest : BackupTest() {
private fun expectInitializeOutputStream() {
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { headerWriter.writeVersion(outputStream, header) } just Runs
every { crypto.encryptHeader(outputStream, header) } just Runs
}
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {
every { plugin.getQuota() } returns quota
every { inputStream.read(any(), any(), numBytes) } returns readBytes
every { crypto.encryptSegment(outputStream, any()) } just Runs
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
every { encryptedOutputStream.write(any<ByteArray>()) } just Runs
}
private fun expectClearState() {
every { outputStream.write(closeBytes) } just Runs
every { outputStream.flush() } just Runs
every { encryptedOutputStream.flush() } just Runs
every { encryptedOutputStream.close() } just Runs
every { outputStream.close() } just Runs
every { inputStream.close() } just Runs
every { data.close() } just Runs

View file

@ -6,9 +6,12 @@ import android.app.backup.BackupTransport.TRANSPORT_OK
import android.app.backup.BackupTransport.TRANSPORT_PACKAGE_REJECTED
import com.stevesoltys.seedvault.coAssertThrows
import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import io.mockk.CapturingSlot
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
@ -20,9 +23,10 @@ import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.EOFException
import java.io.IOException
import java.security.GeneralSecurityException
import kotlin.random.Random
@Suppress("BlockingMethodInNonBlockingContext")
@ -33,7 +37,7 @@ internal class FullRestoreTest : RestoreTest() {
private val encrypted = getRandomByteArray()
private val outputStream = ByteArrayOutputStream()
private val versionHeader = VersionHeader(VERSION, packageInfo.packageName)
private val ad = getADForFull(VERSION, packageInfo.packageName)
@Test
fun `has no initial state`() {
@ -67,6 +71,7 @@ internal class FullRestoreTest : RestoreTest() {
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
@ -80,6 +85,7 @@ internal class FullRestoreTest : RestoreTest() {
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
@ -95,6 +101,7 @@ internal class FullRestoreTest : RestoreTest() {
every {
headerReader.readVersion(inputStream)
} throws UnsupportedVersionException(unsupportedVersion)
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
@ -103,18 +110,13 @@ internal class FullRestoreTest : RestoreTest() {
}
@Test
fun `decrypting version header when getting first chunk throws`() = runBlocking {
fun `getting decrypted stream when getting first chunk throws`() = runBlocking {
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION
every {
crypto.decryptHeader(
inputStream,
VERSION,
packageInfo.packageName
)
} throws IOException()
every { crypto.newDecryptingStream(inputStream, ad) } throws IOException()
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
@ -123,54 +125,20 @@ internal class FullRestoreTest : RestoreTest() {
}
@Test
fun `decrypting version header when getting first chunk throws security exception`() =
fun `getting decrypted stream when getting first chunk throws general security exception`() =
runBlocking {
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION
every {
crypto.decryptHeader(
inputStream,
VERSION,
packageInfo.packageName
)
} throws SecurityException()
every { crypto.newDecryptingStream(inputStream, ad) } throws GeneralSecurityException()
every { fileDescriptor.close() } just Runs
assertEquals(TRANSPORT_ERROR, restore.getNextFullRestoreDataChunk(fileDescriptor))
}
@Test
fun `decrypting segment throws IOException`() = runBlocking {
restore.initializeState(token, packageInfo)
initInputStream()
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } throws IOException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(
TRANSPORT_PACKAGE_REJECTED,
restore.getNextFullRestoreDataChunk(fileDescriptor)
)
}
@Test
fun `decrypting segment throws EOFException`() = runBlocking {
restore.initializeState(token, packageInfo)
initInputStream()
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } throws EOFException()
every { inputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
}
@Test
fun `full chunk gets encrypted`() = runBlocking {
fun `full chunk gets decrypted`() = runBlocking {
restore.initializeState(token, packageInfo)
initInputStream()
@ -183,6 +151,50 @@ internal class FullRestoreTest : RestoreTest() {
assertFalse(restore.hasState())
}
@Test
fun `full chunk gets decrypted from version 0`() = runBlocking {
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns 0.toByte()
every {
crypto.decryptHeader(inputStream, 0.toByte(), packageInfo.packageName)
} returns VersionHeader(0.toByte(), packageInfo.packageName)
every { crypto.decryptSegment(inputStream) } returns encrypted
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(encrypted.size, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encrypted, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState())
}
@Test
fun `three full chunk get decrypted and then return no more data`() = runBlocking {
val encryptedBytes = Random.nextBytes(MAX_SEGMENT_LENGTH * 2 + 1)
val decryptedInputStream = ByteArrayInputStream(encryptedBytes)
restore.initializeState(token, packageInfo)
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { fileDescriptor.close() } just Runs
every { inputStream.close() } just Runs
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(MAX_SEGMENT_LENGTH, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(1, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertEquals(NO_MORE_DATA, restore.getNextFullRestoreDataChunk(fileDescriptor))
assertArrayEquals(encryptedBytes, outputStream.toByteArray())
restore.finishRestore()
assertFalse(restore.hasState())
}
@Test
fun `aborting full restore closes stream, resets state`() = runBlocking {
restore.initializeState(token, packageInfo)
@ -201,18 +213,17 @@ internal class FullRestoreTest : RestoreTest() {
private fun initInputStream() {
coEvery { plugin.getInputStreamForPackage(token, packageInfo) } returns inputStream
every { headerReader.readVersion(inputStream) } returns VERSION
every {
crypto.decryptHeader(
inputStream,
VERSION,
packageInfo.packageName
)
} returns versionHeader
every { crypto.newDecryptingStream(inputStream, ad) } returns decryptedInputStream
}
private fun readAndEncryptInputStream(encryptedBytes: ByteArray) {
every { outputFactory.getOutputStream(fileDescriptor) } returns outputStream
every { crypto.decryptSegment(inputStream) } returns encryptedBytes
val slot = CapturingSlot<ByteArray>()
every { decryptedInputStream.read(capture(slot)) } answers {
encryptedBytes.copyInto(slot.captured)
encryptedBytes.size
}
every { decryptedInputStream.close() } just Runs
every { fileDescriptor.close() } just Runs
}