Skip to content

Commit

Permalink
Makes ResponseStream actually implement ReceiveChannel (#217)
Browse files Browse the repository at this point in the history
Also adds `copyFromChannel` helper functions.
  • Loading branch information
jhump authored Feb 6, 2024
1 parent 3f69420 commit c6e55d3
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 29 deletions.
9 changes: 9 additions & 0 deletions conformance/client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ plugins {
kotlin("jvm")
}

tasks {
compileKotlin {
kotlinOptions {
// Generated Kotlin code for protobuf uses OptIn annotation
freeCompilerArgs += "-opt-in=kotlin.RequiresOptIn"
}
}
}

dependencies {
implementation(project(":okhttp"))
implementation(libs.kotlin.coroutines.core)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class Client(
stream.close()
}
try {
val resp = stream.responses.messages.receive()
val resp = stream.responses.receive()
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.close()
Expand Down Expand Up @@ -363,7 +363,7 @@ class Client(
var connEx: ConnectException? = null
var trailers: Headers
try {
for (resp in stream.responses.messages) {
for (resp in stream.responses) {
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.close()
Expand Down Expand Up @@ -418,7 +418,7 @@ class Client(
if (cancel is Cancel.AfterNumResponses && cancel.num == 0) {
stream.close()
}
for (resp in stream.messages) {
for (resp in stream) {
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.Headers
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.coroutineScope

/**
Expand Down Expand Up @@ -60,12 +61,12 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface BidiStream<Req : MessageLite, Resp : MessageLite> : SuspendCloseable {
interface BidiStream<Req, Resp> : SuspendCloseable {
val requests: RequestStream<Req>
val responses: ResponseStream<Resp>

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): BidiStream<Req, Resp> {
fun <Req, Resp> new(underlying: BidirectionalStreamInterface<Req, Resp>): BidiStream<Req, Resp> {
val reqStream = RequestStream.new(underlying)
val respStream = ResponseStream.new(underlying)
return object : BidiStream<Req, Resp> {
Expand All @@ -83,3 +84,42 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(
}
}
}

/**
* Copies the contents of the given channel to the given
* stream's requests, closing the request stream at the
* end (via half-closing the stream).
*/
suspend fun <Req, Resp> copyFromChannel(
stream: BidiStreamClient.BidiStream<Req, Resp>,
requests: ReceiveChannel<Req>,
) {
copyFromChannel(stream, requests) { it }
}

/**
* Copies the contents of the given channel to the given
* stream's requests, transforming each element using the
* given lambda, closing the request stream at the end
* (via half-closing the stream).
*/
suspend fun <Req, Resp, T> copyFromChannel(
stream: BidiStreamClient.BidiStream<Req, Resp>,
requests: ReceiveChannel<T>,
toRequest: (T) -> Req,
) {
stream.requests.use {
try {
for (req in requests) {
it.send(toRequest(req))
}
} catch (ex: Throwable) {
try {
stream.close()
} catch (closeEx: Throwable) {
ex.addSuppressed(closeEx)
}
throw ex
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.connectrpc.Headers
import com.connectrpc.ResponseMessage
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.coroutineScope

/**
Expand Down Expand Up @@ -60,13 +61,29 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface ClientStream<Req : MessageLite, Resp : MessageLite> : SuspendCloseable {
suspend fun send(req: Req)
interface ClientStream<Req, Resp> :
RequestStream<Req>,
SuspendCloseable {

/**
* Closes the request stream **and** cancels the RPC. To close
* the request stream normally, without canceling the RPC,
* call closeAndReceive() instead.
*/
override suspend fun close()

/**
* Closes the request stream and then awaits and returns
* the RPC result.
*/
suspend fun closeAndReceive(): ResponseMessage<Resp>

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: ClientOnlyStreamInterface<Req, Resp>): ClientStream<Req, Resp> {
fun <Req, Resp> new(underlying: ClientOnlyStreamInterface<Req, Resp>): ClientStream<Req, Resp> {
return object : ClientStream<Req, Resp> {
override val isClosedForSend: Boolean
get() = underlying.isSendClosed()

override suspend fun send(req: Req) {
val result = underlying.send(req)
if (result.isFailure) {
Expand All @@ -84,11 +101,10 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
trailers = underlying.responseTrailers().await(),
)
} catch (e: Exception) {
val connectException: ConnectException
if (e is ConnectException) {
connectException = e
val connectException = if (e is ConnectException) {
e
} else {
connectException = ConnectException(code = Code.UNKNOWN, exception = e)
ConnectException(code = Code.UNKNOWN, exception = e)
}
return ResponseMessage.Failure(
cause = connectException,
Expand All @@ -107,3 +123,42 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
}
}
}

/**
* Copies the contents of the given channel to the given
* stream, closing the stream at the end and returning
* the RPC result.
*/
suspend fun <Req, Resp> copyFromChannel(
stream: ClientStreamClient.ClientStream<Req, Resp>,
requests: ReceiveChannel<Req>,
): ResponseMessage<Resp> {
return copyFromChannel(stream, requests) { it }
}

/**
* Copies the contents of the given channel to the given
* stream, transforming each element using the given lambda,
* closing the stream at the end and returning the RPC result.
*/
suspend fun <Req, Resp, T> copyFromChannel(
stream: ClientStreamClient.ClientStream<Req, Resp>,
requests: ReceiveChannel<T>,
toRequest: (T) -> Req,
): ResponseMessage<Resp> {
stream.use {
try {
for (req in requests) {
it.send(toRequest(req))
}
return it.closeAndReceive()
} catch (ex: Throwable) {
try {
stream.close()
} catch (closeEx: Throwable) {
ex.addSuppressed(closeEx)
}
throw ex
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package com.connectrpc.conformance.client.adapt

import com.connectrpc.BidirectionalStreamInterface
import com.google.protobuf.MessageLite

/**
* RequestStream is a stream that allows a client to upload
Expand All @@ -28,7 +27,9 @@ import com.google.protobuf.MessageLite
* requests "half-closes" the stream; closing the responses
* "fully closes" it.
*/
interface RequestStream<Req : MessageLite> : SuspendCloseable {
interface RequestStream<Req> : SuspendCloseable {
val isClosedForSend: Boolean

/**
* Sends a message on the stream.
* @throws Exception when the request cannot be sent
Expand All @@ -37,8 +38,11 @@ interface RequestStream<Req : MessageLite> : SuspendCloseable {
suspend fun send(req: Req)

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): RequestStream<Req> {
fun <Req, Resp> new(underlying: BidirectionalStreamInterface<Req, Resp>): RequestStream<Req> {
return object : RequestStream<Req> {
override val isClosedForSend: Boolean
get() = underlying.isSendClosed()

override suspend fun send(req: Req) {
val result = underlying.send(req)
if (result.isFailure) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ package com.connectrpc.conformance.client.adapt
import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.Headers
import com.connectrpc.ServerOnlyStreamInterface
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.ChannelIterator
import kotlinx.coroutines.channels.ChannelResult
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.selects.SelectClause1

/**
* ResponseStream is a stream that allows a client to download
Expand All @@ -29,17 +34,27 @@ import kotlinx.coroutines.channels.ReceiveChannel
*
* @param Resp The response message type
*/
interface ResponseStream<Resp : MessageLite> : SuspendCloseable {
val messages: ReceiveChannel<Resp>
interface ResponseStream<Resp> :
ReceiveChannel<Resp>,
SuspendCloseable {

suspend fun headers(): Headers
suspend fun trailers(): Headers

@Deprecated(
"""Prefer close() instead. Since response streams are not buffered,
there will not be undelivered items to discard. The close() method
will result in a ConnectException with a CANCELED code being thrown
by calls to receive, whereas this method results in the given
CancellationException being thrown.""",
ReplaceWith("stream.close()"),
)
override fun cancel(cause: CancellationException?)

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): ResponseStream<Resp> {
fun <Req, Resp> new(underlying: BidirectionalStreamInterface<Req, Resp>): ResponseStream<Resp> {
val channel = underlying.responseChannel()
return object : ResponseStream<Resp> {
override val messages: ReceiveChannel<Resp>
get() = underlying.responseChannel()

override suspend fun headers(): Headers {
return underlying.responseHeaders().await()
}
Expand All @@ -51,14 +66,53 @@ interface ResponseStream<Resp : MessageLite> : SuspendCloseable {
override suspend fun close() {
underlying.receiveClose()
}

@OptIn(DelicateCoroutinesApi::class)
override val isClosedForReceive: Boolean
get() = channel.isClosedForReceive

@ExperimentalCoroutinesApi
override val isEmpty: Boolean
get() = channel.isEmpty

override val onReceive: SelectClause1<Resp>
get() = channel.onReceive

override val onReceiveCatching: SelectClause1<ChannelResult<Resp>>
get() = channel.onReceiveCatching

@Deprecated("Since 1.2.0, binary compatibility with versions <= 1.1.x", level = DeprecationLevel.HIDDEN)
override fun cancel(cause: Throwable?): Boolean {
channel.cancel(CancellationException())
return false
}

@Deprecated("Prefer close() instead.", ReplaceWith("stream.close()"))
override fun cancel(cause: CancellationException?) {
channel.cancel(cause)
}

override fun iterator(): ChannelIterator<Resp> {
return channel.iterator()
}

override suspend fun receive(): Resp {
return channel.receive()
}

override suspend fun receiveCatching(): ChannelResult<Resp> {
return channel.receiveCatching()
}

override fun tryReceive(): ChannelResult<Resp> {
return channel.tryReceive()
}
}
}

fun <Req : MessageLite, Resp : MessageLite> new(underlying: ServerOnlyStreamInterface<Req, Resp>): ResponseStream<Resp> {
fun <Req, Resp> new(underlying: ServerOnlyStreamInterface<Req, Resp>): ResponseStream<Resp> {
val channel = underlying.responseChannel()
return object : ResponseStream<Resp> {
override val messages: ReceiveChannel<Resp>
get() = underlying.responseChannel()

override suspend fun headers(): Headers {
return underlying.responseHeaders().await()
}
Expand All @@ -70,6 +124,47 @@ interface ResponseStream<Resp : MessageLite> : SuspendCloseable {
override suspend fun close() {
underlying.receiveClose()
}

@OptIn(DelicateCoroutinesApi::class)
override val isClosedForReceive: Boolean
get() = channel.isClosedForReceive

@ExperimentalCoroutinesApi
override val isEmpty: Boolean
get() = channel.isEmpty

override val onReceive: SelectClause1<Resp>
get() = channel.onReceive

override val onReceiveCatching: SelectClause1<ChannelResult<Resp>>
get() = channel.onReceiveCatching

@Deprecated("Since 1.2.0, binary compatibility with versions <= 1.1.x", level = DeprecationLevel.HIDDEN)
override fun cancel(cause: Throwable?): Boolean {
channel.cancel(CancellationException())
return false
}

@Deprecated("Prefer close() instead.", ReplaceWith("stream.close()"))
override fun cancel(cause: CancellationException?) {
channel.cancel(cause)
}

override fun iterator(): ChannelIterator<Resp> {
return channel.iterator()
}

override suspend fun receive(): Resp {
return channel.receive()
}

override suspend fun receiveCatching(): ChannelResult<Resp> {
return channel.receiveCatching()
}

override fun tryReceive(): ChannelResult<Resp> {
return channel.tryReceive()
}
}
}
}
Expand Down
Loading

0 comments on commit c6e55d3

Please sign in to comment.