Clean up after moving all backup code to new v1 version

This commit is contained in:
Torsten Grote 2021-09-09 17:06:40 +02:00 committed by Chirayu Desai
parent f4dc776ed3
commit aeafc80bb9
24 changed files with 71 additions and 454 deletions

View file

@ -1,33 +0,0 @@
package com.stevesoltys.seedvault
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.LargeTest
import com.stevesoltys.seedvault.crypto.CipherFactoryImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
private val TAG = CipherUniqueNonceTest::class.java.simpleName
private const val ITERATIONS = 1_000_000
@LargeTest
@RunWith(AndroidJUnit4::class)
class CipherUniqueNonceTest {
private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager)
private val nonceSet = HashSet<ByteArray>()
@Test
fun testUniqueNonce() {
for (i in 1..ITERATIONS) {
val iv = cipherFactory.createEncryptionCipher().iv
Log.w(TAG, "$i: ${iv.toHexString()}")
assertTrue(nonceSet.add(iv))
}
}
}

View file

@ -2,8 +2,6 @@ package com.stevesoltys.seedvault.crypto
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming import com.google.crypto.tink.subtle.AesGcmHkdfStreaming
import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import com.stevesoltys.seedvault.header.MAX_VERSION_HEADER_SIZE import com.stevesoltys.seedvault.header.MAX_VERSION_HEADER_SIZE
import com.stevesoltys.seedvault.header.SegmentHeader import com.stevesoltys.seedvault.header.SegmentHeader
@ -15,78 +13,46 @@ import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
import javax.crypto.Cipher
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
import kotlin.math.min
/** /**
* A backup stream starts with a version byte followed by an encrypted [VersionHeader]. * A version 1 backup stream uses [AesGcmHkdfStreaming] from the tink library.
*
* A version 0 backup stream starts with a version byte followed by an encrypted [VersionHeader].
* *
* The header will be encrypted with AES/GCM to provide authentication. * The header will be encrypted with AES/GCM to provide authentication.
* It can be written using [encryptHeader] and read using [decryptHeader]. * It can be read using [decryptHeader] which throws a [SecurityException],
* The latter throws a [SecurityException],
* if the expected version and package name do not match the encrypted header. * if the expected version and package name do not match the encrypted header.
* *
* After the header, follows one or more data segments. * After the header, follows one or more data segments.
* Each segment begins with a clear-text [SegmentHeader] * Each segment begins with a clear-text [SegmentHeader]
* that contains the length of the segment * that contains the length of the segment
* and a nonce acting as the initialization vector for the encryption. * and a nonce acting as the initialization vector for the encryption.
* The segment can be written using [encryptSegment] and read using [decryptSegment]. * The segment can be read using [decryptSegment] which throws a [SecurityException],
* The latter throws a [SecurityException],
* if the length of the segment is specified larger than allowed. * if the length of the segment is specified larger than allowed.
*/ */
interface Crypto { internal interface Crypto {
/** /**
* Returns a [AesGcmHkdfStreaming] encrypting stream * Returns a [AesGcmHkdfStreaming] encrypting stream
* that gets encrypted with the given secret. * that gets encrypted and authenticated the given associated data.
*/ */
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)
fun newEncryptingStream( fun newEncryptingStream(
outputStream: OutputStream, outputStream: OutputStream,
associatedData: ByteArray = ByteArray(0) associatedData: ByteArray
): OutputStream ): OutputStream
/**
* Returns a [AesGcmHkdfStreaming] decrypting stream
* that gets decrypted and authenticated the given associated data.
*/
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)
fun newDecryptingStream( fun newDecryptingStream(
inputStream: InputStream, inputStream: InputStream,
associatedData: ByteArray = ByteArray(0) associatedData: ByteArray
): InputStream ): InputStream
/**
* Encrypts a backup stream header ([VersionHeader]) and writes it to the given [OutputStream].
*
* The header using a small segment containing only
* the version number, the package name and (optionally) the key of a key/value stream.
*/
@Throws(IOException::class)
fun encryptHeader(outputStream: OutputStream, versionHeader: VersionHeader)
/**
* Encrypts a new backup segment from the given cleartext payload
* and writes it to the given [OutputStream].
*
* A segment starts with a [SegmentHeader] which includes the length of the segment
* and a nonce which is used as initialization vector for the encryption.
*
* After the header follows the encrypted payload.
* Larger backup streams such as from a full backup are encrypted in multiple segments
* to avoid having to load the entire stream into memory when doing authenticated encryption.
*
* The given cleartext can later be decrypted
* by calling [decryptSegment] on the same byte stream.
*/
@Throws(IOException::class)
fun encryptSegment(outputStream: OutputStream, cleartext: ByteArray)
/**
* Like [encryptSegment],
* but if the given cleartext [ByteArray] is larger than [MAX_SEGMENT_CLEARTEXT_LENGTH],
* multiple segments will be written.
*/
@Throws(IOException::class)
fun encryptMultipleSegments(outputStream: OutputStream, cleartext: ByteArray)
/** /**
* Reads and decrypts a [VersionHeader] from the given [InputStream] * Reads and decrypts a [VersionHeader] from the given [InputStream]
* and ensures that the expected version, package name and key match * and ensures that the expected version, package name and key match
@ -95,6 +61,7 @@ interface Crypto {
* *
* @return The read [VersionHeader] present in the beginning of the given [InputStream]. * @return The read [VersionHeader] present in the beginning of the given [InputStream].
*/ */
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class) @Throws(IOException::class, SecurityException::class)
fun decryptHeader( fun decryptHeader(
inputStream: InputStream, inputStream: InputStream,
@ -106,14 +73,16 @@ interface Crypto {
/** /**
* Reads and decrypts a segment from the given [InputStream]. * Reads and decrypts a segment from the given [InputStream].
* *
* @return The decrypted segment payload as passed into [encryptSegment] * @return The decrypted segment payload.
*/ */
@Deprecated("Use newDecryptingStream instead")
@Throws(EOFException::class, IOException::class, SecurityException::class) @Throws(EOFException::class, IOException::class, SecurityException::class)
fun decryptSegment(inputStream: InputStream): ByteArray fun decryptSegment(inputStream: InputStream): ByteArray
/** /**
* Like [decryptSegment], but decrypts multiple segments and does not throw [EOFException]. * Like [decryptSegment], but decrypts multiple segments and does not throw [EOFException].
*/ */
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class) @Throws(IOException::class, SecurityException::class)
fun decryptMultipleSegments(inputStream: InputStream): ByteArray fun decryptMultipleSegments(inputStream: InputStream): ByteArray
@ -132,7 +101,6 @@ internal const val TYPE_BACKUP_FULL: Byte = 0x02
internal class CryptoImpl( internal class CryptoImpl(
private val keyManager: KeyManager, private val keyManager: KeyManager,
private val cipherFactory: CipherFactory, private val cipherFactory: CipherFactory,
private val headerWriter: HeaderWriter,
private val headerReader: HeaderReader private val headerReader: HeaderReader
) : Crypto { ) : Crypto {
@ -156,45 +124,8 @@ internal class CryptoImpl(
return StreamCrypto.newDecryptingStream(key, inputStream, associatedData) return StreamCrypto.newDecryptingStream(key, inputStream, associatedData)
} }
@Throws(IOException::class)
override fun encryptHeader(outputStream: OutputStream, versionHeader: VersionHeader) {
val bytes = headerWriter.getEncodedVersionHeader(versionHeader)
encryptSegment(outputStream, bytes)
}
@Throws(IOException::class)
override fun encryptSegment(outputStream: OutputStream, cleartext: ByteArray) {
val cipher = cipherFactory.createEncryptionCipher()
check(cipher.getOutputSize(cleartext.size) <= MAX_SEGMENT_LENGTH) {
"Cipher's output size ${cipher.getOutputSize(cleartext.size)} is larger" +
"than maximum segment length ($MAX_SEGMENT_LENGTH)"
}
encryptSegment(cipher, outputStream, cleartext)
}
@Throws(IOException::class)
override fun encryptMultipleSegments(outputStream: OutputStream, cleartext: ByteArray) {
var end = 0
while (end < cleartext.size) {
val start = end
end = min(cleartext.size, start + MAX_SEGMENT_CLEARTEXT_LENGTH)
val segment = cleartext.copyOfRange(start, end)
val cipher = cipherFactory.createEncryptionCipher()
encryptSegment(cipher, outputStream, segment)
}
}
@Throws(IOException::class)
private fun encryptSegment(cipher: Cipher, outputStream: OutputStream, segment: ByteArray) {
val encrypted = cipher.doFinal(segment)
val segmentHeader = SegmentHeader(encrypted.size.toShort(), cipher.iv)
headerWriter.writeSegmentHeader(outputStream, segmentHeader)
outputStream.write(encrypted)
}
@Throws(IOException::class, SecurityException::class) @Throws(IOException::class, SecurityException::class)
@Deprecated("Use newDecryptingStream instead")
override fun decryptHeader( override fun decryptHeader(
inputStream: InputStream, inputStream: InputStream,
expectedVersion: Byte, expectedVersion: Byte,
@ -223,11 +154,13 @@ internal class CryptoImpl(
return header return header
} }
@Deprecated("Use newDecryptingStream instead")
@Throws(EOFException::class, IOException::class, SecurityException::class) @Throws(EOFException::class, IOException::class, SecurityException::class)
override fun decryptSegment(inputStream: InputStream): ByteArray { override fun decryptSegment(inputStream: InputStream): ByteArray {
return decryptSegment(inputStream, MAX_SEGMENT_LENGTH) return decryptSegment(inputStream, MAX_SEGMENT_LENGTH)
} }
@Deprecated("Use newDecryptingStream instead")
@Throws(IOException::class, SecurityException::class) @Throws(IOException::class, SecurityException::class)
override fun decryptMultipleSegments(inputStream: InputStream): ByteArray { override fun decryptMultipleSegments(inputStream: InputStream): ByteArray {
var result = ByteArray(0) var result = ByteArray(0)

View file

@ -15,5 +15,5 @@ val cryptoModule = module {
} }
KeyManagerImpl(keyStore) KeyManagerImpl(keyStore)
} }
single<Crypto> { CryptoImpl(get(), get(), get(), get()) } single<Crypto> { CryptoImpl(get(), get(), get()) }
} }

View file

@ -15,7 +15,8 @@ internal const val MAX_VERSION_HEADER_SIZE =
* After the first version byte of each backup stream * After the first version byte of each backup stream
* must follow followed this header encrypted with authentication. * must follow followed this header encrypted with authentication.
*/ */
data class VersionHeader( @Deprecated("version header is in associated data now")
internal data class VersionHeader(
internal val version: Byte = VERSION, // 1 byte internal val version: Byte = VERSION, // 1 byte
internal val packageName: String, // ?? bytes (max 255) internal val packageName: String, // ?? bytes (max 255)
internal val key: String? = null // ?? bytes internal val key: String? = null // ?? bytes
@ -60,6 +61,7 @@ internal const val SEGMENT_HEADER_SIZE = SEGMENT_LENGTH_SIZE + IV_SIZE
/** /**
* Each data segment must start with this header * Each data segment must start with this header
*/ */
@Deprecated("Don't do manual segments, use Crypto interface instead.")
class SegmentHeader( class SegmentHeader(
internal val segmentLength: Short, // 2 bytes internal val segmentLength: Short, // 2 bytes
internal val nonce: ByteArray // 12 bytes internal val nonce: ByteArray // 12 bytes

View file

@ -3,6 +3,5 @@ package com.stevesoltys.seedvault.header
import org.koin.dsl.module import org.koin.dsl.module
val headerModule = module { val headerModule = module {
single<HeaderWriter> { HeaderWriterImpl() }
single<HeaderReader> { HeaderReaderImpl() } single<HeaderReader> { HeaderReaderImpl() }
} }

View file

@ -6,13 +6,15 @@ import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.nio.ByteBuffer import java.nio.ByteBuffer
interface HeaderReader { internal interface HeaderReader {
@Throws(IOException::class, UnsupportedVersionException::class) @Throws(IOException::class, UnsupportedVersionException::class)
fun readVersion(inputStream: InputStream): Byte fun readVersion(inputStream: InputStream): Byte
@Deprecated("")
@Throws(SecurityException::class) @Throws(SecurityException::class)
fun getVersionHeader(byteArray: ByteArray): VersionHeader fun getVersionHeader(byteArray: ByteArray): VersionHeader
@Deprecated("")
@Throws(EOFException::class, IOException::class) @Throws(EOFException::class, IOException::class)
fun readSegmentHeader(inputStream: InputStream): SegmentHeader fun readSegmentHeader(inputStream: InputStream): SegmentHeader
} }

View file

@ -1,51 +0,0 @@
package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.Utf8
import java.io.IOException
import java.io.OutputStream
import java.nio.ByteBuffer
interface HeaderWriter {
@Throws(IOException::class)
fun writeVersion(outputStream: OutputStream, header: VersionHeader)
fun getEncodedVersionHeader(header: VersionHeader): ByteArray
@Throws(IOException::class)
fun writeSegmentHeader(outputStream: OutputStream, header: SegmentHeader)
}
internal class HeaderWriterImpl : HeaderWriter {
@Throws(IOException::class)
override fun writeVersion(outputStream: OutputStream, header: VersionHeader) {
val headerBytes = ByteArray(1)
headerBytes[0] = header.version
outputStream.write(headerBytes)
}
override fun getEncodedVersionHeader(header: VersionHeader): ByteArray {
val packageBytes = header.packageName.toByteArray(Utf8)
val keyBytes = header.key?.toByteArray(Utf8)
val size = 1 + 2 + packageBytes.size + 2 + (keyBytes?.size ?: 0)
return ByteBuffer.allocate(size).apply {
put(header.version)
putShort(packageBytes.size.toShort())
put(packageBytes)
if (keyBytes == null) {
putShort(0.toShort())
} else {
putShort(keyBytes.size.toShort())
put(keyBytes)
}
}.array()
}
override fun writeSegmentHeader(outputStream: OutputStream, header: SegmentHeader) {
val buffer = ByteBuffer.allocate(SEGMENT_HEADER_SIZE)
.putShort(header.segmentLength)
.put(header.nonce)
outputStream.write(buffer.array())
}
}

View file

@ -23,7 +23,6 @@ val backupModule = module {
plugin = get<BackupPlugin>().kvBackupPlugin, plugin = get<BackupPlugin>().kvBackupPlugin,
settingsManager = get(), settingsManager = get(),
inputFactory = get(), inputFactory = get(),
headerWriter = get(),
crypto = get(), crypto = get(),
nm = get() nm = get()
) )
@ -33,7 +32,6 @@ val backupModule = module {
plugin = get<BackupPlugin>().fullBackupPlugin, plugin = get<BackupPlugin>().fullBackupPlugin,
settingsManager = get(), settingsManager = get(),
inputFactory = get(), inputFactory = get(),
headerWriter = get(),
crypto = get() crypto = get()
) )
} }

View file

@ -9,9 +9,7 @@ import android.content.pm.PackageInfo
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.util.Log import android.util.Log
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForFull import com.stevesoltys.seedvault.header.getADForFull
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import libcore.io.IoUtils.closeQuietly import libcore.io.IoUtils.closeQuietly
@ -43,7 +41,6 @@ internal class FullBackup(
private val plugin: FullBackupPlugin, private val plugin: FullBackupPlugin,
private val settingsManager: SettingsManager, private val settingsManager: SettingsManager,
private val inputFactory: InputFactory, private val inputFactory: InputFactory,
private val headerWriter: HeaderWriter,
private val crypto: Crypto private val crypto: Crypto
) { ) {
@ -124,8 +121,7 @@ internal class FullBackup(
} }
// store version header // store version header
val state = this.state ?: throw AssertionError() val state = this.state ?: throw AssertionError()
val header = VersionHeader(packageName = state.packageName) outputStream.write(ByteArray(1) { VERSION })
headerWriter.writeVersion(outputStream, header)
crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName)) crypto.newEncryptingStream(outputStream, getADForFull(VERSION, state.packageName))
} // this lambda is only called before we actually write backup data the first time } // this lambda is only called before we actually write backup data the first time
return TRANSPORT_OK return TRANSPORT_OK

