Skip to content

Commit

Permalink
Auto-dispatch to I/O context (#218)
Browse files Browse the repository at this point in the history
This adds a new property to `ProtocolClientConfig` called
`ioCoroutineContext`. It defaults to null, which means there is no
change in behavior from the current defaults -- operations happen in the
calling coroutine context, so callers of RPC methods must handle the
dispatch.

But when it is configured as non-null (most likely using a value like
`Dispatchers.IO`), then I/O and blocking operations for an RPC are
automatically dispatched to the given coroutine context. This allows for
RPC client code to be used from any coroutine context without having to
worry about explicitly dispatching.
  • Loading branch information
jhump authored Feb 9, 2024
1 parent c6e55d3 commit d921d9b
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ internal class TracingHTTPClient(
return res
}

override fun sendClose() {
override suspend fun sendClose() {
printer.printlnWithStackTrace("Half-closing stream")
delegate.sendClose()
}

override fun receiveClose() {
override suspend fun receiveClose() {
printer.printlnWithStackTrace("Closing stream")
delegate.receiveClose()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class ElizaChatActivity : AppCompatActivity() {
host = host,
serializationStrategy = GoogleJavaLiteProtobufStrategy(),
networkProtocol = selectedNetworkProtocolOption,
// RPC operations that involve network I/O will
// use this coroutine context.
ioCoroutineContext = Dispatchers.IO,
),
)
// Create the Eliza service client.
Expand Down Expand Up @@ -113,7 +116,7 @@ class ElizaChatActivity : AppCompatActivity() {
adapter.add(MessageData(sentence, false))
editTextView.setText("")
// Ensure IO context for unary requests.
lifecycleScope.launch(Dispatchers.IO) {
lifecycleScope.launch {
// Make a unary request to Eliza.
val response = elizaServiceClient.say(SayRequest.newBuilder().setSentence(sentence).build())
response.success { success ->
Expand All @@ -133,7 +136,7 @@ class ElizaChatActivity : AppCompatActivity() {

private fun setupStreamingChat(elizaServiceClient: ElizaServiceClient) {
// On stream result, this callback can be called multiple times.
lifecycleScope.launch(Dispatchers.IO) {
lifecycleScope.launch {
// Initialize a bidi stream with Eliza.
val stream = elizaServiceClient.converse()
try {
Expand All @@ -156,15 +159,13 @@ class ElizaChatActivity : AppCompatActivity() {
} catch (e: ConnectException) {
adapter.add(MessageData("Session failed with code ${e.code}", true))
}
lifecycleScope.launch(Dispatchers.Main) {
buttonView.setOnClickListener {
val sentence = editTextView.text.toString()
adapter.add(MessageData(sentence, false))
editTextView.setText("")
// Send will be streaming a message to Eliza.
lifecycleScope.launch(Dispatchers.IO) {
stream.send(ConverseRequest.newBuilder().setSentence(sentence).build())
}
buttonView.setOnClickListener {
val sentence = editTextView.text.toString()
adapter.add(MessageData(sentence, false))
editTextView.setText("")
// Send will be streaming a message to Eliza.
lifecycleScope.launch {
stream.send(ConverseRequest.newBuilder().setSentence(sentence).build())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
import java.time.Duration

Expand All @@ -44,6 +43,9 @@ class Main {
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
// RPC operations that involve network I/O will
// use this coroutine context.
ioCoroutineContext = Dispatchers.IO,
),
)
val elizaServiceClient = ElizaServiceClient(client)
Expand All @@ -57,13 +59,11 @@ class Main {

private suspend fun connectStreaming(elizaServiceClient: ElizaServiceClient) {
val stream = elizaServiceClient.converse()
withContext(Dispatchers.IO) {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (response in stream.responseChannel()) {
println(response.sentence)
}
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (response in stream.responseChannel()) {
println(response.sentence)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
import java.time.Duration

Expand All @@ -44,6 +43,9 @@ class Main {
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaLiteProtobufStrategy(),
// RPC operations that involve network I/O will
// use this coroutine context.
ioCoroutineContext = Dispatchers.IO,
),
)
val elizaServiceClient = ElizaServiceClient(client)
Expand All @@ -57,13 +59,11 @@ class Main {

private suspend fun connectStreaming(elizaServiceClient: ElizaServiceClient) {
val stream = elizaServiceClient.converse()
withContext(Dispatchers.IO) {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (response in stream.responseChannel()) {
println(response.sentence)
}
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (response in stream.responseChannel()) {
println(response.sentence)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ interface BidirectionalStreamInterface<Input, Output> {
/**
* Close the send stream. No calls to [send] are valid after calling [sendClose].
*/
fun sendClose()
suspend fun sendClose()

/**
* Close the receive stream.
*/
fun receiveClose()
suspend fun receiveClose()

/**
* Determine if the underlying client send stream is closed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ interface ClientOnlyStreamInterface<Input, Output> {
/**
* Close the stream. No calls to [send] are valid after calling [sendClose].
*/
fun sendClose()
suspend fun sendClose()

/**
* Cancels the stream. This closes both send and receive sides of the stream
* without awaiting any server reply.
*/
fun cancel()
suspend fun cancel()

/**
* Determine if the underlying client send stream is closed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.connectrpc.protocols.GRPCInterceptor
import com.connectrpc.protocols.GRPCWebInterceptor
import com.connectrpc.protocols.NetworkProtocol
import java.net.URI
import kotlin.coroutines.CoroutineContext

/**
* Set of configuration used to set up clients.
Expand All @@ -45,6 +46,14 @@ class ProtocolClientConfig @JvmOverloads constructor(
// Compression pools that provide support for the provided `compressionName`, as well as any
// other compression methods that need to be supported for inbound responses.
compressionPools: List<CompressionPool> = listOf(GzipCompressionPool),
// The coroutine context to use for I/O, such as sending RPC messages.
// If null, the current/calling coroutine context is used. So the caller
// may need to explicitly dispatch send calls using contexts where I/O
// is appropriate (using the withContext extension function). If non-null
// (such as Dispatchers.IO), operations that involve I/O or other
// blocking will automatically be dispatched using the given context,
// so the caller does not need to worry about it.
val ioCoroutineContext: CoroutineContext? = null,
) {
private val internalInterceptorFactoryList = mutableListOf<(ProtocolClientConfig) -> Interceptor>()
private val compressionPools = mutableMapOf<String, CompressionPool>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ interface ServerOnlyStreamInterface<Input, Output> {
/**
* Close the receive stream.
*/
fun receiveClose()
suspend fun receiveClose()

/**
* Determine if the underlying client receive stream is closed.
Expand Down
90 changes: 0 additions & 90 deletions library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package com.connectrpc.http

import com.connectrpc.StreamResult
import okio.Buffer
import java.util.concurrent.atomic.AtomicBoolean

typealias Cancelable = () -> Unit

Expand Down Expand Up @@ -46,92 +45,3 @@ interface HTTPClientInterface {
*/
fun stream(request: HTTPRequest, duplex: Boolean, onResult: suspend (StreamResult<Buffer>) -> Unit): Stream
}

interface Stream {
suspend fun send(buffer: Buffer): Result<Unit>

fun sendClose()

fun receiveClose()

fun isSendClosed(): Boolean

fun isReceiveClosed(): Boolean
}

fun Stream(
onSend: suspend (Buffer) -> Result<Unit>,
onSendClose: () -> Unit = {},
onReceiveClose: () -> Unit = {},
): Stream {
val isSendClosed = AtomicBoolean()
val isReceiveClosed = AtomicBoolean()
return object : Stream {
override suspend fun send(buffer: Buffer): Result<Unit> {
if (isSendClosed()) {
return Result.failure(IllegalStateException("cannot send. underlying stream is closed"))
}
return try {
onSend(buffer)
} catch (e: Throwable) {
Result.failure(e)
}
}

override fun sendClose() {
if (isSendClosed.compareAndSet(false, true)) {
onSendClose()
}
}

override fun receiveClose() {
if (isReceiveClosed.compareAndSet(false, true)) {
try {
onReceiveClose()
} finally {
// When receive side is closed, the send side is
// implicitly closed as well.
// We don't use sendClose() because we don't want to
// invoke onSendClose() since that will try to actually
// half-close the HTTP stream, which will fail since
// closing the receive side cancels the entire thing.
isSendClosed.set(true)
}
}
}

override fun isSendClosed(): Boolean {
return isSendClosed.get()
}

override fun isReceiveClosed(): Boolean {
return isReceiveClosed.get()
}
}
}

/**
* Returns a new stream that applies the given function to each
* buffer when send is called. The result of that function is
* what is passed along to the original stream.
*/
fun Stream.transform(apply: (Buffer) -> Buffer): Stream {
val delegate = this
return object : Stream {
override suspend fun send(buffer: Buffer): Result<Unit> {
return delegate.send(apply(buffer))
}
override fun sendClose() {
delegate.sendClose()
}
override fun receiveClose() {
delegate.receiveClose()
}
override fun isSendClosed(): Boolean {
return delegate.isSendClosed()
}
override fun isReceiveClosed(): Boolean {
return delegate.isReceiveClosed()
}
}
}
Loading

0 comments on commit d921d9b

Please sign in to comment.