Clean up after moving all backup code to new v1 version

This commit is contained in:
Torsten Grote 2021-09-09 17:06:40 +02:00 committed by Chirayu Desai
parent f4dc776ed3
commit aeafc80bb9
24 changed files with 71 additions and 454 deletions

View file

@ -1,33 +0,0 @@
package com.stevesoltys.seedvault
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.LargeTest
import com.stevesoltys.seedvault.crypto.CipherFactoryImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
private val TAG = CipherUniqueNonceTest::class.java.simpleName
private const val ITERATIONS = 1_000_000
@LargeTest
@RunWith(AndroidJUnit4::class)
class CipherUniqueNonceTest {
private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager)
private val nonceSet = HashSet<ByteArray>()
@Test
fun testUniqueNonce() {
for (i in 1..ITERATIONS) {
val iv = cipherFactory.createEncryptionCipher().iv
Log.w(TAG, "$i: ${iv.toHexString()}")
assertTrue(nonceSet.add(iv))
}
}
}

View file

@ -2,8 +2,6 @@ package com.stevesoltys.seedvault.crypto
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming
import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.MAX_VERSION_HEADER_SIZE
import com.stevesoltys.seedvault.header.SegmentHeader
@ -15,78 +13,46 @@ import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.security.GeneralSecurityException
import javax.crypto.Cipher
import javax.crypto.spec.SecretKeySpec
import kotlin.math.min
/**
* A backup stream starts with a version byte followed by an encrypted [VersionHeader].
* A version 1 backup stream uses [AesGcmHkdfStreaming] from the tink library.
*
* A version 0 backup stream starts with a version byte followed by an encrypted [VersionHeader].
*
* The header will be encrypted with AES/GCM to provide authentication.
* It can be written using [encryptHeader] and read using [decryptHeader].
* The latter throws a [SecurityException],
* It can be read using [decryptHeader] which throws a [SecurityException],
* if the expected version and package name do not match the encrypted header.
*
* After the header, follows one or more data segments.
* Each segment begins with a clear-text [SegmentHeader]
* that contains the length of the segment
* and a nonce acting as the initialization vector for the encryption.
* The segment can be written using [encryptSegment] and read using [decryptSegment].
* The latter throws a [SecurityException],
* The segment can be read using [decryptSegment] which throws a [SecurityException],
* if the length of the segment is specified larger than allowed.
*/
interface Crypto {
internal interface Crypto {
/**
* Returns a [AesGcmHkdfStreaming] encrypting stream
* that gets encrypted with the given secret.
* that gets encrypted and authenticated the given associated data.
*/
@Throws(IOException::class, GeneralSecurityException::class)
fun newEncryptingStream(
outputStream: OutputStream,
associatedData: ByteArray = ByteArray(0)
associatedData: ByteArray
): OutputStream
/**
* Returns a [AesGcmHkdfStreaming] decrypting stream
* that gets decrypted and authenticated the given associated data.
*/
@Throws(IOException::class, GeneralSecurityException::class)
fun newDecryptingStream(
inputStream: InputStream,
associatedData: ByteArray = ByteArray(0)
associatedData: ByteArray
): InputStream
/**
* Encrypts a backup stream header ([VersionHeader]) and writes it to the given [OutputStream].
*
* The header using a small segment containing only
* the version number, the package name and (optionally) the key of a key/value stream.
*/
@Throws(IOException::class)
fun encryptHeader(outputStream: OutputStream, versionHeader: VersionHeader)
/**
* Encrypts a new backup segment from the given cleartext payload
* and writes it to the given [OutputStream].
*
* A segment starts with a [SegmentHeader] which includes the length of the segment
* and a nonce which is used as initialization vector for the encryption.
*
* After the header follows the encrypted payload.
* Larger backup streams such as from a full backup are encrypted in multiple segments
* to avoid having to load the entire stream into memory when doing authenticated encryption.
*
* The given cleartext can later be decrypted
* by calling [decryptSegment] on the same byte stream.
*/
@Throws(IOException::class)
fun encryptSegment(outputStream: OutputStream, cleartext: ByteArray)
/**
* Like [encryptSegment],
* but if the given cleartext [ByteArray] is larger than [MAX_SEGMENT_CLEARTEXT_LENGTH],
* multiple segments will be written.
*/
@Throws(IOException::class)
fun encryptMultipleSegments(outputStream: OutputStream, cleartext: ByteArray)
/**
* Reads and decrypts a [VersionHeader] from the given [InputStream]
* and ensures that the expected version, package name and key match
@ -95,6 +61,7 @@ interface Crypto {
*
* @return The read [VersionHeader] present in the beginning of the given [InputStream].
*/
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class)
fun decryptHeader(
inputStream: InputStream,
@ -106,14 +73,16 @@ interface Crypto {
/**
* Reads and decrypts a segment from the given [InputStream].
*
* @return The decrypted segment payload as passed into [encryptSegment]
* @return The decrypted segment payload.
*/
@Deprecated("Use newDecryptingStream instead")
@Throws(EOFException::class, IOException::class, SecurityException::class)
fun decryptSegment(inputStream: InputStream): ByteArray
/**
* Like [decryptSegment], but decrypts multiple segments and does not throw [EOFException].
*/
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class)
fun decryptMultipleSegments(inputStream: InputStream): ByteArray
@ -132,7 +101,6 @@ internal const val TYPE_BACKUP_FULL: Byte = 0x02
internal class CryptoImpl(
private val keyManager: KeyManager,
private val cipherFactory: CipherFactory,
private val headerWriter: HeaderWriter,
private val headerReader: HeaderReader
) : Crypto {
@ -156,45 +124,8 @@ internal class CryptoImpl(
return StreamCrypto.newDecryptingStream(key, inputStream, associatedData)
}
@Throws(IOException::class)
override fun encryptHeader(outputStream: OutputStream, versionHeader: VersionHeader) {
val bytes = headerWriter.getEncodedVersionHeader(versionHeader)
encryptSegment(outputStream, bytes)
}
@Throws(IOException::class)
override fun encryptSegment(outputStream: OutputStream, cleartext: ByteArray) {
val cipher = cipherFactory.createEncryptionCipher()
check(cipher.getOutputSize(cleartext.size) <= MAX_SEGMENT_LENGTH) {
"Cipher's output size ${cipher.getOutputSize(cleartext.size)} is larger" +
"than maximum segment length ($MAX_SEGMENT_LENGTH)"
}
encryptSegment(cipher, outputStream, cleartext)
}
@Throws(IOException::class)
override fun encryptMultipleSegments(outputStream: OutputStream, cleartext: ByteArray) {
var end = 0
while (end < cleartext.size) {
val start = end
end = min(cleartext.size, start + MAX_SEGMENT_CLEARTEXT_LENGTH)
val segment = cleartext.copyOfRange(start, end)
val cipher = cipherFactory.createEncryptionCipher()
encryptSegment(cipher, outputStream, segment)
}
}
@Throws(IOException::class)
private fun encryptSegment(cipher: Cipher, outputStream: OutputStream, segment: ByteArray) {
val encrypted = cipher.doFinal(segment)
val segmentHeader = SegmentHeader(encrypted.size.toShort(), cipher.iv)
headerWriter.writeSegmentHeader(outputStream, segmentHeader)
outputStream.write(encrypted)
}
@Throws(IOException::class, SecurityException::class)
@Deprecated("Use newDecryptingStream instead")
override fun decryptHeader(
inputStream: InputStream,
expectedVersion: Byte,
@ -223,11 +154,13 @@ internal class CryptoImpl(
return header
}
@Deprecated("Use newDecryptingStream instead")
@Throws(EOFException::class, IOException::class, SecurityException::class)
override fun decryptSegment(inputStream: InputStream): ByteArray {
return decryptSegment(inputStream, MAX_SEGMENT_LENGTH)
}
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class)
override fun decryptMultipleSegments(inputStream: InputStream): ByteArray {
var result = ByteArray(0)

View file

@ -15,5 +15,5 @@ val cryptoModule = module {
}
KeyManagerImpl(keyStore)
}
single<Crypto> { CryptoImpl(get(), get(), get(), get()) }
single<Crypto> { CryptoImpl(get(), get(), get()) }
}

View file

@ -15,7 +15,8 @@ internal const val MAX_VERSION_HEADER_SIZE =
* After the first version byte of each backup stream
* must follow followed this header encrypted with authentication.
*/
data class VersionHeader(
@Deprecated("version header is in associated data now")
internal data class VersionHeader(
internal val version: Byte = VERSION, // 1 byte
internal val packageName: String, // ?? bytes (max 255)
internal val key: String? = null // ?? bytes
@ -60,6 +61,7 @@ internal const val SEGMENT_HEADER_SIZE = SEGMENT_LENGTH_SIZE + IV_SIZE
/**
* Each data segment must start with this header
*/
@Deprecated("Don't do manual segments, use Crypto interface instead.")
class SegmentHeader(
internal val segmentLength: Short, // 2 bytes
internal val nonce: ByteArray // 12 bytes

View file

@ -3,6 +3,5 @@ package com.stevesoltys.seedvault.header
import org.koin.dsl.module
val headerModule = module {
single<HeaderWriter> { HeaderWriterImpl() }
single<HeaderReader> { HeaderReaderImpl() }
}

View file

@ -6,13 +6,15 @@ import java.io.IOException
import java.io.InputStream
import java.nio.ByteBuffer
interface HeaderReader {
internal interface HeaderReader {
@Throws(IOException::class, UnsupportedVersionException::class)
fun readVersion(inputStream: InputStream): Byte
@Deprecated("")
@Throws(SecurityException::class)
fun getVersionHeader(byteArray: ByteArray): VersionHeader
@Deprecated("")
@Throws(EOFException::class, IOException::class)
fun readSegmentHeader(inputStream: InputStream): SegmentHeader
}

View file

@ -1,51 +0,0 @@
package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.Utf8
import java.io.IOException
import java.io.OutputStream
import java.nio.ByteBuffer
interface HeaderWriter {
@Throws(IOException::class)
fun writeVersion(outputStream: OutputStream, header: VersionHeader)
fun getEncodedVersionHeader(header: VersionHeader): ByteArray
@Throws(IOException::class)
fun writeSegmentHeader(outputStream: OutputStream, header: SegmentHeader)
}
internal class HeaderWriterImpl : HeaderWriter {
@Throws(IOException::class)
override fun writeVersion(outputStream: OutputStream, header: VersionHeader) {
val headerBytes = ByteArray(1)
headerBytes[0] = header.version
outputStream.write(headerBytes)
}
override fun getEncodedVersionHeader(header: VersionHeader): ByteArray {
val packageBytes = header.packageName.toByteArray(Utf8)
val keyBytes = header.key?.toByteArray(Utf8)
val size = 1 + 2 + packageBytes.size + 2 + (keyBytes?.size ?: 0)
return ByteBuffer.allocate(size).apply {
put(header.version)
putShort(packageBytes.size.toShort())
put(packageBytes)
if (keyBytes == null) {
putShort(0.toShort())
} else {
putShort(keyBytes.size.toShort())
put(keyBytes)
}
}.array()
}
override fun writeSegmentHeader(outputStream: OutputStream, header: SegmentHeader) {
val buffer = ByteBuffer.allocate(SEGMENT_HEADER_SIZE)
.putShort(header.segmentLength)
.put(header.nonce)
outputStream.write(buffer.array())
}
}

View file

@ -23,7 +23,6 @@ val backupModule = module {
plugin = get<BackupPlugin>().kvBackupPlugin,
settingsManager = get(),
inputFactory = get(),
headerWriter = get(),
crypto = get(),
nm = get()
)
@ -33,7 +32,6 @@ val backupModule = module {
plugin = get<BackupPlugin>().fullBackupPlugin,
settingsManager = get(),
inputFactory = get(),
headerWriter = get(),
crypto = get()
)
}

View file

@ -9,9 +9,7 @@ import android.content.pm.PackageInfo
import android.os.ParcelFileDescriptor
import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.settings.SettingsManager
import libcore.io.IoUtils.closeQuietly
@ -43,7 +41,6 @@ internal class FullBackup(
private val plugin: FullBackupPlugin,
private val settingsManager: SettingsManager,
private val inputFactory: InputFactory,
private val headerWriter: HeaderWriter,
private val crypto: Crypto
) {
@ -124,8 +121,7 @@ internal class FullBackup(
}
// store version header
val state = this.state ?: throw AssertionError()
val header = VersionHeader(packageName = state.packageName)
headerWriter.writeVersion(outputStream, header)
outputStream.write(ByteArray(1) { VERSION })
crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName))
} // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK

View file

@ -12,9 +12,7 @@ import android.util.Log
import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.settings.SettingsManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
@ -31,7 +29,6 @@ internal class KVBackup(
private val plugin: KVBackupPlugin,
private val settingsManager: SettingsManager,
private val inputFactory: InputFactory,
private val headerWriter: HeaderWriter,
private val crypto: Crypto,
private val nm: BackupNotificationManager
) {
@ -168,11 +165,7 @@ internal class KVBackup(
plugin.deleteRecord(packageInfo, op.base64Key)
} else {
plugin.getOutputStreamForRecord(packageInfo, op.base64Key).use { outputStream ->
val header = VersionHeader(
packageName = packageInfo.packageName,
key = op.key
)
headerWriter.writeVersion(outputStream, header)
outputStream.write(ByteArray(1) { VERSION })
val ad = getADForKV(VERSION, packageInfo.packageName)
crypto.newEncryptingStream(outputStream, ad).use { encryptedStream ->
encryptedStream.write(op.value)

View file

@ -21,7 +21,7 @@ class TestApp : App() {
private val testCryptoModule = module {
factory<CipherFactory> { CipherFactoryImpl(get()) }
single<KeyManager> { KeyManagerTestImpl() }
single<Crypto> { CryptoImpl(get(), get(), get(), get()) }
single<Crypto> { CryptoImpl(get(), get(), get()) }
}
private val appModule = module {
single { Clock() }

View file

@ -1,57 +1,22 @@
package com.stevesoltys.seedvault.crypto
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.IV_SIZE
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import javax.crypto.Cipher
import kotlin.random.Random
@TestInstance(PER_METHOD)
class CryptoImplTest {
private val keyManager = mockk<KeyManager>()
private val cipherFactory = mockk<CipherFactory>()
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cipher = mockk<Cipher>()
private val iv = ByteArray(IV_SIZE).apply { Random.nextBytes(this) }
private val cleartext = ByteArray(Random.nextInt(Short.MAX_VALUE.toInt()))
.apply { Random.nextBytes(this) }
private val ciphertext = ByteArray(Random.nextInt(Short.MAX_VALUE.toInt()))
.apply { Random.nextBytes(this) }
private val outputStream = ByteArrayOutputStream()
@Test
fun `encrypted cleartext gets decrypted as expected`() {
every { cipherFactory.createEncryptionCipher() } returns cipher
every { cipher.getOutputSize(cleartext.size) } returns MAX_SEGMENT_LENGTH
every { cipher.doFinal(cleartext) } returns ciphertext
every { cipher.iv } returns iv
crypto.encryptSegment(outputStream, cleartext)
val inputStream = ByteArrayInputStream(outputStream.toByteArray())
every { cipherFactory.createDecryptionCipher(iv) } returns cipher
every { cipher.doFinal(ciphertext) } returns cleartext
assertArrayEquals(cleartext, crypto.decryptSegment(inputStream))
}
private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
@Test
fun `decrypting multiple segments on empty stream throws`() {

View file

@ -1,16 +1,14 @@
package com.stevesoltys.seedvault.crypto
import com.stevesoltys.seedvault.assertReadEquals
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import kotlin.random.Random
@TestInstance(PER_METHOD)
@ -18,50 +16,31 @@ class CryptoIntegrationTest {
private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cleartext = byteArrayOf(0x01, 0x02, 0x03)
private val outputStream = ByteArrayOutputStream()
private val cleartext = Random.nextBytes(Random.nextInt(1, 422300))
@Test
fun `the plain crypto works`() {
val eCipher = cipherFactory.createEncryptionCipher()
val encrypted = eCipher.doFinal(cleartext)
val dCipher = cipherFactory.createDecryptionCipher(eCipher.iv)
val decrypted = dCipher.doFinal(encrypted)
assertArrayEquals(cleartext, decrypted)
}
@Test
fun `encrypted cleartext gets decrypted as expected`() {
crypto.encryptSegment(outputStream, cleartext)
fun `decrypting encrypted cleartext works`() {
val ad = Random.nextBytes(42)
val outputStream = ByteArrayOutputStream()
crypto.newEncryptingStream(outputStream, ad).use { it.write(cleartext) }
val inputStream = ByteArrayInputStream(outputStream.toByteArray())
assertArrayEquals(cleartext, crypto.decryptSegment(inputStream))
crypto.newDecryptingStream(inputStream, ad).use {
assertReadEquals(cleartext, it)
}
}
@Test
fun `multiple segments get encrypted and decrypted as expected`() {
val size = Random.nextInt(5) * MAX_SEGMENT_CLEARTEXT_LENGTH + Random.nextInt(0, 1337)
val cleartext = ByteArray(size).apply { Random.nextBytes(this) }
crypto.encryptMultipleSegments(outputStream, cleartext)
fun `decrypting encrypted cleartext fails with different AD`() {
val outputStream = ByteArrayOutputStream()
crypto.newEncryptingStream(outputStream, Random.nextBytes(42)).use { it.write(cleartext) }
val inputStream = ByteArrayInputStream(outputStream.toByteArray())
assertArrayEquals(cleartext, crypto.decryptMultipleSegments(inputStream))
}
@Test
fun `test maximum lengths`() {
val cipher = cipherFactory.createEncryptionCipher()
val expectedDiff = MAX_SEGMENT_LENGTH - MAX_SEGMENT_CLEARTEXT_LENGTH
for (i in 1..(3 * MAX_SEGMENT_LENGTH + 42)) {
val outputSize = cipher.getOutputSize(i)
assertEquals(expectedDiff, outputSize - i)
assertThrows(IOException::class.java) {
crypto.newDecryptingStream(inputStream, Random.nextBytes(41)).use {
it.read()
}
}
}

View file

@ -4,7 +4,6 @@ import com.stevesoltys.seedvault.assertContains
import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.IV_SIZE
import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE
import com.stevesoltys.seedvault.header.MAX_PACKAGE_LENGTH_SIZE
@ -13,10 +12,7 @@ import com.stevesoltys.seedvault.header.MAX_VERSION_HEADER_SIZE
import com.stevesoltys.seedvault.header.SegmentHeader
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import io.mockk.CapturingSlot
import io.mockk.Runs
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
@ -26,7 +22,6 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.EOFException
import java.io.IOException
import java.io.InputStream
@ -38,10 +33,9 @@ class CryptoTest {
private val keyManager = mockk<KeyManager>()
private val cipherFactory = mockk<CipherFactory>()
private val headerWriter = mockk<HeaderWriter>()
private val headerReader = mockk<HeaderReader>()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
private val cipher = mockk<Cipher>()
@ -55,49 +49,12 @@ class CryptoTest {
)
private val versionCiphertext = getRandomByteArray(MAX_VERSION_HEADER_SIZE)
private val versionSegmentHeader = SegmentHeader(versionCiphertext.size.toShort(), iv)
private val outputStream = ByteArrayOutputStream()
private val segmentHeader = SegmentHeader(ciphertext.size.toShort(), iv)
// the headerReader will not actually read the header, so only insert cipher text
private val inputStream = ByteArrayInputStream(ciphertext)
private val versionInputStream = ByteArrayInputStream(versionCiphertext)
// encrypting
@Test
fun `encrypt header works as expected`() {
val segmentHeader = CapturingSlot<SegmentHeader>()
every { headerWriter.getEncodedVersionHeader(versionHeader) } returns ciphertext
encryptSegmentHeader(ciphertext, segmentHeader)
crypto.encryptHeader(outputStream, versionHeader)
assertArrayEquals(iv, segmentHeader.captured.nonce)
assertEquals(ciphertext.size, segmentHeader.captured.segmentLength.toInt())
}
@Test
fun `encrypting segment works as expected`() {
val segmentHeader = CapturingSlot<SegmentHeader>()
encryptSegmentHeader(cleartext, segmentHeader)
crypto.encryptSegment(outputStream, cleartext)
assertArrayEquals(ciphertext, outputStream.toByteArray())
assertArrayEquals(iv, segmentHeader.captured.nonce)
assertEquals(ciphertext.size, segmentHeader.captured.segmentLength.toInt())
}
private fun encryptSegmentHeader(
toEncrypt: ByteArray,
segmentHeader: CapturingSlot<SegmentHeader>
) {
every { cipherFactory.createEncryptionCipher() } returns cipher
every { cipher.getOutputSize(toEncrypt.size) } returns toEncrypt.size
every { cipher.iv } returns iv
every { headerWriter.writeSegmentHeader(outputStream, capture(segmentHeader)) } just Runs
every { cipher.doFinal(toEncrypt) } returns ciphertext
}
// decrypting
@Test

View file

@ -1,104 +0,0 @@
package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.getRandomString
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import kotlin.random.Random
@TestInstance(PER_CLASS)
internal class HeaderWriterReaderTest {
private val writer = HeaderWriterImpl()
private val reader = HeaderReaderImpl()
private val packageName = getRandomString(MAX_PACKAGE_LENGTH_SIZE)
private val key = getRandomString(MAX_KEY_LENGTH_SIZE)
private val versionHeader = VersionHeader(VERSION, packageName, key)
private val unsupportedVersionHeader = VersionHeader((VERSION + 1).toByte(), packageName)
private val segmentLength = getRandomValidSegmentLength()
private val nonce = getRandomByteArray(IV_SIZE)
private val segmentHeader = SegmentHeader(segmentLength, nonce)
@Test
fun `written version matches read input`() {
assertEquals(versionHeader.version, readWriteVersion(versionHeader))
}
@Test
fun `reading unsupported version throws exception`() {
assertThrows(UnsupportedVersionException::class.javaObjectType) {
readWriteVersion(unsupportedVersionHeader)
}
}
@Test
fun `VersionHeader output matches read input`() {
assertEquals(versionHeader, readWrite(versionHeader))
}
@Test
fun `VersionHeader with no key output matches read input`() {
val versionHeader = VersionHeader(VERSION, packageName, null)
assertEquals(versionHeader, readWrite(versionHeader))
}
@Test
fun `VersionHeader with empty package name throws`() {
val versionHeader = VersionHeader(VERSION, "")
assertThrows(SecurityException::class.java) {
readWrite(versionHeader)
}
}
@Test
fun `SegmentHeader constructor needs right IV size`() {
val nonceTooBig = ByteArray(IV_SIZE + 1).apply { Random.nextBytes(this) }
assertThrows(IllegalStateException::class.javaObjectType) {
SegmentHeader(segmentLength, nonceTooBig)
}
val nonceTooSmall = ByteArray(IV_SIZE - 1).apply { Random.nextBytes(this) }
assertThrows(IllegalStateException::class.javaObjectType) {
SegmentHeader(segmentLength, nonceTooSmall)
}
}
@Test
fun `SegmentHeader output matches read input`() {
assertEquals(segmentHeader, readWriteVersion(segmentHeader))
}
private fun readWriteVersion(header: VersionHeader): Byte {
val outputStream = ByteArrayOutputStream()
writer.writeVersion(outputStream, header)
val written = outputStream.toByteArray()
val inputStream = ByteArrayInputStream(written)
return reader.readVersion(inputStream)
}
private fun readWrite(header: VersionHeader): VersionHeader {
val written = writer.getEncodedVersionHeader(header)
return reader.getVersionHeader(written)
}
private fun readWriteVersion(header: SegmentHeader): SegmentHeader {
val outputStream = ByteArrayOutputStream()
writer.writeSegmentHeader(outputStream, header)
val written = outputStream.toByteArray()
val inputStream = ByteArrayInputStream(written)
return reader.readSegmentHeader(inputStream)
}
private fun assertEquals(expected: SegmentHeader, actual: SegmentHeader) {
assertEquals(expected.segmentLength, actual.segmentLength)
assertArrayEquals(expected.nonce, actual.nonce)
}
}

View file

@ -6,7 +6,6 @@ import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA
import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED
@ -27,9 +26,8 @@ internal class MetadataReadWriteTest {
)
private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val writer = MetadataWriterImpl(cryptoImpl)
private val reader = MetadataReaderImpl(cryptoImpl)

View file

@ -5,7 +5,6 @@ import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA
import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED
import com.stevesoltys.seedvault.toByteArrayFromHex
@ -27,9 +26,8 @@ internal class MetadataV0ReadTest {
)
private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val reader = MetadataReaderImpl(cryptoImpl)

View file

@ -10,7 +10,7 @@ import org.junit.Assert.assertTrue
import org.junit.jupiter.api.Test
import kotlin.random.Random
class ApkSplitCompatibilityCheckerTest : TransportTest() {
internal class ApkSplitCompatibilityCheckerTest : TransportTest() {
private val deviceInfo: DeviceInfo = mockk()
private val deviceName = getRandomString()

View file

@ -12,7 +12,6 @@ import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.metadata.PackageMetadata
@ -57,9 +56,8 @@ internal class CoordinatorIntegrationTest : TransportTest() {
private val outputFactory = mockk<OutputFactory>()
private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>()
@ -69,7 +67,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
plugin = kvBackupPlugin,
settingsManager = settingsManager,
inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = cryptoImpl,
nm = notificationManager
)
@ -78,7 +75,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
plugin = fullBackupPlugin,
settingsManager = settingsManager,
inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = cryptoImpl
)
private val apkBackup = mockk<ApkBackup>()

View file

@ -20,7 +20,7 @@ import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import kotlin.random.Random
@TestInstance(PER_METHOD)
abstract class TransportTest {
internal abstract class TransportTest {
protected val clock: Clock = mockk()
protected val crypto = mockk<Crypto>()

View file

@ -1,8 +1,6 @@
package com.stevesoltys.seedvault.transport.backup
import android.os.ParcelFileDescriptor
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.transport.TransportTest
import io.mockk.mockk
import java.io.OutputStream
@ -10,12 +8,10 @@ import java.io.OutputStream
internal abstract class BackupTest : TransportTest() {
protected val inputFactory = mockk<InputFactory>()
protected val headerWriter = mockk<HeaderWriter>()
protected val data = mockk<ParcelFileDescriptor>()
protected val outputStream = mockk<OutputStream>()
protected val encryptedOutputStream = mockk<OutputStream>()
protected val header = VersionHeader(packageName = packageInfo.packageName)
protected val quota = 42L
}

View file

@ -24,7 +24,7 @@ import kotlin.random.Random
internal class FullBackupTest : BackupTest() {
private val plugin = mockk<FullBackupPlugin>()
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto)
private val backup = FullBackup(plugin, settingsManager, inputFactory, crypto)
private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
private val inputStream = mockk<FileInputStream>()
@ -168,7 +168,7 @@ internal class FullBackupTest : BackupTest() {
every { plugin.getQuota() } returns quota
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream
every { headerWriter.writeVersion(outputStream, header) } throws IOException()
every { outputStream.write(ByteArray(1) { VERSION }) } throws IOException()
expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
@ -316,7 +316,7 @@ internal class FullBackupTest : BackupTest() {
private fun expectInitializeOutputStream() {
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { headerWriter.writeVersion(outputStream, header) } just Runs
every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
}
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {

View file

@ -11,7 +11,7 @@ import android.content.pm.PackageInfo
import com.stevesoltys.seedvault.Utf8
import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.CapturingSlot
@ -42,7 +42,6 @@ internal class KVBackupTest : BackupTest() {
plugin = plugin,
settingsManager = settingsManager,
inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = crypto,
nm = notificationManager
)
@ -50,7 +49,6 @@ internal class KVBackupTest : BackupTest() {
private val key = getRandomString(MAX_KEY_LENGTH_SIZE)
private val key64 = Base64.getEncoder().encodeToString(key.toByteArray(Utf8))
private val dataValue = Random.nextBytes(23)
private val versionHeader = VersionHeader(packageName = packageInfo.packageName, key = key)
@Test
fun `has no initial state`() {
@ -81,14 +79,11 @@ internal class KVBackupTest : BackupTest() {
// store first record and show notification for it
every { notificationManager.onPmKvBackup("key1", 1, 2) } just Runs
coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5MQ") } returns outputStream
val versionHeader1 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key1")
every { headerWriter.writeVersion(outputStream, versionHeader1) } just Runs
every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
// store second record and show notification for it
every { notificationManager.onPmKvBackup("key2", 2, 2) } just Runs
coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5Mg") } returns outputStream
val versionHeader2 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key2")
every { headerWriter.writeVersion(outputStream, versionHeader2) } just Runs
// encrypt to and close output stream
every { crypto.newEncryptingStream(outputStream, any()) } returns encryptedOutputStream
@ -213,11 +208,11 @@ internal class KVBackupTest : BackupTest() {
}
@Test
fun `exception while writing version header`() = runBlocking {
fun `exception while writing version`() = runBlocking {
initPlugin(false)
getDataInput(listOf(true))
coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream
every { headerWriter.writeVersion(outputStream, versionHeader) } throws IOException()
every { outputStream.write(ByteArray(1) { VERSION }) } throws IOException()
every { outputStream.close() } just Runs
every { plugin.packageFinished(packageInfo) } just Runs
@ -231,7 +226,7 @@ internal class KVBackupTest : BackupTest() {
fun `exception while writing encrypted value to output stream`() = runBlocking {
initPlugin(false)
getDataInput(listOf(true))
writeHeaderAndEncrypt()
writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } throws IOException()
every { plugin.packageFinished(packageInfo) } just Runs
@ -245,7 +240,7 @@ internal class KVBackupTest : BackupTest() {
fun `exception while flushing output stream`() = runBlocking {
initPlugin(false)
getDataInput(listOf(true))
writeHeaderAndEncrypt()
writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } throws IOException()
every { encryptedOutputStream.close() } just Runs
@ -262,7 +257,7 @@ internal class KVBackupTest : BackupTest() {
fun `ignoring exception while closing output stream`() = runBlocking {
initPlugin(false)
getDataInput(listOf(true, false))
writeHeaderAndEncrypt()
writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } just Runs
every { encryptedOutputStream.close() } just Runs
@ -278,7 +273,7 @@ internal class KVBackupTest : BackupTest() {
private fun singleRecordBackup(hasDataForPackage: Boolean = false) {
initPlugin(hasDataForPackage)
getDataInput(listOf(true, false))
writeHeaderAndEncrypt()
writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } just Runs
every { encryptedOutputStream.close() } just Runs
@ -306,10 +301,10 @@ internal class KVBackupTest : BackupTest() {
}
}
private fun writeHeaderAndEncrypt() {
private fun writeVersionAndEncrypt() {
coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream
every { headerWriter.writeVersion(outputStream, versionHeader) } just Runs
val ad = getADForKV(versionHeader.version, packageInfo.packageName)
every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
val ad = getADForKV(VERSION, packageInfo.packageName)
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
}

View file

@ -12,7 +12,6 @@ import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.transport.TransportTest
@ -42,9 +41,8 @@ internal class RestoreV0IntegrationTest : TransportTest() {
)
private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader)
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>()