Skip to content

Commit

Permalink
Splitted class into RecordIoStorageClient and StreamingAeadStorageClient
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopremier committed Oct 25, 2024
1 parent b09152b commit 49e080f
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 177 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.securecomputation.teesdk.cloudstorage.v1alpha

import com.google.protobuf.ByteString
import java.util.logging.Logger
import org.wfanet.measurement.storage.StorageClient
import java.io.ByteArrayOutputStream
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow

/**
* A wrapper class for the [StorageClient] interface that handles Apache Mesos RecordIO formatted files
* for blob/object storage operations.
*
* This class supports row-based reading and writing, enabling the processing of individual rows
* at the client's pace. The implementation focuses on handling record-level operations for
* RecordIO files.
*
* @param storageClient underlying client for accessing blob/object storage
*/
class RecordIoStorageClient(private val storageClient: StorageClient) : StorageClient {

/**
* Writes RecordIO rows to storage using the RecordIO format.
*
* This function takes a flow of RecordIO rows (represented as a Flow<ByteString>), formats each row
* by prepending the record size and a newline character (`\n`), and writes the formatted content
* to storage.
*
* The function handles rows emitted at any pace, meaning it can process rows that are emitted asynchronously
* with delays in between.
*
* @param blobKey The key (or name) of the blob where the content will be stored.
* @param content A Flow<ByteString> representing the source of RecordIO rows that will be stored.
*
* @return A Blob object representing the RecordIO data that was written to storage.
*/
override suspend fun writeBlob(blobKey: String, content: Flow<ByteString>): StorageClient.Blob {
val processedContent = flow {
val outputStream = ByteArrayOutputStream()

content.collect { byteString ->
val rawBytes = byteString.toByteArray()
val recordSize = rawBytes.size.toString()
val fullRecord = recordSize + "\n" + String(rawBytes, Charsets.UTF_8)
val fullRecordBytes = fullRecord.toByteArray(Charsets.UTF_8)
outputStream.write(fullRecordBytes)
emit(ByteString.copyFrom(outputStream.toByteArray()))
outputStream.reset()
}

val remainingBytes = outputStream.toByteArray()
if (remainingBytes.isNotEmpty()) {
emit(ByteString.copyFrom(remainingBytes))
}
}

val wrappedBlob: StorageClient.Blob = storageClient.writeBlob(blobKey, processedContent)
logger.fine { "Wrote content to storage with blobKey: $blobKey" }
return RecordioBlob(wrappedBlob, blobKey)
}

/**
* Returns a [StorageClient.Blob] with specified blob key, or `null` if not found.
*
* Blob content is read as RecordIO format when [RecordioBlob.read] is called
*/
override suspend fun getBlob(blobKey: String): StorageClient.Blob? {
val blob = storageClient.getBlob(blobKey)
return blob?.let { RecordioBlob(it, blobKey) }
}

/** A blob that will read the content in RecordIO format */
private inner class RecordioBlob(private val blob: StorageClient.Blob, private val blobKey: String) :
StorageClient.Blob {
override val storageClient = this@RecordIoStorageClient.storageClient

override val size: Long
get() = blob.size

/**
* Reads data from storage in RecordIO format, streaming chunks of data and processing them
* on-the-fly to extract individual records.
*
* The function reads each row from an Apache Mesos RecordIO file and emits them individually.
* Each record in the RecordIO format begins with its size followed by a newline character,
* then the actual record data.
*
* @return A Flow of ByteString, where each emission represents a complete record from the
* RecordIO formatted data.
*
* @throws java.io.IOException If there is an issue reading from the stream.
*/
override fun read(): Flow<ByteString> = flow {
val buffer = StringBuilder()
var currentRecordSize = -1
var recordBuffer = ByteArrayOutputStream()

blob.read().collect { chunk ->
var position = 0
val chunkString = chunk.toByteArray().toString(Charsets.UTF_8)

while (position < chunkString.length) {
if (currentRecordSize == -1) {
while (position < chunkString.length) {
val char = chunkString[position++]
if (char == '\n') {
currentRecordSize = buffer.toString().toInt()
buffer.clear()
recordBuffer = ByteArrayOutputStream(currentRecordSize)
break
}
buffer.append(char)
}
}
if (currentRecordSize > 0) {
val remainingBytes = chunkString.length - position
val bytesToRead = minOf(remainingBytes, currentRecordSize - recordBuffer.size())

if (bytesToRead > 0) {
recordBuffer.write(
chunkString.substring(position, position + bytesToRead)
.toByteArray(Charsets.UTF_8)
)
position += bytesToRead
}
if (recordBuffer.size() == currentRecordSize) {
emit(ByteString.copyFrom(recordBuffer.toByteArray()))
currentRecordSize = -1
recordBuffer = ByteArrayOutputStream()
}
}
}
}
}

override suspend fun delete() = blob.delete()

}

companion object {
internal val logger = Logger.getLogger(this::class.java.name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,58 +23,48 @@ import java.io.ByteArrayOutputStream
import java.nio.channels.ReadableByteChannel
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import org.jetbrains.annotations.BlockingExecutor

/**
* A wrapper class for the [StorageClient] interface that leverages Tink AEAD encryption/decryption
* for blob/object storage operations on files formatted using Apache Mesos RecordIO.
* for blob/object storage operations.
*
* This class supports row-based encryption and decryption, enabling the processing of individual
* rows at the client’s pace. Unlike [KmsStorageClient],
* which encrypts entire blobs, this class focuses on handling encryption and decryption at the
* record level inside RecordIO files.
* This class provides streaming encryption and decryption of data using StreamingAead,
* enabling secure storage of large files by processing them in chunks.
*
* @param storageClient underlying client for accessing blob/object storage
* @param dataKey a base64-encoded symmetric data key
* @param streamingAead the StreamingAead instance used for encryption/decryption
* @param streamingAeadContext coroutine context for encryption/decryption operations
*/
class RecordIoStorageClient(
class StreamingAeadStorageClient(
private val storageClient: StorageClient,
private val streamingAead: StreamingAead,
private val streamingAeadContext: @BlockingExecutor CoroutineContext = Dispatchers.IO,
) : StorageClient {

/**
* Encrypts and writes RecordIO rows to Google Cloud Storage using StreamingAead.
*
* This function takes a flow of RecordIO rows (represented as a Flow<ByteString>), formats each row
* by prepending the record size and a newline character (`\n`), and encrypts the entire formatted row
* using StreamingAead before writing the encrypted content to Google Cloud Storage.
*
* The function handles rows emitted at any pace, meaning it can process rows that are emitted asynchronously
* with delays in between.
*
* @param blobKey The key (or name) of the blob where the encrypted content will be stored.
* @param content A Flow<ByteString> representing the source of RecordIO rows that will be encrypted and stored.
*
* @return A Blob object representing the encrypted RecordIO data that was written to Google Cloud Storage.
*/
/**
* Encrypts and writes data to storage using StreamingAead.
*
* This function takes a flow of data chunks, encrypts them using StreamingAead,
* and writes the encrypted content to storage.
*
* @param blobKey The key (or name) of the blob where the encrypted content will be stored.
* @param content A Flow<ByteString> representing the source data that will be encrypted and stored.
*
* @return A Blob object representing the encrypted data that was written to storage.
*/
override suspend fun writeBlob(blobKey: String, content: Flow<ByteString>): StorageClient.Blob {
val encryptedContent = flow {
val outputStream = ByteArrayOutputStream()
val ciphertextChannel = this@RecordIoStorageClient.streamingAead.newEncryptingChannel(
val ciphertextChannel = this@StreamingAeadStorageClient.streamingAead.newEncryptingChannel(
Channels.newChannel(outputStream), blobKey.encodeToByteArray()
)

content.collect { byteString ->
val rawBytes = byteString.toByteArray()
val recordSize = rawBytes.size.toString()
val fullRecord = recordSize + "\n" + String(rawBytes, Charsets.UTF_8)
val fullRecordBytes = fullRecord.toByteArray(Charsets.UTF_8)
val buffer = ByteBuffer.wrap(fullRecordBytes)
val buffer = ByteBuffer.wrap(byteString.toByteArray())
while (buffer.hasRemaining()) {
ciphertextChannel.write(buffer)
}
Expand All @@ -92,36 +82,34 @@ class RecordIoStorageClient(
}
val wrappedBlob: StorageClient.Blob = storageClient.writeBlob(blobKey, encryptedContent)
logger.fine { "Wrote encrypted content to storage with blobKey: $blobKey" }
return RecordioBlob(wrappedBlob, blobKey)
return EncryptedBlob(wrappedBlob, blobKey)
}

/**
* Returns a [StorageClient.Blob] with specified blob key, or `null` if not found.
*
* Blob content is not decrypted until [RecordioBlob.read]
* Blob content is not decrypted until [EncryptedBlob.read]
*/
override suspend fun getBlob(blobKey: String): StorageClient.Blob? {
val blob = storageClient.getBlob(blobKey)
return blob?.let { RecordioBlob(it, blobKey) }
return blob?.let { EncryptedBlob(it, blobKey) }
}

/** A blob that will decrypt the content when read */
private inner class RecordioBlob(private val blob: StorageClient.Blob, private val blobKey: String) :
private inner class EncryptedBlob(private val blob: StorageClient.Blob, private val blobKey: String) :
StorageClient.Blob {
override val storageClient = this@RecordIoStorageClient.storageClient
override val storageClient = this@StreamingAeadStorageClient.storageClient

override val size: Long
get() = blob.size

/**
* This method handles the decryption of data, streaming chunks
* of encrypted data and decrypting them on-the-fly using a StreamingAead instance.
*
* The function then reads each row from an Apache Mesos RecordIO file and emits each one individually.
* Reads and decrypts the blob's content.
*
* @return The number of bytes read from the encrypted stream and written into the buffer.
* Returns -1 when the end of the stream is reached.
* This method handles the decryption of data by collecting all encrypted data first,
* then decrypting it as a single operation.
*
* @return A Flow of ByteString containing the decrypted data.
* @throws java.io.IOException If there is an issue reading from the stream or during decryption.
*/
override fun read(): Flow<ByteString> = flow {
Expand All @@ -134,7 +122,7 @@ class RecordIoStorageClient(
chunkChannel.close()
}

val plaintextChannel = this@RecordIoStorageClient.streamingAead.newDecryptingChannel(
val plaintextChannel = this@StreamingAeadStorageClient.streamingAead.newDecryptingChannel(
object : ReadableByteChannel {
private var currentChunk: ByteString? = null
private var bufferOffset = 0
Expand Down Expand Up @@ -164,42 +152,13 @@ class RecordIoStorageClient(
blobKey.encodeToByteArray()
)

val byteBuffer = ByteBuffer.allocate(4096)
val sizeBuffer = ByteArrayOutputStream()

val buffer = ByteBuffer.allocate(8192)
while (true) {
byteBuffer.clear()
if (plaintextChannel.read(byteBuffer) <= 0) break
byteBuffer.flip()

while (byteBuffer.hasRemaining()) {
val b = byteBuffer.get()

if (b.toInt().toChar() == '\n') {
val recordSize = sizeBuffer.toString(StandardCharsets.UTF_8).trim().toInt()
sizeBuffer.reset()
val recordData = ByteBuffer.allocate(recordSize)
var totalBytesRead = 0
if (byteBuffer.hasRemaining()) {
val bytesToRead = minOf(recordSize, byteBuffer.remaining())
val oldLimit = byteBuffer.limit()
byteBuffer.limit(byteBuffer.position() + bytesToRead)
recordData.put(byteBuffer)
byteBuffer.limit(oldLimit)
totalBytesRead += bytesToRead
}
while (recordData.hasRemaining()) {
val bytesRead = plaintextChannel.read(recordData)
if (bytesRead <= 0) break
totalBytesRead += bytesRead
}

recordData.flip()
emit(ByteString.copyFrom(recordData.array()))
} else {
sizeBuffer.write(b.toInt())
}
}
buffer.clear()
val bytesRead = plaintextChannel.read(buffer)
if (bytesRead <= 0) break
buffer.flip()
emit(ByteString.copyFrom(buffer.array(), 0, buffer.limit()))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ kt_jvm_test(
],
)

kt_jvm_test(
name = "StreamingAeadStorageClientTest",
srcs = ["StreamingAeadStorageClientTest.kt"],
deps = [
"//imports/java/com/google/common/truth",
"//imports/java/com/google/crypto/tink",
"//imports/java/org/junit",
"//src/main/kotlin/org/wfanet/measurement/common/crypto/tink",
"//src/main/kotlin/org/wfanet/measurement/common/crypto/tink/testing",
"//src/main/kotlin/org/wfanet/measurement/storage/testing",
],
)

kt_jvm_test(
name = "SelfIssuedIdTokensTest",
srcs = ["SelfIssuedIdTokensTest.kt"],
Expand Down
Loading

0 comments on commit 49e080f

Please sign in to comment.