Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: capture call site coroutine context into call options #592

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.squareup.kotlinpoet.AnnotationSpec
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.TypeName
Expand All @@ -48,6 +49,7 @@ import io.grpc.kotlin.generator.protoc.methodName
import io.grpc.kotlin.generator.protoc.of
import io.grpc.kotlin.generator.protoc.serviceName
import kotlinx.coroutines.flow.Flow
import kotlin.coroutines.CoroutineContext
import io.grpc.Channel as GrpcChannel
import io.grpc.Metadata as GrpcMetadata

Expand All @@ -62,6 +64,10 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
private val STREAMING_PARAMETER_NAME = MemberSimpleName("requests")
private val GRPC_CHANNEL_PARAMETER_NAME = MemberSimpleName("channel")
private val CALL_OPTIONS_PARAMETER_NAME = MemberSimpleName("callOptions")
private val WITH_COROUTINE_CONTEXT_FUN_NAME = MemberName(ClientCalls::class.java.`package`.name, "withCoroutineContext")
private val COROUTINE_CONTEXT_VAL_NAME = MemberName(CoroutineContext::class.java.`package`.name, "coroutineContext")
private val FLOW_FUN_NAME = MemberName(Flow::class.java.`package`.name, "flow")
private val EMIT_ALL_FUN_NAME = MemberName(Flow::class.java.`package`.name, "emitAll")

private val HEADERS_PARAMETER: ParameterSpec = ParameterSpec
.builder("headers", GrpcMetadata::class)
Expand Down Expand Up @@ -94,6 +100,9 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
} else {
if (isServerStreaming) MethodType.SERVER_STREAMING else MethodType.UNARY
}

private val MethodDescriptor.isSuspendable: Boolean
get() = !isServerStreaming
}

override fun generate(service: ServiceDescriptor): Declarations = declarations {
Expand Down Expand Up @@ -189,28 +198,39 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
)
}

val codeBlockMap = mapOf(
"helperMethod" to helperMethod,
"methodDescriptor" to method.descriptorCode,
"parameter" to parameter,
"headers" to HEADERS_PARAMETER
)
val codeBlockMap = buildMap {
this["helperMethod"] = helperMethod
this["methodDescriptor"] = method.descriptorCode
this["parameter"] = parameter
this["headers"] = HEADERS_PARAMETER
this["withContext"] = WITH_COROUTINE_CONTEXT_FUN_NAME
this["coroutineContext"] = COROUTINE_CONTEXT_VAL_NAME
if (!method.isSuspendable) {
this["flow"] = FLOW_FUN_NAME
this["emitAll"] = EMIT_ALL_FUN_NAME
}
}

if (!method.isServerStreaming) {
if (method.isSuspendable) {
funSpecBuilder.addModifiers(KModifier.SUSPEND)
}

funSpecBuilder.addNamedCode(
"""
return %helperMethod:M(
val helperCall = """
%helperMethod:M(
channel,
%methodDescriptor:L,
%parameter:N,
callOptions,
callOptions.%withContext:M(%coroutineContext:M),
%headers:N
)
""".trimIndent(),
codeBlockMap
""".trimIndent()
funSpecBuilder.addNamedCode(
if (method.isSuspendable) {
"return $helperCall"
} else {
"return \n%flow:M {\n⇥%emitAll:M(\n⇥$helperCall\n⇤)\n⇤}"
},
codeBlockMap,
)
return funSpecBuilder.build()
}
Expand Down
25 changes: 25 additions & 0 deletions stub/src/main/java/io/grpc/kotlin/CallOptionsCoroutineContext.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.grpc.kotlin

import io.grpc.CallOptions
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext

private val COROUTINE_CONTEXT_OPTION: CallOptions.Key<CoroutineContext> =
CallOptions.Key.createWithDefault("Coroutine context", EmptyCoroutineContext)

/**
* Sets a coroutine context.
*
* @param context coroutine context to put into the call options
* @return [CallOptions] instance with coroutine context
*/
fun CallOptions.withCoroutineContext(context: CoroutineContext): CallOptions =
withOption(COROUTINE_CONTEXT_OPTION, context)

/**
* Gets a coroutine context from the call options.
*
* Default: [EmptyCoroutineContext]
*/
val CallOptions.coroutineContext: CoroutineContext
get() = getOption(COROUTINE_CONTEXT_OPTION)
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package io.grpc.kotlin

import com.google.common.truth.Truth.assertThat
import com.google.common.truth.extensions.proto.ProtoTruth
import io.grpc.CallOptions
import io.grpc.Channel
import io.grpc.ClientCall
import io.grpc.ClientInterceptor
import io.grpc.ClientInterceptors
import io.grpc.MethodDescriptor
import io.grpc.examples.helloworld.GreeterGrpcKt
import io.grpc.examples.helloworld.HelloRequest
import io.grpc.examples.helloworld.MultiHelloRequest
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import java.util.UUID
import kotlin.coroutines.CoroutineContext

@RunWith(JUnit4::class)
class ClientCallOptionsCoroutineContextPropagationTest : AbstractCallsTest() {

@Test
fun `should capture coroutine context with unary call`() {
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
override suspend fun sayHello(request: HelloRequest) = helloReply("Hello, ${request.name}!")
}
val interceptor = CoroutineContextCapturingInterceptor()
val contextElement = DummyCoroutineContextElement()
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)

runBlocking {
withContext(contextElement) {
ProtoTruth.assertThat(stub.sayHello(helloRequest("Steven")))
.isEqualTo(helloReply("Hello, Steven!"))
}
}
assertThat(interceptor.coroutineContext).isNotNull()
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
}

