diff --git a/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClient.kt b/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClient.kt new file mode 100644 index 000000000..fc7e45831 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClient.kt @@ -0,0 +1,157 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.securecomputation.teesdk.cloudstorage.v1alpha + +import com.google.protobuf.ByteString +import java.util.logging.Logger +import org.wfanet.measurement.storage.StorageClient +import java.io.ByteArrayOutputStream +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * A wrapper class for the [StorageClient] interface that handles Apache Mesos RecordIO formatted files + * for blob/object storage operations. + * + * This class supports row-based reading and writing, enabling the processing of individual rows + * at the client's pace. The implementation focuses on handling record-level operations for + * RecordIO files. + * + * @param storageClient underlying client for accessing blob/object storage + */ +class RecordIoStorageClient(private val storageClient: StorageClient) : StorageClient { + + /** + * Writes RecordIO rows to storage using the RecordIO format. + * + * This function takes a flow of RecordIO rows (represented as a Flow), formats each row + * by prepending the record size and a newline character (`\n`), and writes the formatted content + * to storage. + * + * The function handles rows emitted at any pace, meaning it can process rows that are emitted asynchronously + * with delays in between. + * + * @param blobKey The key (or name) of the blob where the content will be stored. + * @param content A Flow representing the source of RecordIO rows that will be stored. + * + * @return A Blob object representing the RecordIO data that was written to storage. + */ + override suspend fun writeBlob(blobKey: String, content: Flow): StorageClient.Blob { + val processedContent = flow { + val outputStream = ByteArrayOutputStream() + + content.collect { byteString -> + val rawBytes = byteString.toByteArray() + val recordSize = rawBytes.size.toString() + val fullRecord = recordSize + "\n" + String(rawBytes, Charsets.UTF_8) + val fullRecordBytes = fullRecord.toByteArray(Charsets.UTF_8) + outputStream.write(fullRecordBytes) + emit(ByteString.copyFrom(outputStream.toByteArray())) + outputStream.reset() + } + + val remainingBytes = outputStream.toByteArray() + if (remainingBytes.isNotEmpty()) { + emit(ByteString.copyFrom(remainingBytes)) + } + } + + val wrappedBlob: StorageClient.Blob = storageClient.writeBlob(blobKey, processedContent) + logger.fine { "Wrote content to storage with blobKey: $blobKey" } + return RecordioBlob(wrappedBlob, blobKey) + } + + /** + * Returns a [StorageClient.Blob] with specified blob key, or `null` if not found. + * + * Blob content is read as RecordIO format when [RecordioBlob.read] is called + */ + override suspend fun getBlob(blobKey: String): StorageClient.Blob? { + val blob = storageClient.getBlob(blobKey) + return blob?.let { RecordioBlob(it, blobKey) } + } + + /** A blob that will read the content in RecordIO format */ + private inner class RecordioBlob(private val blob: StorageClient.Blob, private val blobKey: String) : + StorageClient.Blob { + override val storageClient = this@RecordIoStorageClient.storageClient + + override val size: Long + get() = blob.size + + /** + * Reads data from storage in RecordIO format, streaming chunks of data and processing them + * on-the-fly to extract individual records. + * + * The function reads each row from an Apache Mesos RecordIO file and emits them individually. + * Each record in the RecordIO format begins with its size followed by a newline character, + * then the actual record data. + * + * @return A Flow of ByteString, where each emission represents a complete record from the + * RecordIO formatted data. + * + * @throws java.io.IOException If there is an issue reading from the stream. + */ + override fun read(): Flow = flow { + val buffer = StringBuilder() + var currentRecordSize = -1 + var recordBuffer = ByteArrayOutputStream() + + blob.read().collect { chunk -> + var position = 0 + val chunkString = chunk.toByteArray().toString(Charsets.UTF_8) + + while (position < chunkString.length) { + if (currentRecordSize == -1) { + while (position < chunkString.length) { + val char = chunkString[position++] + if (char == '\n') { + currentRecordSize = buffer.toString().toInt() + buffer.clear() + recordBuffer = ByteArrayOutputStream(currentRecordSize) + break + } + buffer.append(char) + } + } + if (currentRecordSize > 0) { + val remainingBytes = chunkString.length - position + val bytesToRead = minOf(remainingBytes, currentRecordSize - recordBuffer.size()) + + if (bytesToRead > 0) { + recordBuffer.write( + chunkString.substring(position, position + bytesToRead) + .toByteArray(Charsets.UTF_8) + ) + position += bytesToRead + } + if (recordBuffer.size() == currentRecordSize) { + emit(ByteString.copyFrom(recordBuffer.toByteArray())) + currentRecordSize = -1 + recordBuffer = ByteArrayOutputStream() + } + } + } + } + } + + override suspend fun delete() = blob.delete() + + } + + companion object { + internal val logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageclient.kt b/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClient.kt similarity index 53% rename from src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageclient.kt rename to src/main/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClient.kt index 4f6cad9bf..531b605aa 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageclient.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClient.kt @@ -23,7 +23,6 @@ import java.io.ByteArrayOutputStream import java.nio.channels.ReadableByteChannel import java.nio.ByteBuffer import java.nio.channels.Channels -import java.nio.charset.StandardCharsets import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.* import kotlinx.coroutines.flow.* @@ -31,50 +30,41 @@ import org.jetbrains.annotations.BlockingExecutor /** * A wrapper class for the [StorageClient] interface that leverages Tink AEAD encryption/decryption - * for blob/object storage operations on files formatted using Apache Mesos RecordIO. + * for blob/object storage operations. * - * This class supports row-based encryption and decryption, enabling the processing of individual - * rows at the client’s pace. Unlike [KmsStorageClient], - * which encrypts entire blobs, this class focuses on handling encryption and decryption at the - * record level inside RecordIO files. + * This class provides streaming encryption and decryption of data using StreamingAead, + * enabling secure storage of large files by processing them in chunks. * * @param storageClient underlying client for accessing blob/object storage - * @param dataKey a base64-encoded symmetric data key + * @param streamingAead the StreamingAead instance used for encryption/decryption + * @param streamingAeadContext coroutine context for encryption/decryption operations */ -class RecordIoStorageClient( +class StreamingAeadStorageClient( private val storageClient: StorageClient, private val streamingAead: StreamingAead, private val streamingAeadContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, ) : StorageClient { -/** - * Encrypts and writes RecordIO rows to Google Cloud Storage using StreamingAead. - * - * This function takes a flow of RecordIO rows (represented as a Flow), formats each row - * by prepending the record size and a newline character (`\n`), and encrypts the entire formatted row - * using StreamingAead before writing the encrypted content to Google Cloud Storage. - * - * The function handles rows emitted at any pace, meaning it can process rows that are emitted asynchronously - * with delays in between. - * - * @param blobKey The key (or name) of the blob where the encrypted content will be stored. - * @param content A Flow representing the source of RecordIO rows that will be encrypted and stored. - * - * @return A Blob object representing the encrypted RecordIO data that was written to Google Cloud Storage. - */ + /** + * Encrypts and writes data to storage using StreamingAead. + * + * This function takes a flow of data chunks, encrypts them using StreamingAead, + * and writes the encrypted content to storage. + * + * @param blobKey The key (or name) of the blob where the encrypted content will be stored. + * @param content A Flow representing the source data that will be encrypted and stored. + * + * @return A Blob object representing the encrypted data that was written to storage. + */ override suspend fun writeBlob(blobKey: String, content: Flow): StorageClient.Blob { val encryptedContent = flow { val outputStream = ByteArrayOutputStream() - val ciphertextChannel = this@RecordIoStorageClient.streamingAead.newEncryptingChannel( + val ciphertextChannel = this@StreamingAeadStorageClient.streamingAead.newEncryptingChannel( Channels.newChannel(outputStream), blobKey.encodeToByteArray() ) content.collect { byteString -> - val rawBytes = byteString.toByteArray() - val recordSize = rawBytes.size.toString() - val fullRecord = recordSize + "\n" + String(rawBytes, Charsets.UTF_8) - val fullRecordBytes = fullRecord.toByteArray(Charsets.UTF_8) - val buffer = ByteBuffer.wrap(fullRecordBytes) + val buffer = ByteBuffer.wrap(byteString.toByteArray()) while (buffer.hasRemaining()) { ciphertextChannel.write(buffer) } @@ -92,36 +82,34 @@ class RecordIoStorageClient( } val wrappedBlob: StorageClient.Blob = storageClient.writeBlob(blobKey, encryptedContent) logger.fine { "Wrote encrypted content to storage with blobKey: $blobKey" } - return RecordioBlob(wrappedBlob, blobKey) + return EncryptedBlob(wrappedBlob, blobKey) } /** * Returns a [StorageClient.Blob] with specified blob key, or `null` if not found. * - * Blob content is not decrypted until [RecordioBlob.read] + * Blob content is not decrypted until [EncryptedBlob.read] */ override suspend fun getBlob(blobKey: String): StorageClient.Blob? { val blob = storageClient.getBlob(blobKey) - return blob?.let { RecordioBlob(it, blobKey) } + return blob?.let { EncryptedBlob(it, blobKey) } } /** A blob that will decrypt the content when read */ - private inner class RecordioBlob(private val blob: StorageClient.Blob, private val blobKey: String) : + private inner class EncryptedBlob(private val blob: StorageClient.Blob, private val blobKey: String) : StorageClient.Blob { - override val storageClient = this@RecordIoStorageClient.storageClient + override val storageClient = this@StreamingAeadStorageClient.storageClient override val size: Long get() = blob.size /** - * This method handles the decryption of data, streaming chunks - * of encrypted data and decrypting them on-the-fly using a StreamingAead instance. - * - * The function then reads each row from an Apache Mesos RecordIO file and emits each one individually. + * Reads and decrypts the blob's content. * - * @return The number of bytes read from the encrypted stream and written into the buffer. - * Returns -1 when the end of the stream is reached. + * This method handles the decryption of data by collecting all encrypted data first, + * then decrypting it as a single operation. * + * @return A Flow of ByteString containing the decrypted data. * @throws java.io.IOException If there is an issue reading from the stream or during decryption. */ override fun read(): Flow = flow { @@ -134,7 +122,7 @@ class RecordIoStorageClient( chunkChannel.close() } - val plaintextChannel = this@RecordIoStorageClient.streamingAead.newDecryptingChannel( + val plaintextChannel = this@StreamingAeadStorageClient.streamingAead.newDecryptingChannel( object : ReadableByteChannel { private var currentChunk: ByteString? = null private var bufferOffset = 0 @@ -164,42 +152,13 @@ class RecordIoStorageClient( blobKey.encodeToByteArray() ) - val byteBuffer = ByteBuffer.allocate(4096) - val sizeBuffer = ByteArrayOutputStream() - + val buffer = ByteBuffer.allocate(8192) while (true) { - byteBuffer.clear() - if (plaintextChannel.read(byteBuffer) <= 0) break - byteBuffer.flip() - - while (byteBuffer.hasRemaining()) { - val b = byteBuffer.get() - - if (b.toInt().toChar() == '\n') { - val recordSize = sizeBuffer.toString(StandardCharsets.UTF_8).trim().toInt() - sizeBuffer.reset() - val recordData = ByteBuffer.allocate(recordSize) - var totalBytesRead = 0 - if (byteBuffer.hasRemaining()) { - val bytesToRead = minOf(recordSize, byteBuffer.remaining()) - val oldLimit = byteBuffer.limit() - byteBuffer.limit(byteBuffer.position() + bytesToRead) - recordData.put(byteBuffer) - byteBuffer.limit(oldLimit) - totalBytesRead += bytesToRead - } - while (recordData.hasRemaining()) { - val bytesRead = plaintextChannel.read(recordData) - if (bytesRead <= 0) break - totalBytesRead += bytesRead - } - - recordData.flip() - emit(ByteString.copyFrom(recordData.array())) - } else { - sizeBuffer.write(b.toInt()) - } - } + buffer.clear() + val bytesRead = plaintextChannel.read(buffer) + if (bytesRead <= 0) break + buffer.flip() + emit(ByteString.copyFrom(buffer.array(), 0, buffer.limit())) } } diff --git a/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/BUILD.bazel index d6c660440..b2b7496fc 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/BUILD.bazel @@ -40,6 +40,19 @@ kt_jvm_test( ], ) +kt_jvm_test( + name = "StreamingAeadStorageClientTest", + srcs = ["StreamingAeadStorageClientTest.kt"], + deps = [ + "//imports/java/com/google/common/truth", + "//imports/java/com/google/crypto/tink", + "//imports/java/org/junit", + "//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", + "//src/main/kotlin/org/wfanet/measurement/common/crypto/tink/testing", + "//src/main/kotlin/org/wfanet/measurement/storage/testing", + ], +) + kt_jvm_test( name = "SelfIssuedIdTokensTest", srcs = ["SelfIssuedIdTokensTest.kt"], diff --git a/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClientTest.kt b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClientTest.kt index d1fc49ef0..d3ac20628 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClientTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/RecordIoStorageClientTest.kt @@ -17,138 +17,92 @@ package org.wfanet.measurement.common.crypto.tink import com.google.common.truth.Truth.assertThat -import com.google.crypto.tink.KeyTemplates -import com.google.crypto.tink.KeysetHandle -import com.google.crypto.tink.StreamingAead -import com.google.crypto.tink.streamingaead.StreamingAeadConfig import com.google.protobuf.ByteString -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer -import java.nio.channels.Channels -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.* import kotlinx.coroutines.runBlocking import org.junit.Before import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.measurement.securecomputation.teesdk.cloudstorage.v1alpha.RecordIoStorageClient -import org.wfanet.measurement.storage.testing.AbstractStreamingStorageClientTest -import org.wfanet.measurement.storage.testing.BlobSubject +import org.wfanet.measurement.storage.StorageClient import org.wfanet.measurement.storage.testing.InMemoryStorageClient @RunWith(JUnit4::class) -class RecordIoStorageClientTest : AbstractStreamingStorageClientTest() { - private val wrappedStorageClient = InMemoryStorageClient() +class RecordIoStorageClientTest { - @Test - fun `Blob size returns content size`() = runBlocking { - val blobKey = "blob-to-check-size" - - val blob = storageClient.writeBlob(blobKey, testBlobContent) - val wrappedBlob = wrappedStorageClient.getBlob(blobKey) - - BlobSubject.assertThat(blob).hasSize(wrappedBlob!!.size.toInt()) - } + private lateinit var wrappedStorageClient: StorageClient + private lateinit var recordIoStorageClient: RecordIoStorageClient @Before fun initStorageClient() { - storageClient = RecordIoStorageClient( - wrappedStorageClient, - streamingAead - ) + wrappedStorageClient = InMemoryStorageClient() + recordIoStorageClient = RecordIoStorageClient(wrappedStorageClient) } - @Test - fun `test write and read single record`() = runBlocking { - val blobKey = "test-key" - - val record = """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""" - val inputFlow = flow { emit(ByteString.copyFromUtf8(record)) } - val blob = storageClient.writeBlob(blobKey, inputFlow) - val readRecords = mutableListOf() - blob.read().collect { byteString -> - readRecords.add(byteString.toStringUtf8()) - } - assertThat(readRecords).hasSize(1) - assertThat(readRecords[0]).isEqualTo(record) + fun `test writing and reading single record`() = runBlocking { + val testData = "Hello World" + val blobKey = "test-single-record" + recordIoStorageClient.writeBlob( + blobKey, + flowOf(ByteString.copyFromUtf8(testData)) + ) + val blob = recordIoStorageClient.getBlob(blobKey) + requireNotNull(blob) { "Blob should exist" } + val records = blob.read().toList() + assertThat(1).isEqualTo(records.size) + assertThat(testData).isEqualTo(records[0].toStringUtf8()) } @Test - fun `test write and read multiple records`() = runBlocking { - val blobKey = "test-key" - val records = listOf( - """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""", - """{"type":"HEARTBEAT"}""", - """{"type":"HEARTBEAT_ACK"}""" + fun `test writing and reading large records`() = runBlocking { + val largeString = """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""".repeat(130000) // ~4MB + val testData = listOf(largeString) + val blobKey = "test-large-records" + recordIoStorageClient.writeBlob( + blobKey, + testData.map { ByteString.copyFromUtf8(it) }.asFlow() ) - val inputFlow = flow { - records.forEach { record -> - emit(ByteString.copyFromUtf8(record)) - } - } - val blob = storageClient.writeBlob(blobKey, inputFlow) - val readRecords = mutableListOf() - blob.read().collect { byteString -> - readRecords.add(byteString.toStringUtf8()) + val blob = recordIoStorageClient.getBlob(blobKey) + requireNotNull(blob) { "Blob should exist" } + val records = blob.read().toList() + assertThat(testData.size).isEqualTo(records.size) + records.forEachIndexed { index, record -> + assertThat(testData[index]).isEqualTo(record.toStringUtf8()) } - assertThat(readRecords).hasSize(records.size) - assertThat(readRecords).containsExactlyElementsIn(records).inOrder() } - @Test - fun `test write and read large records`() = runBlocking { - val blobKey = "test-key" - val largeRecord = buildString { - repeat(130000) {// ~ 4MB - append("""{"type": "LARGE_RECORD", "index": $it},""") - } - } - val inputFlow = flow { emit(ByteString.copyFromUtf8(largeRecord)) } - val blob = storageClient.writeBlob(blobKey, inputFlow) - val readRecords = mutableListOf() - blob.read().collect { byteString -> - readRecords.add(byteString.toStringUtf8()) - } - assertThat(readRecords).hasSize(1) - assertThat(readRecords[0]).isEqualTo(largeRecord) + @Test + fun `test writing empty flow`() = runBlocking { + val blobKey = "test-empty-flow" + recordIoStorageClient.writeBlob(blobKey, emptyFlow()) + val blob = recordIoStorageClient.getBlob(blobKey) + requireNotNull(blob) { "Blob should exist" } + val records = blob.read().toList() + assertThat(records).isEmpty() } @Test - fun `wrapped blob is encrypted`() = runBlocking { - val blobKey = "test-blob" - val testContent = """{"type": "TEST_RECORD", "data": "test content"}""" - val inputFlow = flow { emit(ByteString.copyFromUtf8(testContent)) } - storageClient.writeBlob(blobKey, inputFlow) - val encryptedBlob = wrappedStorageClient.getBlob(blobKey) - val decryptedContent = ByteArrayOutputStream() - val decryptingChannel = streamingAead.newDecryptingChannel( - Channels.newChannel(ByteArrayInputStream(encryptedBlob?.read()?.first()?.toByteArray())), - blobKey.encodeToByteArray() + fun `test deleting blob`() = runBlocking { + val blobKey = "test-delete" + val testData = "Test Data" + recordIoStorageClient.writeBlob( + blobKey, + flowOf(ByteString.copyFromUtf8(testData)) ) - val buffer = ByteBuffer.allocate(8192) - while (decryptingChannel.read(buffer) != -1) { - buffer.flip() - decryptedContent.write(buffer.array(), 0, buffer.limit()) - buffer.clear() - } - val recordContent = String(decryptedContent.toByteArray()).split('\n')[1] // Skip the size line - assertThat(recordContent).isEqualTo(testContent) + val blob = recordIoStorageClient.getBlob(blobKey) + requireNotNull(blob) { "Blob should exist" } + blob.delete() + val deletedBlob = recordIoStorageClient.getBlob(blobKey) + assertThat(deletedBlob).isNull() } - companion object { - - init { - StreamingAeadConfig.register() - } - - private val AEAD_KEY_TEMPLATE = KeyTemplates.get("AES128_GCM_HKDF_1MB") - private val KEY_ENCRYPTION_KEY = KeysetHandle.generateNew(AEAD_KEY_TEMPLATE) - private val streamingAead = KEY_ENCRYPTION_KEY.getPrimitive(StreamingAead::class.java) - + @Test + fun `test non-existent blob returns null`() = runBlocking { + val nonExistentBlob = recordIoStorageClient.getBlob("non-existent-key") + assertThat(nonExistentBlob).isNull() } } diff --git a/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClientTest.kt b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClientTest.kt new file mode 100644 index 000000000..763586852 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/common/crypto/tink/StreamingAeadStorageClientTest.kt @@ -0,0 +1,165 @@ +/* + * Copyright 2021 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.crypto.tink + +import com.google.common.truth.Truth.assertThat +import com.google.crypto.tink.KeyTemplates +import com.google.crypto.tink.KeysetHandle +import com.google.crypto.tink.StreamingAead +import com.google.crypto.tink.streamingaead.StreamingAeadConfig +import com.google.protobuf.ByteString +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels +import kotlin.test.assertNotNull +import kotlinx.coroutines.flow.emptyFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.securecomputation.teesdk.cloudstorage.v1alpha.StreamingAeadStorageClient +import org.wfanet.measurement.storage.StorageClient +import org.wfanet.measurement.storage.testing.BlobSubject +import org.wfanet.measurement.storage.testing.InMemoryStorageClient + +@RunWith(JUnit4::class) +class StreamingAeadStorageClientTest { + + private lateinit var wrappedStorageClient: StorageClient + private lateinit var streamingAeadStorageClient: StreamingAeadStorageClient + + @Before + fun initStorageClient() { + wrappedStorageClient = InMemoryStorageClient() + streamingAeadStorageClient = StreamingAeadStorageClient( + wrappedStorageClient, + streamingAead + ) + } + + @Test + fun `test write and read single record`() = runBlocking { + val blobKey = "test-key" + val record = """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""" + val inputFlow = flow { emit(ByteString.copyFromUtf8(record)) } + val blob = streamingAeadStorageClient.writeBlob(blobKey, inputFlow) + val readRecords = mutableListOf() + blob.read().collect { byteString -> + readRecords.add(byteString.toStringUtf8()) + } + assertThat(readRecords).hasSize(1) + assertThat(readRecords[0]).isEqualTo(record) + } + + @Test + fun `test write and read multiple records`() = runBlocking { + val blobKey = "test-key" + val records = listOf( + """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""", + """{"type":"HEARTBEAT"}""", + """{"type":"HEARTBEAT_ACK"}""" + ) + val combinedRecords = records.joinToString("") + val inputFlow = flow { + records.forEach { record -> + emit(ByteString.copyFromUtf8(record)) + } + } + val blob = streamingAeadStorageClient.writeBlob(blobKey, inputFlow) + val combinedContent = ByteArrayOutputStream() + blob.read().collect { byteString -> + combinedContent.write(byteString.toByteArray()) + } + val readContent = combinedContent.toString(Charsets.UTF_8) + assertThat(readContent).isEqualTo(combinedRecords) + } + + @Test + fun `test write and read large records`() = runBlocking { + val blobKey = "test-key" + + val largeRecord = buildString { + repeat(130000) {// ~ 4MB + append("""{"type": "LARGE_RECORD", "index": $it},""") + } + } + val inputFlow = flow { emit(ByteString.copyFromUtf8(largeRecord)) } + val blob = streamingAeadStorageClient.writeBlob(blobKey, inputFlow) + val combinedContent = ByteArrayOutputStream() + blob.read().collect { byteString -> + combinedContent.write(byteString.toByteArray()) + } + val readContent = combinedContent.toString(Charsets.UTF_8) + assertThat(readContent).isEqualTo(largeRecord) + } + + @Test + fun `wrapped blob is encrypted`() = runBlocking { + val blobKey = "test-blob" + val testContent = """{"type": "TEST_RECORD", "data": "test content"}""" + val inputFlow = flow { emit(ByteString.copyFromUtf8(testContent)) } + streamingAeadStorageClient.writeBlob(blobKey, inputFlow) + val encryptedBlob = wrappedStorageClient.getBlob(blobKey) + val decryptedContent = ByteArrayOutputStream() + val decryptingChannel = streamingAead.newDecryptingChannel( + Channels.newChannel(ByteArrayInputStream(encryptedBlob?.read()?.first()?.toByteArray())), + blobKey.encodeToByteArray() + ) + val buffer = ByteBuffer.allocate(8192) + while (decryptingChannel.read(buffer) != -1) { + buffer.flip() + decryptedContent.write(buffer.array(), 0, buffer.limit()) + buffer.clear() + } + assertThat(String(decryptedContent.toByteArray())).isEqualTo(testContent) + } + + @Test + fun `Blob delete deletes blob`() = runBlocking { + val blobKey = "blob-to-delete" + val record = """{"type": "SUBSCRIBED","subscribed": {"framework_id": {"value":"12220-3440-12532-2345"}}}""" + val inputFlow = flow { emit(ByteString.copyFromUtf8(record)) } + val blob = streamingAeadStorageClient.writeBlob(blobKey, inputFlow) + blob.delete() + assertThat(streamingAeadStorageClient.getBlob(blobKey)).isNull() + } + + @Test + fun `Write and read empty blob`() = runBlocking { + val blobKey = "empty-blob" + streamingAeadStorageClient.writeBlob(blobKey, emptyFlow()) + val blob = assertNotNull(streamingAeadStorageClient.getBlob(blobKey)) + BlobSubject.assertThat(blob).contentEqualTo(ByteString.EMPTY) + } + + companion object { + + init { + StreamingAeadConfig.register() + } + + private val AEAD_KEY_TEMPLATE = KeyTemplates.get("AES128_GCM_HKDF_1MB") + private val KEY_ENCRYPTION_KEY = KeysetHandle.generateNew(AEAD_KEY_TEMPLATE) + private val streamingAead = KEY_ENCRYPTION_KEY.getPrimitive(StreamingAead::class.java) + + } +} +