package com.stevesoltys.backup.header import com.stevesoltys.backup.Utf8 import com.stevesoltys.backup.assertContains import com.stevesoltys.backup.getRandomString import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS import java.io.ByteArrayInputStream import java.io.IOException import java.nio.ByteBuffer import kotlin.random.Random @TestInstance(PER_CLASS) internal class HeaderReaderTest { private val reader = HeaderReaderImpl() // Version Tests @Test fun `valid version is read`() { val input = byteArrayOf(VERSION) val inputStream = ByteArrayInputStream(input) assertEquals(VERSION, reader.readVersion(inputStream)) } @Test fun `too short version stream throws exception`() { val input = ByteArray(0) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readVersion(inputStream) } } @Test fun `unsupported version throws exception`() { val input = byteArrayOf((VERSION + 1).toByte()) val inputStream = ByteArrayInputStream(input) assertThrows(UnsupportedVersionException::class.javaObjectType) { reader.readVersion(inputStream) } } @Test fun `negative version throws exception`() { val input = byteArrayOf((-1).toByte()) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readVersion(inputStream) } } @Test fun `max version byte throws exception`() { val input = byteArrayOf(Byte.MAX_VALUE) val inputStream = ByteArrayInputStream(input) assertThrows(UnsupportedVersionException::class.javaObjectType) { reader.readVersion(inputStream) } } // VersionHeader Tests @Test fun `valid VersionHeader is read`() { val input = byteArrayOf(VERSION, 0x00, 0x01, 0x61, 0x00, 0x01, 0x62) val versionHeader = VersionHeader(VERSION, "a", "b") assertEquals(versionHeader, reader.getVersionHeader(input)) } @Test fun `zero package length in VersionHeader throws`() { val input = byteArrayOf(VERSION, 0x00, 0x00, 0x00, 0x01, 0x62) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `negative package length in VersionHeader throws`() { val input = byteArrayOf(0x00, 0xFF, 0xFF, 0x00, 0x01, 0x62) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `too large package length in VersionHeader throws`() { val size = MAX_PACKAGE_LENGTH_SIZE + 1 val input = ByteBuffer.allocate(3 + size) .put(VERSION) .putShort(size.toShort()) .put(ByteArray(size)) .array() val e = assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } assertContains(e.message, size.toString()) } @Test fun `insufficient bytes for package in VersionHeader throws`() { val input = byteArrayOf(VERSION, 0x00, 0x50) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `zero key length in VersionHeader gets accepted`() { val input = byteArrayOf(VERSION, 0x00, 0x01, 0x61, 0x00, 0x00) val versionHeader = VersionHeader(VERSION, "a", null) assertEquals(versionHeader, reader.getVersionHeader(input)) } @Test fun `negative key length in VersionHeader throws`() { val input = byteArrayOf(0x00, 0x00, 0x01, 0x61, 0xFF, 0xFF) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `too large key length in VersionHeader throws`() { val size = MAX_KEY_LENGTH_SIZE + 1 val input = ByteBuffer.allocate(4 + size) .put(VERSION) .putShort(1.toShort()) .put("a".toByteArray(Utf8)) .putShort(size.toShort()) .array() val e = assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } assertContains(e.message, size.toString()) } @Test fun `insufficient bytes for key in VersionHeader throws`() { val input = byteArrayOf(0x00, 0x00, 0x01, 0x61, 0x00, 0x50) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `extra bytes in VersionHeader throws`() { val input = byteArrayOf(VERSION, 0x00, 0x01, 0x61, 0x00, 0x01, 0x62, 0x00) assertThrows(SecurityException::class.javaObjectType) { reader.getVersionHeader(input) } } @Test fun `max sized VersionHeader gets accepted`() { val packageName = getRandomString(MAX_PACKAGE_LENGTH_SIZE) val key = getRandomString(MAX_KEY_LENGTH_SIZE) val input = ByteBuffer.allocate(MAX_VERSION_HEADER_SIZE) .put(VERSION) .putShort(MAX_PACKAGE_LENGTH_SIZE.toShort()) .put(packageName.toByteArray(Utf8)) .putShort(MAX_KEY_LENGTH_SIZE.toShort()) .put(key.toByteArray(Utf8)) .array() assertEquals(MAX_VERSION_HEADER_SIZE, input.size) val h = reader.getVersionHeader(input) assertEquals(VERSION, h.version) assertEquals(packageName, h.packageName) assertEquals(key, h.key) } // SegmentHeader Tests @Test fun `too short SegmentHeader throws exception`() { val input = byteArrayOf(0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readSegmentHeader(inputStream) } } @Test fun `segment length of zero is rejected`() { val input = byteArrayOf(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readSegmentHeader(inputStream) } } @Test fun `negative segment length is rejected`() { val input = byteArrayOf(0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readSegmentHeader(inputStream) } } @Test fun `minimum negative segment length is rejected`() { val input = byteArrayOf(0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertThrows(IOException::class.javaObjectType) { reader.readSegmentHeader(inputStream) } } @Test fun `max segment length is accepted`() { val input = byteArrayOf(0x7F, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertEquals(MAX_SEGMENT_LENGTH, reader.readSegmentHeader(inputStream).segmentLength.toInt()) } @Test fun `min segment length of 1 is accepted`() { val input = byteArrayOf(0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) val inputStream = ByteArrayInputStream(input) assertEquals(1, reader.readSegmentHeader(inputStream).segmentLength.toInt()) } @Test fun `segment length is always read correctly`() { val segmentLength = getRandomValidSegmentLength() val input = ByteBuffer.allocate(SEGMENT_HEADER_SIZE) .putShort(segmentLength) .put(ByteArray(IV_SIZE)) .array() val inputStream = ByteArrayInputStream(input) assertEquals(segmentLength, reader.readSegmentHeader(inputStream).segmentLength) } @Test fun `nonce is read in big endian`() { val nonce = byteArrayOf(0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01) val input = byteArrayOf(0x00, 0x01, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01) val inputStream = ByteArrayInputStream(input) assertArrayEquals(nonce, reader.readSegmentHeader(inputStream).nonce) } @Test fun `nonce is always read correctly`() { val nonce = ByteArray(IV_SIZE).apply { Random.nextBytes(this) } val input = ByteBuffer.allocate(SEGMENT_HEADER_SIZE) .putShort(1) .put(nonce) .array() val inputStream = ByteArrayInputStream(input) assertArrayEquals(nonce, reader.readSegmentHeader(inputStream).nonce) } private fun byteArrayOf(vararg elements: Int): ByteArray { return elements.map { it.toByte() }.toByteArray() } } internal fun getRandomValidSegmentLength(): Short { return Random.nextInt(1, Short.MAX_VALUE.toInt()).toShort() }