@Test
fun `should capture coroutine context with client streaming`() {
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
override suspend fun clientStreamSayHello(requests: Flow<HelloRequest>) = requests.map { request ->
helloReply("Hello, ${request.name}!")
}.first()
}
val interceptor = CoroutineContextCapturingInterceptor()
val contextElement = DummyCoroutineContextElement()
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)

runBlocking {
withContext(contextElement) {
ProtoTruth.assertThat(stub.clientStreamSayHello(flowOf(helloRequest("Steven"))))
.isEqualTo(helloReply("Hello, Steven!"))
}
}
assertThat(interceptor.coroutineContext).isNotNull()
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
}

@Test
fun `should capture coroutine context with server streaming`() {
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
override fun serverStreamSayHello(request: MultiHelloRequest) = flowOf(
helloReply("Hello, ${request.nameList.joinToString()}!")
)
}
val interceptor = CoroutineContextCapturingInterceptor()
val contextElement = DummyCoroutineContextElement()
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)

runBlocking {
withContext(contextElement) {
ProtoTruth.assertThat(stub.serverStreamSayHello(multiHelloRequest("Steven", "Andrew")).first())
.isEqualTo(helloReply("Hello, Steven, Andrew!"))
}
}
assertThat(interceptor.coroutineContext).isNotNull()
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
}

@Test
fun `should capture coroutine context with bidi streaming`() {
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
override fun bidiStreamSayHello(requests: Flow<HelloRequest>) = requests.map { request ->
helloReply("Hello, ${request.name}!")
}
}
val interceptor = CoroutineContextCapturingInterceptor()
val contextElement = DummyCoroutineContextElement()
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)

runBlocking {
withContext(contextElement) {
ProtoTruth.assertThat(stub.bidiStreamSayHello(flowOf(helloRequest("Steven"))).first())
.isEqualTo(helloReply("Hello, Steven!"))
}
}
assertThat(interceptor.coroutineContext).isNotNull()
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
}
}

private data class DummyCoroutineContextElement(val value: UUID = UUID.randomUUID()) : CoroutineContext.Element {
override val key: CoroutineContext.Key<*> = Key

companion object Key : CoroutineContext.Key<DummyCoroutineContextElement>
}

private class CoroutineContextCapturingInterceptor : ClientInterceptor {

var coroutineContext: CoroutineContext? = null

override fun <ReqT : Any?, RespT : Any?> interceptCall(
method: MethodDescriptor<ReqT, RespT>,
callOptions: CallOptions,
next: Channel,
): ClientCall<ReqT, RespT> {
coroutineContext = callOptions.coroutineContext

return next.newCall(method, callOptions)
}
}