View file

@ -12,9 +12,7 @@ import android.util.Log
import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER import com.stevesoltys.seedvault.MAGIC_PACKAGE_MANAGER
import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.encodeBase64 import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.settings.SettingsManager import com.stevesoltys.seedvault.settings.SettingsManager
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
@ -31,7 +29,6 @@ internal class KVBackup(
private val plugin: KVBackupPlugin, private val plugin: KVBackupPlugin,
private val settingsManager: SettingsManager, private val settingsManager: SettingsManager,
private val inputFactory: InputFactory, private val inputFactory: InputFactory,
private val headerWriter: HeaderWriter,
private val crypto: Crypto, private val crypto: Crypto,
private val nm: BackupNotificationManager private val nm: BackupNotificationManager
) { ) {
@ -168,11 +165,7 @@ internal class KVBackup(
plugin.deleteRecord(packageInfo, op.base64Key) plugin.deleteRecord(packageInfo, op.base64Key)
} else { } else {
plugin.getOutputStreamForRecord(packageInfo, op.base64Key).use { outputStream -> plugin.getOutputStreamForRecord(packageInfo, op.base64Key).use { outputStream ->
val header = VersionHeader( outputStream.write(ByteArray(1) { VERSION })
packageName = packageInfo.packageName,
key = op.key
)
headerWriter.writeVersion(outputStream, header)
val ad = getADForKV(VERSION, packageInfo.packageName) val ad = getADForKV(VERSION, packageInfo.packageName)
crypto.newEncryptingStream(outputStream, ad).use { encryptedStream -> crypto.newEncryptingStream(outputStream, ad).use { encryptedStream ->
encryptedStream.write(op.value) encryptedStream.write(op.value)

View file

@ -21,7 +21,7 @@ class TestApp : App() {
private val testCryptoModule = module { private val testCryptoModule = module {
factory<CipherFactory> { CipherFactoryImpl(get()) } factory<CipherFactory> { CipherFactoryImpl(get()) }
single<KeyManager> { KeyManagerTestImpl() } single<KeyManager> { KeyManagerTestImpl() }
single<Crypto> { CryptoImpl(get(), get(), get(), get()) } single<Crypto> { CryptoImpl(get(), get(), get()) }
} }
private val appModule = module { private val appModule = module {
single { Clock() } single { Clock() }

View file

@ -1,57 +1,22 @@
package com.stevesoltys.seedvault.crypto package com.stevesoltys.seedvault.crypto
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.IV_SIZE
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import io.mockk.every
import io.mockk.mockk import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertThrows import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException import java.io.IOException
import javax.crypto.Cipher
import kotlin.random.Random
@TestInstance(PER_METHOD) @TestInstance(PER_METHOD)
class CryptoImplTest { class CryptoImplTest {
private val keyManager = mockk<KeyManager>() private val keyManager = mockk<KeyManager>()
private val cipherFactory = mockk<CipherFactory>() private val cipherFactory = mockk<CipherFactory>()
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
private val cipher = mockk<Cipher>()
private val iv = ByteArray(IV_SIZE).apply { Random.nextBytes(this) }
private val cleartext = ByteArray(Random.nextInt(Short.MAX_VALUE.toInt()))
.apply { Random.nextBytes(this) }
private val ciphertext = ByteArray(Random.nextInt(Short.MAX_VALUE.toInt()))
.apply { Random.nextBytes(this) }
private val outputStream = ByteArrayOutputStream()
@Test
fun `encrypted cleartext gets decrypted as expected`() {
every { cipherFactory.createEncryptionCipher() } returns cipher
every { cipher.getOutputSize(cleartext.size) } returns MAX_SEGMENT_LENGTH
every { cipher.doFinal(cleartext) } returns ciphertext
every { cipher.iv } returns iv
crypto.encryptSegment(outputStream, cleartext)
val inputStream = ByteArrayInputStream(outputStream.toByteArray())
every { cipherFactory.createDecryptionCipher(iv) } returns cipher
every { cipher.doFinal(ciphertext) } returns cleartext
assertArrayEquals(cleartext, crypto.decryptSegment(inputStream))
}
@Test @Test
fun `decrypting multiple segments on empty stream throws`() { fun `decrypting multiple segments on empty stream throws`() {

View file

@ -1,16 +1,14 @@
package com.stevesoltys.seedvault.crypto package com.stevesoltys.seedvault.crypto
import com.stevesoltys.seedvault.assertReadEquals
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl import org.junit.jupiter.api.Assertions.assertThrows
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.header.MAX_SEGMENT_LENGTH
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.IOException
import kotlin.random.Random import kotlin.random.Random
@TestInstance(PER_METHOD) @TestInstance(PER_METHOD)
@ -18,50 +16,31 @@ class CryptoIntegrationTest {
private val keyManager = KeyManagerTestImpl() private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val cleartext = Random.nextBytes(Random.nextInt(1, 422300))
private val cleartext = byteArrayOf(0x01, 0x02, 0x03)
private val outputStream = ByteArrayOutputStream()
@Test @Test
fun `the plain crypto works`() { fun `decrypting encrypted cleartext works`() {
val eCipher = cipherFactory.createEncryptionCipher() val ad = Random.nextBytes(42)
val encrypted = eCipher.doFinal(cleartext) val outputStream = ByteArrayOutputStream()
crypto.newEncryptingStream(outputStream, ad).use { it.write(cleartext) }
val dCipher = cipherFactory.createDecryptionCipher(eCipher.iv)
val decrypted = dCipher.doFinal(encrypted)
assertArrayEquals(cleartext, decrypted)
}
@Test
fun `encrypted cleartext gets decrypted as expected`() {
crypto.encryptSegment(outputStream, cleartext)
val inputStream = ByteArrayInputStream(outputStream.toByteArray()) val inputStream = ByteArrayInputStream(outputStream.toByteArray())
assertArrayEquals(cleartext, crypto.decryptSegment(inputStream)) crypto.newDecryptingStream(inputStream, ad).use {
assertReadEquals(cleartext, it)
}
} }
@Test @Test
fun `multiple segments get encrypted and decrypted as expected`() { fun `decrypting encrypted cleartext fails with different AD`() {
val size = Random.nextInt(5) * MAX_SEGMENT_CLEARTEXT_LENGTH + Random.nextInt(0, 1337) val outputStream = ByteArrayOutputStream()
val cleartext = ByteArray(size).apply { Random.nextBytes(this) } crypto.newEncryptingStream(outputStream, Random.nextBytes(42)).use { it.write(cleartext) }
crypto.encryptMultipleSegments(outputStream, cleartext)
val inputStream = ByteArrayInputStream(outputStream.toByteArray()) val inputStream = ByteArrayInputStream(outputStream.toByteArray())
assertArrayEquals(cleartext, crypto.decryptMultipleSegments(inputStream)) assertThrows(IOException::class.java) {
crypto.newDecryptingStream(inputStream, Random.nextBytes(41)).use {
it.read()
} }
@Test
fun `test maximum lengths`() {
val cipher = cipherFactory.createEncryptionCipher()
val expectedDiff = MAX_SEGMENT_LENGTH - MAX_SEGMENT_CLEARTEXT_LENGTH
for (i in 1..(3 * MAX_SEGMENT_LENGTH + 42)) {
val outputSize = cipher.getOutputSize(i)
assertEquals(expectedDiff, outputSize - i)
} }
} }

View file

@ -4,7 +4,6 @@ import com.stevesoltys.seedvault.assertContains
import com.stevesoltys.seedvault.getRandomByteArray import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.HeaderReader import com.stevesoltys.seedvault.header.HeaderReader
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.IV_SIZE import com.stevesoltys.seedvault.header.IV_SIZE
import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE
import com.stevesoltys.seedvault.header.MAX_PACKAGE_LENGTH_SIZE import com.stevesoltys.seedvault.header.MAX_PACKAGE_LENGTH_SIZE
@ -13,10 +12,7 @@ import com.stevesoltys.seedvault.header.MAX_VERSION_HEADER_SIZE
import com.stevesoltys.seedvault.header.SegmentHeader import com.stevesoltys.seedvault.header.SegmentHeader
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.VersionHeader
import io.mockk.CapturingSlot
import io.mockk.Runs
import io.mockk.every import io.mockk.every
import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
@ -26,7 +22,6 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.EOFException import java.io.EOFException
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
@ -38,10 +33,9 @@ class CryptoTest {
private val keyManager = mockk<KeyManager>() private val keyManager = mockk<KeyManager>()
private val cipherFactory = mockk<CipherFactory>() private val cipherFactory = mockk<CipherFactory>()
private val headerWriter = mockk<HeaderWriter>()
private val headerReader = mockk<HeaderReader>() private val headerReader = mockk<HeaderReader>()
private val crypto = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val crypto = CryptoImpl(keyManager, cipherFactory, headerReader)
private val cipher = mockk<Cipher>() private val cipher = mockk<Cipher>()
@ -55,49 +49,12 @@ class CryptoTest {
) )
private val versionCiphertext = getRandomByteArray(MAX_VERSION_HEADER_SIZE) private val versionCiphertext = getRandomByteArray(MAX_VERSION_HEADER_SIZE)
private val versionSegmentHeader = SegmentHeader(versionCiphertext.size.toShort(), iv) private val versionSegmentHeader = SegmentHeader(versionCiphertext.size.toShort(), iv)
private val outputStream = ByteArrayOutputStream()
private val segmentHeader = SegmentHeader(ciphertext.size.toShort(), iv) private val segmentHeader = SegmentHeader(ciphertext.size.toShort(), iv)
// the headerReader will not actually read the header, so only insert cipher text // the headerReader will not actually read the header, so only insert cipher text
private val inputStream = ByteArrayInputStream(ciphertext) private val inputStream = ByteArrayInputStream(ciphertext)
private val versionInputStream = ByteArrayInputStream(versionCiphertext) private val versionInputStream = ByteArrayInputStream(versionCiphertext)
// encrypting
@Test
fun `encrypt header works as expected`() {
val segmentHeader = CapturingSlot<SegmentHeader>()
every { headerWriter.getEncodedVersionHeader(versionHeader) } returns ciphertext
encryptSegmentHeader(ciphertext, segmentHeader)
crypto.encryptHeader(outputStream, versionHeader)
assertArrayEquals(iv, segmentHeader.captured.nonce)
assertEquals(ciphertext.size, segmentHeader.captured.segmentLength.toInt())
}
@Test
fun `encrypting segment works as expected`() {
val segmentHeader = CapturingSlot<SegmentHeader>()
encryptSegmentHeader(cleartext, segmentHeader)
crypto.encryptSegment(outputStream, cleartext)
assertArrayEquals(ciphertext, outputStream.toByteArray())
assertArrayEquals(iv, segmentHeader.captured.nonce)
assertEquals(ciphertext.size, segmentHeader.captured.segmentLength.toInt())
}
private fun encryptSegmentHeader(
toEncrypt: ByteArray,
segmentHeader: CapturingSlot<SegmentHeader>
) {
every { cipherFactory.createEncryptionCipher() } returns cipher
every { cipher.getOutputSize(toEncrypt.size) } returns toEncrypt.size
every { cipher.iv } returns iv
every { headerWriter.writeSegmentHeader(outputStream, capture(segmentHeader)) } just Runs
every { cipher.doFinal(toEncrypt) } returns ciphertext
}
// decrypting // decrypting
@Test @Test

View file

@ -1,104 +0,0 @@
package com.stevesoltys.seedvault.header
import com.stevesoltys.seedvault.getRandomByteArray
import com.stevesoltys.seedvault.getRandomString
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import kotlin.random.Random
@TestInstance(PER_CLASS)
internal class HeaderWriterReaderTest {
private val writer = HeaderWriterImpl()
private val reader = HeaderReaderImpl()
private val packageName = getRandomString(MAX_PACKAGE_LENGTH_SIZE)
private val key = getRandomString(MAX_KEY_LENGTH_SIZE)
private val versionHeader = VersionHeader(VERSION, packageName, key)
private val unsupportedVersionHeader = VersionHeader((VERSION + 1).toByte(), packageName)
private val segmentLength = getRandomValidSegmentLength()
private val nonce = getRandomByteArray(IV_SIZE)
private val segmentHeader = SegmentHeader(segmentLength, nonce)
@Test
fun `written version matches read input`() {
assertEquals(versionHeader.version, readWriteVersion(versionHeader))
}
@Test
fun `reading unsupported version throws exception`() {
assertThrows(UnsupportedVersionException::class.javaObjectType) {
readWriteVersion(unsupportedVersionHeader)
}
}
@Test
fun `VersionHeader output matches read input`() {
assertEquals(versionHeader, readWrite(versionHeader))
}
@Test
fun `VersionHeader with no key output matches read input`() {
val versionHeader = VersionHeader(VERSION, packageName, null)
assertEquals(versionHeader, readWrite(versionHeader))
}
@Test
fun `VersionHeader with empty package name throws`() {
val versionHeader = VersionHeader(VERSION, "")
assertThrows(SecurityException::class.java) {
readWrite(versionHeader)
}
}
@Test
fun `SegmentHeader constructor needs right IV size`() {
val nonceTooBig = ByteArray(IV_SIZE + 1).apply { Random.nextBytes(this) }
assertThrows(IllegalStateException::class.javaObjectType) {
SegmentHeader(segmentLength, nonceTooBig)
}
val nonceTooSmall = ByteArray(IV_SIZE - 1).apply { Random.nextBytes(this) }
assertThrows(IllegalStateException::class.javaObjectType) {
SegmentHeader(segmentLength, nonceTooSmall)
}
}
@Test
fun `SegmentHeader output matches read input`() {
assertEquals(segmentHeader, readWriteVersion(segmentHeader))
}
private fun readWriteVersion(header: VersionHeader): Byte {
val outputStream = ByteArrayOutputStream()
writer.writeVersion(outputStream, header)
val written = outputStream.toByteArray()
val inputStream = ByteArrayInputStream(written)
return reader.readVersion(inputStream)
}
private fun readWrite(header: VersionHeader): VersionHeader {
val written = writer.getEncodedVersionHeader(header)
return reader.getVersionHeader(written)
}
private fun readWriteVersion(header: SegmentHeader): SegmentHeader {
val outputStream = ByteArrayOutputStream()
writer.writeSegmentHeader(outputStream, header)
val written = outputStream.toByteArray()
val inputStream = ByteArrayInputStream(written)
return reader.readSegmentHeader(inputStream)
}
private fun assertEquals(expected: SegmentHeader, actual: SegmentHeader) {
assertEquals(expected.segmentLength, actual.segmentLength)
assertArrayEquals(expected.nonce, actual.nonce)
}
}

View file

@ -6,7 +6,6 @@ import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA
import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED
@ -27,9 +26,8 @@ internal class MetadataReadWriteTest {
) )
private val keyManager = KeyManagerTestImpl(secretKey) private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val writer = MetadataWriterImpl(cryptoImpl) private val writer = MetadataWriterImpl(cryptoImpl)
private val reader = MetadataReaderImpl(cryptoImpl) private val reader = MetadataReaderImpl(cryptoImpl)

View file

@ -5,7 +5,6 @@ import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA import com.stevesoltys.seedvault.metadata.PackageState.APK_AND_DATA
import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED import com.stevesoltys.seedvault.metadata.PackageState.WAS_STOPPED
import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.toByteArrayFromHex
@ -27,9 +26,8 @@ internal class MetadataV0ReadTest {
) )
private val keyManager = KeyManagerTestImpl(secretKey) private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val reader = MetadataReaderImpl(cryptoImpl) private val reader = MetadataReaderImpl(cryptoImpl)

View file

@ -10,7 +10,7 @@ import org.junit.Assert.assertTrue
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import kotlin.random.Random import kotlin.random.Random
class ApkSplitCompatibilityCheckerTest : TransportTest() { internal class ApkSplitCompatibilityCheckerTest : TransportTest() {
private val deviceInfo: DeviceInfo = mockk() private val deviceInfo: DeviceInfo = mockk()
private val deviceName = getRandomString() private val deviceName = getRandomString()

View file

@ -12,7 +12,6 @@ import com.stevesoltys.seedvault.crypto.CryptoImpl
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64 import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH import com.stevesoltys.seedvault.header.MAX_SEGMENT_CLEARTEXT_LENGTH
import com.stevesoltys.seedvault.metadata.MetadataReaderImpl import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.metadata.PackageMetadata import com.stevesoltys.seedvault.metadata.PackageMetadata
@ -57,9 +56,8 @@ internal class CoordinatorIntegrationTest : TransportTest() {
private val outputFactory = mockk<OutputFactory>() private val outputFactory = mockk<OutputFactory>()
private val keyManager = KeyManagerTestImpl() private val keyManager = KeyManagerTestImpl()
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val metadataReader = MetadataReaderImpl(cryptoImpl) private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>() private val notificationManager = mockk<BackupNotificationManager>()
@ -69,7 +67,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
plugin = kvBackupPlugin, plugin = kvBackupPlugin,
settingsManager = settingsManager, settingsManager = settingsManager,
inputFactory = inputFactory, inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = cryptoImpl, crypto = cryptoImpl,
nm = notificationManager nm = notificationManager
) )
@ -78,7 +75,6 @@ internal class CoordinatorIntegrationTest : TransportTest() {
plugin = fullBackupPlugin, plugin = fullBackupPlugin,
settingsManager = settingsManager, settingsManager = settingsManager,
inputFactory = inputFactory, inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = cryptoImpl crypto = cryptoImpl
) )
private val apkBackup = mockk<ApkBackup>() private val apkBackup = mockk<ApkBackup>()

View file

@ -20,7 +20,7 @@ import org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD
import kotlin.random.Random import kotlin.random.Random
@TestInstance(PER_METHOD) @TestInstance(PER_METHOD)
abstract class TransportTest { internal abstract class TransportTest {
protected val clock: Clock = mockk() protected val clock: Clock = mockk()
protected val crypto = mockk<Crypto>() protected val crypto = mockk<Crypto>()

View file

@ -1,8 +1,6 @@
package com.stevesoltys.seedvault.transport.backup package com.stevesoltys.seedvault.transport.backup
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import com.stevesoltys.seedvault.header.HeaderWriter
import com.stevesoltys.seedvault.header.VersionHeader
import com.stevesoltys.seedvault.transport.TransportTest import com.stevesoltys.seedvault.transport.TransportTest
import io.mockk.mockk import io.mockk.mockk
import java.io.OutputStream import java.io.OutputStream
@ -10,12 +8,10 @@ import java.io.OutputStream
internal abstract class BackupTest : TransportTest() { internal abstract class BackupTest : TransportTest() {
protected val inputFactory = mockk<InputFactory>() protected val inputFactory = mockk<InputFactory>()
protected val headerWriter = mockk<HeaderWriter>()
protected val data = mockk<ParcelFileDescriptor>() protected val data = mockk<ParcelFileDescriptor>()
protected val outputStream = mockk<OutputStream>() protected val outputStream = mockk<OutputStream>()
protected val encryptedOutputStream = mockk<OutputStream>() protected val encryptedOutputStream = mockk<OutputStream>()
protected val header = VersionHeader(packageName = packageInfo.packageName)
protected val quota = 42L protected val quota = 42L
} }

View file

@ -24,7 +24,7 @@ import kotlin.random.Random
internal class FullBackupTest : BackupTest() { internal class FullBackupTest : BackupTest() {
private val plugin = mockk<FullBackupPlugin>() private val plugin = mockk<FullBackupPlugin>()
private val backup = FullBackup(plugin, settingsManager, inputFactory, headerWriter, crypto) private val backup = FullBackup(plugin, settingsManager, inputFactory, crypto)
private val bytes = ByteArray(23).apply { Random.nextBytes(this) } private val bytes = ByteArray(23).apply { Random.nextBytes(this) }
private val inputStream = mockk<FileInputStream>() private val inputStream = mockk<FileInputStream>()
@ -168,7 +168,7 @@ internal class FullBackupTest : BackupTest() {
every { plugin.getQuota() } returns quota every { plugin.getQuota() } returns quota
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { inputFactory.getInputStream(data) } returns inputStream every { inputFactory.getInputStream(data) } returns inputStream
every { headerWriter.writeVersion(outputStream, header) } throws IOException() every { outputStream.write(ByteArray(1) { VERSION }) } throws IOException()
expectClearState() expectClearState()
assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data)) assertEquals(TRANSPORT_OK, backup.performFullBackup(packageInfo, data))
@ -316,7 +316,7 @@ internal class FullBackupTest : BackupTest() {
private fun expectInitializeOutputStream() { private fun expectInitializeOutputStream() {
coEvery { plugin.getOutputStream(packageInfo) } returns outputStream coEvery { plugin.getOutputStream(packageInfo) } returns outputStream
every { headerWriter.writeVersion(outputStream, header) } just Runs every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
} }
private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) { private fun expectSendData(numBytes: Int, readBytes: Int = numBytes) {

View file

@ -11,7 +11,7 @@ import android.content.pm.PackageInfo
import com.stevesoltys.seedvault.Utf8 import com.stevesoltys.seedvault.Utf8
import com.stevesoltys.seedvault.getRandomString import com.stevesoltys.seedvault.getRandomString
import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE import com.stevesoltys.seedvault.header.MAX_KEY_LENGTH_SIZE
import com.stevesoltys.seedvault.header.VersionHeader import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.header.getADForKV import com.stevesoltys.seedvault.header.getADForKV
import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager import com.stevesoltys.seedvault.ui.notification.BackupNotificationManager
import io.mockk.CapturingSlot import io.mockk.CapturingSlot
@ -42,7 +42,6 @@ internal class KVBackupTest : BackupTest() {
plugin = plugin, plugin = plugin,
settingsManager = settingsManager, settingsManager = settingsManager,
inputFactory = inputFactory, inputFactory = inputFactory,
headerWriter = headerWriter,
crypto = crypto, crypto = crypto,
nm = notificationManager nm = notificationManager
) )
@ -50,7 +49,6 @@ internal class KVBackupTest : BackupTest() {
private val key = getRandomString(MAX_KEY_LENGTH_SIZE) private val key = getRandomString(MAX_KEY_LENGTH_SIZE)
private val key64 = Base64.getEncoder().encodeToString(key.toByteArray(Utf8)) private val key64 = Base64.getEncoder().encodeToString(key.toByteArray(Utf8))
private val dataValue = Random.nextBytes(23) private val dataValue = Random.nextBytes(23)
private val versionHeader = VersionHeader(packageName = packageInfo.packageName, key = key)
@Test @Test
fun `has no initial state`() { fun `has no initial state`() {
@ -81,14 +79,11 @@ internal class KVBackupTest : BackupTest() {
// store first record and show notification for it // store first record and show notification for it
every { notificationManager.onPmKvBackup("key1", 1, 2) } just Runs every { notificationManager.onPmKvBackup("key1", 1, 2) } just Runs
coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5MQ") } returns outputStream coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5MQ") } returns outputStream
val versionHeader1 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key1") every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
every { headerWriter.writeVersion(outputStream, versionHeader1) } just Runs
// store second record and show notification for it // store second record and show notification for it
every { notificationManager.onPmKvBackup("key2", 2, 2) } just Runs every { notificationManager.onPmKvBackup("key2", 2, 2) } just Runs
coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5Mg") } returns outputStream coEvery { plugin.getOutputStreamForRecord(pmPackageInfo, "a2V5Mg") } returns outputStream
val versionHeader2 = VersionHeader(packageName = pmPackageInfo.packageName, key = "key2")
every { headerWriter.writeVersion(outputStream, versionHeader2) } just Runs
// encrypt to and close output stream // encrypt to and close output stream
every { crypto.newEncryptingStream(outputStream, any()) } returns encryptedOutputStream every { crypto.newEncryptingStream(outputStream, any()) } returns encryptedOutputStream
@ -213,11 +208,11 @@ internal class KVBackupTest : BackupTest() {
} }
@Test @Test
fun `exception while writing version header`() = runBlocking { fun `exception while writing version`() = runBlocking {
initPlugin(false) initPlugin(false)
getDataInput(listOf(true)) getDataInput(listOf(true))
coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream
every { headerWriter.writeVersion(outputStream, versionHeader) } throws IOException() every { outputStream.write(ByteArray(1) { VERSION }) } throws IOException()
every { outputStream.close() } just Runs every { outputStream.close() } just Runs
every { plugin.packageFinished(packageInfo) } just Runs every { plugin.packageFinished(packageInfo) } just Runs
@ -231,7 +226,7 @@ internal class KVBackupTest : BackupTest() {
fun `exception while writing encrypted value to output stream`() = runBlocking { fun `exception while writing encrypted value to output stream`() = runBlocking {
initPlugin(false) initPlugin(false)
getDataInput(listOf(true)) getDataInput(listOf(true))
writeHeaderAndEncrypt() writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } throws IOException() every { encryptedOutputStream.write(dataValue) } throws IOException()
every { plugin.packageFinished(packageInfo) } just Runs every { plugin.packageFinished(packageInfo) } just Runs
@ -245,7 +240,7 @@ internal class KVBackupTest : BackupTest() {
fun `exception while flushing output stream`() = runBlocking { fun `exception while flushing output stream`() = runBlocking {
initPlugin(false) initPlugin(false)
getDataInput(listOf(true)) getDataInput(listOf(true))
writeHeaderAndEncrypt() writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } throws IOException() every { encryptedOutputStream.flush() } throws IOException()
every { encryptedOutputStream.close() } just Runs every { encryptedOutputStream.close() } just Runs
@ -262,7 +257,7 @@ internal class KVBackupTest : BackupTest() {
fun `ignoring exception while closing output stream`() = runBlocking { fun `ignoring exception while closing output stream`() = runBlocking {
initPlugin(false) initPlugin(false)
getDataInput(listOf(true, false)) getDataInput(listOf(true, false))
writeHeaderAndEncrypt() writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } just Runs every { encryptedOutputStream.flush() } just Runs
every { encryptedOutputStream.close() } just Runs every { encryptedOutputStream.close() } just Runs
@ -278,7 +273,7 @@ internal class KVBackupTest : BackupTest() {
private fun singleRecordBackup(hasDataForPackage: Boolean = false) { private fun singleRecordBackup(hasDataForPackage: Boolean = false) {
initPlugin(hasDataForPackage) initPlugin(hasDataForPackage)
getDataInput(listOf(true, false)) getDataInput(listOf(true, false))
writeHeaderAndEncrypt() writeVersionAndEncrypt()
every { encryptedOutputStream.write(dataValue) } just Runs every { encryptedOutputStream.write(dataValue) } just Runs
every { encryptedOutputStream.flush() } just Runs every { encryptedOutputStream.flush() } just Runs
every { encryptedOutputStream.close() } just Runs every { encryptedOutputStream.close() } just Runs
@ -306,10 +301,10 @@ internal class KVBackupTest : BackupTest() {
} }
} }
private fun writeHeaderAndEncrypt() { private fun writeVersionAndEncrypt() {
coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream coEvery { plugin.getOutputStreamForRecord(packageInfo, key64) } returns outputStream
every { headerWriter.writeVersion(outputStream, versionHeader) } just Runs every { outputStream.write(ByteArray(1) { VERSION }) } just Runs
val ad = getADForKV(versionHeader.version, packageInfo.packageName) val ad = getADForKV(VERSION, packageInfo.packageName)
every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream every { crypto.newEncryptingStream(outputStream, ad) } returns encryptedOutputStream
} }

View file

@ -12,7 +12,6 @@ import com.stevesoltys.seedvault.crypto.KEY_SIZE_BYTES
import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl import com.stevesoltys.seedvault.crypto.KeyManagerTestImpl
import com.stevesoltys.seedvault.encodeBase64 import com.stevesoltys.seedvault.encodeBase64
import com.stevesoltys.seedvault.header.HeaderReaderImpl import com.stevesoltys.seedvault.header.HeaderReaderImpl
import com.stevesoltys.seedvault.header.HeaderWriterImpl
import com.stevesoltys.seedvault.metadata.MetadataReaderImpl import com.stevesoltys.seedvault.metadata.MetadataReaderImpl
import com.stevesoltys.seedvault.toByteArrayFromHex import com.stevesoltys.seedvault.toByteArrayFromHex
import com.stevesoltys.seedvault.transport.TransportTest import com.stevesoltys.seedvault.transport.TransportTest
@ -42,9 +41,8 @@ internal class RestoreV0IntegrationTest : TransportTest() {
) )
private val keyManager = KeyManagerTestImpl(secretKey) private val keyManager = KeyManagerTestImpl(secretKey)
private val cipherFactory = CipherFactoryImpl(keyManager) private val cipherFactory = CipherFactoryImpl(keyManager)
private val headerWriter = HeaderWriterImpl()
private val headerReader = HeaderReaderImpl() private val headerReader = HeaderReaderImpl()
private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerWriter, headerReader) private val cryptoImpl = CryptoImpl(keyManager, cipherFactory, headerReader)
private val metadataReader = MetadataReaderImpl(cryptoImpl) private val metadataReader = MetadataReaderImpl(cryptoImpl)
private val notificationManager = mockk<BackupNotificationManager>() private val notificationManager = mockk<BackupNotificationManager>()