Skip to content

Commit

Permalink
Add tests for Readable and Writable byte channel
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopremier committed Nov 4, 2024
1 parent 2a19e32 commit b92c8a9
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,53 @@ import kotlinx.coroutines.channels.ReceiveChannel
* coroutine channel.
*
* @property delegate The [ReceiveChannel] from which this [ReadableByteChannel] will read data.
* @return The number of bytes read into the buffer, or -1 if the channel is closed and no data
* remains.
* @constructor Creates a readable channel that reads each [ByteString] from the provided
* [ReceiveChannel] and writes it to the specified [ByteBuffer].
*/
class CoroutineReadableByteChannel(private val delegate: ReceiveChannel<ByteString>) :
ReadableByteChannel {

private var remainingBytes: ByteArray? = null
private var remainingOffset = 0

/**
* Reads bytes from the [ReceiveChannel] and transfers them into the provided buffer. If there are
* leftover bytes from a previous read that didn’t fit in the buffer, they are written first.
*
* When the [ReceiveChannel] has no data available:
* - Returns `0` if the channel is open but temporarily empty, allowing the caller to retry later.
* - Returns `-1` if the channel is closed and no more data will arrive.
*
* If only part of a [ByteString] fits into `destination`, the unread portion is saved for future
* reads, ensuring data is preserved between calls.
*
* @param destination The [ByteBuffer] where data will be written.
* @return The number of bytes written to `destination`, `0` if no data is available, or `-1` if
* the channel is closed and all data has been read.
*/
override fun read(destination: ByteBuffer): Int {
remainingBytes?.let {
val bytesToWrite = minOf(destination.remaining(), it.size - remainingOffset)
destination.put(it, remainingOffset, bytesToWrite)
remainingOffset += bytesToWrite

if (remainingOffset >= it.size) {
remainingBytes = null
remainingOffset = 0
}
return bytesToWrite
}

val result = delegate.tryReceive()
val byteString = result.getOrNull() ?: return -1 // -1 indicates the end of the stream
destination.put(byteString.toByteArray())
return byteString.size()
val byteString = result.getOrNull() ?: return if (result.isClosed) -1 else 0
val bytesToWrite = minOf(destination.remaining(), byteString.size())
byteString.substring(0, bytesToWrite).copyTo(destination)
if (bytesToWrite < byteString.size()) {
remainingBytes = byteString.toByteArray()
remainingOffset = bytesToWrite
}

return bytesToWrite
}

override fun isOpen(): Boolean = !delegate.isClosedForReceive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,34 @@ import kotlinx.coroutines.channels.SendChannel
* @property delegate The [SendChannel] to which this [WritableByteChannel] will send data.
* @constructor Creates a writable channel that writes each [ByteBuffer] as a [ByteString] to the
* provided [SendChannel].
* @throws ClosedChannelException if the channel is closed and cannot accept more data.
*/
class CoroutineWritableByteChannel(private val delegate: SendChannel<ByteString>) :
WritableByteChannel {

/**
* Writes the contents of the provided [ByteBuffer] to the [SendChannel] as a [ByteString].
*
* If the channel is not ready to accept data, the method returns `0` without consuming any bytes
* from the buffer, allowing the caller to retry. If the channel is closed, a
* [ClosedChannelException] is thrown.
*
* @param source The [ByteBuffer] containing the data to write.
* @return The number of bytes written, or `0` if the channel is temporarily unable to accept
* data.
* @throws ClosedChannelException if the channel is closed and cannot accept more data.
*/
override fun write(source: ByteBuffer): Int {
val originalPosition = source.position()
val bytesToWrite = source.remaining()
val byteString = ByteString.copyFrom(source)
if (delegate.trySend(byteString).isClosed) {
val result = delegate.trySend(byteString)
if (result.isClosed) {
throw ClosedChannelException()
} else if (result.isFailure) {
source.position(originalPosition)
return 0
}
return byteString.size()
return bytesToWrite
}

override fun isOpen(): Boolean = !delegate.isClosedForSend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ package org.wfanet.measurement.common.crypto.tink

import com.google.crypto.tink.StreamingAead
import com.google.protobuf.ByteString
import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.ReadableByteChannel
import java.util.logging.Logger
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineScope
Expand All @@ -31,6 +29,7 @@ import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.produceIn
import org.jetbrains.annotations.BlockingExecutor
import org.wfanet.measurement.common.BYTES_PER_MIB
import org.wfanet.measurement.common.CoroutineReadableByteChannel
import org.wfanet.measurement.common.CoroutineWritableByteChannel
import org.wfanet.measurement.common.asFlow
import org.wfanet.measurement.storage.StorageClient
Expand Down Expand Up @@ -118,43 +117,10 @@ class StreamingAeadStorageClient(
val scope = CoroutineScope(streamingAeadContext)
try {
val chunkChannel = blob.read().produceIn(scope)
val readableChannel = CoroutineReadableByteChannel(chunkChannel)

val plaintextChannel =
this@StreamingAeadStorageClient.streamingAead.newDecryptingChannel(
object : ReadableByteChannel {
private var currentChunk: ByteString? = null
private var bufferOffset = 0

override fun isOpen(): Boolean = true

override fun close() {}

override fun read(buffer: ByteBuffer): Int {
if (currentChunk == null || bufferOffset >= currentChunk!!.size()) {
val result = chunkChannel.tryReceive()
when {
result.isSuccess -> {
currentChunk = result.getOrNull()
if (currentChunk == null) return -1
bufferOffset = 0
}
chunkChannel.isClosedForReceive -> return -1
else -> return 0
}
}

val nextChunkBuffer = currentChunk!!.asReadOnlyByteBuffer()
nextChunkBuffer.position(bufferOffset)
val bytesToRead = minOf(buffer.remaining(), nextChunkBuffer.remaining())
nextChunkBuffer.limit(bufferOffset + bytesToRead)
buffer.put(nextChunkBuffer)
bufferOffset += bytesToRead
return bytesToRead
}
},
blobKey.encodeToByteArray(),
)

streamingAead.newDecryptingChannel(readableChannel, blobKey.encodeToByteArray())
emitAll(plaintextChannel.asFlow(BYTES_PER_MIB, streamingAeadContext))
} finally {
scope.cancel()
Expand Down
30 changes: 30 additions & 0 deletions src/test/kotlin/org/wfanet/measurement/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@ kt_jvm_test(
],
)

kt_jvm_test(
name = "CoroutineReadableByteChannelTest",
srcs = ["CoroutineReadableByteChannelTest.kt"],
test_class = "org.wfanet.measurement.common.CoroutineReadableByteChannelTest",
deps = [
"//imports/java/com/google/common/truth",
"//imports/java/com/google/protobuf",
"//imports/java/org/junit",
"//imports/kotlin/kotlin/test",
"//imports/kotlin/kotlinx/coroutines:core",
"//imports/kotlin/kotlinx/coroutines/test",
"//src/main/kotlin/org/wfanet/measurement/common",
],
)

kt_jvm_test(
name = "CoroutineWritableByteChannelTest",
srcs = ["CoroutineWritableByteChannelTest.kt"],
test_class = "org.wfanet.measurement.common.CoroutineWritableByteChannelTest",
deps = [
"//imports/java/com/google/common/truth",
"//imports/java/com/google/protobuf",
"//imports/java/org/junit",
"//imports/kotlin/kotlin/test",
"//imports/kotlin/kotlinx/coroutines:core",
"//imports/kotlin/kotlinx/coroutines/test",
"//src/main/kotlin/org/wfanet/measurement/common",
],
)

kt_jvm_test(
name = "ProtoUtilsTest",
srcs = ["ProtoUtilsTest.kt"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.common

import com.google.common.truth.Truth.assertThat
import com.google.protobuf.ByteString
import java.nio.ByteBuffer
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

@RunWith(JUnit4::class)
class CoroutineReadableByteChannelTest {

private lateinit var channel: Channel<ByteString>
private lateinit var coroutineReadableByteChannel: CoroutineReadableByteChannel

@Before
fun setUp() {
channel = Channel()
coroutineReadableByteChannel = CoroutineReadableByteChannel(channel)
}

@After
fun tearDown() {
coroutineReadableByteChannel.close()
}

@Test
fun `read - reads data from channel`() = runBlocking {
val testData = ByteString.copyFromUtf8("hello")
val buffer = ByteBuffer.allocate(5)
launch { channel.send(testData) }
delay(100)
val bytesRead = coroutineReadableByteChannel.read(buffer)
assertThat(testData.size()).isEqualTo(bytesRead)
assertThat(testData.toByteArray().toList()).isEqualTo(buffer.array().toList())
}

@Test
fun `read - handles remaining bytes correctly`() = runBlocking {
val testData = ByteString.copyFromUtf8("hello world")
val buffer = ByteBuffer.allocate(5)
launch { channel.send(testData) }
delay(100)
val bytesRead = coroutineReadableByteChannel.read(buffer)
assertThat(5).isEqualTo(bytesRead)
buffer.flip()
assertThat("hello").isEqualTo(ByteString.copyFrom(buffer).toStringUtf8())
buffer.clear()
var remainingBytesRead = coroutineReadableByteChannel.read(buffer)
buffer.flip()
assertThat(" worl").isEqualTo(ByteString.copyFrom(buffer).toStringUtf8())
assertThat(5).isEqualTo(remainingBytesRead)
buffer.clear()
remainingBytesRead = coroutineReadableByteChannel.read(buffer)
buffer.flip()
assertThat("d").isEqualTo(ByteString.copyFrom(buffer).toStringUtf8())
assertThat(1).isEqualTo(remainingBytesRead)
}

@Test
fun `read - returns 0 when no data available`() = runBlocking {
val buffer = ByteBuffer.allocate(5)
val bytesRead = coroutineReadableByteChannel.read(buffer)
assertThat(0).isEqualTo(bytesRead)
}

@Test
fun `read - returns -1 when channel is closed and empty`() = runBlocking {
channel.close()
val buffer = ByteBuffer.allocate(5)
val bytesRead = coroutineReadableByteChannel.read(buffer)
assertThat(-1).isEqualTo(bytesRead)
}

@Test
fun `isOpen - returns false after close`() = runBlocking {
coroutineReadableByteChannel.close()
assertThat(coroutineReadableByteChannel.isOpen()).isFalse()
}

@Test
fun `isOpen - returns true when channel is open`() = runBlocking {
assertThat(coroutineReadableByteChannel.isOpen()).isTrue()
}
}
Loading

0 comments on commit b92c8a9

Please sign in to comment.