Ensure ChunkWriter uses current backend

When changing backends, the ChunkWriter could still use the old one causing data loss, because chunks assumed to exist on new backend, were written to old one.
This commit is contained in:
Torsten Grote 2024-08-28 15:46:16 -03:00
parent c83e8f392e
commit b59da2a805
No known key found for this signature in database
GPG key ID: 3E5F77D92CF891FF
6 changed files with 46 additions and 6 deletions

View file

@ -17,7 +17,7 @@ import java.io.IOException
import java.security.GeneralSecurityException import java.security.GeneralSecurityException
internal class SnapshotRetriever( internal class SnapshotRetriever(
private val storagePlugin: () -> Backend, private val backendGetter: () -> Backend,
private val streamCrypto: StreamCrypto = StreamCrypto, private val streamCrypto: StreamCrypto = StreamCrypto,
) { ) {
@ -27,7 +27,7 @@ internal class SnapshotRetriever(
InvalidProtocolBufferException::class, InvalidProtocolBufferException::class,
) )
suspend fun getSnapshot(streamKey: ByteArray, storedSnapshot: StoredSnapshot): BackupSnapshot { suspend fun getSnapshot(streamKey: ByteArray, storedSnapshot: StoredSnapshot): BackupSnapshot {
return storagePlugin().load(storedSnapshot.snapshotHandle).use { inputStream -> return backendGetter().load(storedSnapshot.snapshotHandle).use { inputStream ->
val version = inputStream.readVersion() val version = inputStream.readVersion()
val timestamp = storedSnapshot.timestamp val timestamp = storedSnapshot.timestamp
val ad = streamCrypto.getAssociatedDataForSnapshot(timestamp, version.toByte()) val ad = streamCrypto.getAssociatedDataForSnapshot(timestamp, version.toByte())

View file

@ -74,7 +74,8 @@ internal class Backup(
} catch (e: GeneralSecurityException) { } catch (e: GeneralSecurityException) {
throw AssertionError(e) throw AssertionError(e)
} }
private val chunkWriter = ChunkWriter(streamCrypto, streamKey, chunksCache, backend, androidId) private val chunkWriter =
ChunkWriter(streamCrypto, streamKey, chunksCache, backendGetter, androidId)
private val hasMediaAccessPerm = private val hasMediaAccessPerm =
context.checkSelfPermission(ACCESS_MEDIA_LOCATION) == PERMISSION_GRANTED context.checkSelfPermission(ACCESS_MEDIA_LOCATION) == PERMISSION_GRANTED
private val fileBackup = FileBackup( private val fileBackup = FileBackup(

View file

@ -31,11 +31,12 @@ internal class ChunkWriter(
private val streamCrypto: StreamCrypto, private val streamCrypto: StreamCrypto,
private val streamKey: ByteArray, private val streamKey: ByteArray,
private val chunksCache: ChunksCache, private val chunksCache: ChunksCache,
private val backend: Backend, private val backendGetter: () -> Backend,
private val androidId: String, private val androidId: String,
private val bufferSize: Int = DEFAULT_BUFFER_SIZE, private val bufferSize: Int = DEFAULT_BUFFER_SIZE,
) { ) {
private val backend get() = backendGetter()
private val buffer = ByteArray(bufferSize) private val buffer = ByteArray(bufferSize)
@Throws(IOException::class, GeneralSecurityException::class) @Throws(IOException::class, GeneralSecurityException::class)

View file

@ -14,6 +14,7 @@ import android.text.format.Formatter
import io.mockk.Runs import io.mockk.Runs
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.coVerifyOrder
import io.mockk.every import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
@ -475,6 +476,39 @@ internal class BackupRestoreTest {
} }
} }
@Test
fun testBackupUpdatesBackend(): Unit = runBlocking {
val backendGetterNew: () -> Backend = mockk()
val backend1: Backend = mockk()
val backend2: Backend = mockk()
val backup = Backup(
context = context,
db = db,
fileScanner = fileScanner,
backendGetter = backendGetterNew,
androidId = androidId,
keyManager = keyManager,
cacheRepopulater = cacheRepopulater,
)
every { backendGetterNew() } returnsMany listOf(backend1, backend2)
coEvery { backend1.list(any(), Blob::class, callback = any()) } just Runs
every { chunksCache.areAllAvailableChunksCached(db, emptySet()) } returns true
every { fileScanner.getFiles() } returns FileScannerResult(emptyList(), emptyList())
every { filesCache.getByUri(any()) } returns null // nothing is cached, all is new
backup.runBackup(null)
// second run uses new backend
coEvery { backend2.list(any(), Blob::class, callback = any()) } just Runs
backup.runBackup(null)
coVerifyOrder {
backend1.list(any(), Blob::class, callback = any())
backend2.list(any(), Blob::class, callback = any())
}
}
private fun getRandomMediaFile(size: Int) = MediaFile( private fun getRandomMediaFile(size: Int) = MediaFile(
uri = mockk(), uri = mockk(),
dir = getRandomString(), dir = getRandomString(),

View file

@ -32,6 +32,7 @@ internal class ChunkWriterTest {
private val streamCrypto: StreamCrypto = mockk() private val streamCrypto: StreamCrypto = mockk()
private val chunksCache: ChunksCache = mockk() private val chunksCache: ChunksCache = mockk()
private val backendGetter: () -> Backend = mockk()
private val backend: Backend = mockk() private val backend: Backend = mockk()
private val androidId: String = getRandomString() private val androidId: String = getRandomString()
private val streamKey: ByteArray = Random.nextBytes(KEY_SIZE_BYTES) private val streamKey: ByteArray = Random.nextBytes(KEY_SIZE_BYTES)
@ -42,7 +43,7 @@ internal class ChunkWriterTest {
streamCrypto = streamCrypto, streamCrypto = streamCrypto,
streamKey = streamKey, streamKey = streamKey,
chunksCache = chunksCache, chunksCache = chunksCache,
backend = backend, backendGetter = backendGetter,
androidId = androidId, androidId = androidId,
bufferSize = Random.nextInt(1, 42), bufferSize = Random.nextInt(1, 42),
) )
@ -53,6 +54,7 @@ internal class ChunkWriterTest {
init { init {
mockLog() mockLog()
every { backendGetter() } returns backend
} }
@Test @Test

View file

@ -39,6 +39,7 @@ internal class SmallFileBackupIntegrationTest {
private val filesCache: FilesCache = mockk() private val filesCache: FilesCache = mockk()
private val mac: Mac = mockk() private val mac: Mac = mockk()
private val chunksCache: ChunksCache = mockk() private val chunksCache: ChunksCache = mockk()
private val backendGetter: () -> Backend = mockk()
private val backend: Backend = mockk() private val backend: Backend = mockk()
private val androidId: String = getRandomString() private val androidId: String = getRandomString()
@ -46,7 +47,7 @@ internal class SmallFileBackupIntegrationTest {
streamCrypto = StreamCrypto, streamCrypto = StreamCrypto,
streamKey = Random.nextBytes(KEY_SIZE_BYTES), streamKey = Random.nextBytes(KEY_SIZE_BYTES),
chunksCache = chunksCache, chunksCache = chunksCache,
backend = backend, backendGetter = backendGetter,
androidId = androidId, androidId = androidId,
) )
private val zipChunker = ZipChunker( private val zipChunker = ZipChunker(
@ -58,6 +59,7 @@ internal class SmallFileBackupIntegrationTest {
init { init {
mockLog() mockLog()
every { backendGetter() } returns backend
} }
/** /**