diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt index 2ff6265b..7dfa0397 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt @@ -23,6 +23,7 @@ import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier +import com.squareup.kotlinpoet.MemberName import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.TypeName @@ -48,6 +49,7 @@ import io.grpc.kotlin.generator.protoc.methodName import io.grpc.kotlin.generator.protoc.of import io.grpc.kotlin.generator.protoc.serviceName import kotlinx.coroutines.flow.Flow +import kotlin.coroutines.CoroutineContext import io.grpc.Channel as GrpcChannel import io.grpc.Metadata as GrpcMetadata @@ -62,6 +64,10 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co private val STREAMING_PARAMETER_NAME = MemberSimpleName("requests") private val GRPC_CHANNEL_PARAMETER_NAME = MemberSimpleName("channel") private val CALL_OPTIONS_PARAMETER_NAME = MemberSimpleName("callOptions") + private val WITH_COROUTINE_CONTEXT_FUN_NAME = MemberName(ClientCalls::class.java.`package`.name, "withCoroutineContext") + private val COROUTINE_CONTEXT_VAL_NAME = MemberName(CoroutineContext::class.java.`package`.name, "coroutineContext") + private val FLOW_FUN_NAME = MemberName(Flow::class.java.`package`.name, "flow") + private val EMIT_ALL_FUN_NAME = MemberName(Flow::class.java.`package`.name, "emitAll") private val HEADERS_PARAMETER: ParameterSpec = ParameterSpec .builder("headers", GrpcMetadata::class) @@ -94,6 +100,9 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co } else { if (isServerStreaming) MethodType.SERVER_STREAMING else MethodType.UNARY } + + private val MethodDescriptor.isSuspendable: Boolean + get() = !isServerStreaming } override fun generate(service: ServiceDescriptor): Declarations = declarations { @@ -189,28 +198,39 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co ) } - val codeBlockMap = mapOf( - "helperMethod" to helperMethod, - "methodDescriptor" to method.descriptorCode, - "parameter" to parameter, - "headers" to HEADERS_PARAMETER - ) + val codeBlockMap = buildMap { + this["helperMethod"] = helperMethod + this["methodDescriptor"] = method.descriptorCode + this["parameter"] = parameter + this["headers"] = HEADERS_PARAMETER + this["withContext"] = WITH_COROUTINE_CONTEXT_FUN_NAME + this["coroutineContext"] = COROUTINE_CONTEXT_VAL_NAME + if (!method.isSuspendable) { + this["flow"] = FLOW_FUN_NAME + this["emitAll"] = EMIT_ALL_FUN_NAME + } + } - if (!method.isServerStreaming) { + if (method.isSuspendable) { funSpecBuilder.addModifiers(KModifier.SUSPEND) } - funSpecBuilder.addNamedCode( - """ - return %helperMethod:M( + val helperCall = """ + %helperMethod:M( channel, %methodDescriptor:L, %parameter:N, - callOptions, + callOptions.%withContext:M(%coroutineContext:M), %headers:N ) - """.trimIndent(), - codeBlockMap + """.trimIndent() + funSpecBuilder.addNamedCode( + if (method.isSuspendable) { + "return $helperCall" + } else { + "return \n%flow:M {\n⇥%emitAll:M(\n⇥$helperCall\n⇤)\n⇤}" + }, + codeBlockMap, ) return funSpecBuilder.build() } diff --git a/stub/src/main/java/io/grpc/kotlin/CallOptionsCoroutineContext.kt b/stub/src/main/java/io/grpc/kotlin/CallOptionsCoroutineContext.kt new file mode 100644 index 00000000..3b0b6cdb --- /dev/null +++ b/stub/src/main/java/io/grpc/kotlin/CallOptionsCoroutineContext.kt @@ -0,0 +1,25 @@ +package io.grpc.kotlin + +import io.grpc.CallOptions +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext + +private val COROUTINE_CONTEXT_OPTION: CallOptions.Key = + CallOptions.Key.createWithDefault("Coroutine context", EmptyCoroutineContext) + +/** + * Sets a coroutine context. + * + * @param context coroutine context to put into the call options + * @return [CallOptions] instance with coroutine context + */ +fun CallOptions.withCoroutineContext(context: CoroutineContext): CallOptions = + withOption(COROUTINE_CONTEXT_OPTION, context) + +/** + * Gets a coroutine context from the call options. + * + * Default: [EmptyCoroutineContext] + */ +val CallOptions.coroutineContext: CoroutineContext + get() = getOption(COROUTINE_CONTEXT_OPTION) diff --git a/stub/src/test/java/io/grpc/kotlin/ClientCallOptionsCoroutineContextPropagationTest.kt b/stub/src/test/java/io/grpc/kotlin/ClientCallOptionsCoroutineContextPropagationTest.kt new file mode 100644 index 00000000..98013b27 --- /dev/null +++ b/stub/src/test/java/io/grpc/kotlin/ClientCallOptionsCoroutineContextPropagationTest.kt @@ -0,0 +1,134 @@ +package io.grpc.kotlin + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth +import io.grpc.CallOptions +import io.grpc.Channel +import io.grpc.ClientCall +import io.grpc.ClientInterceptor +import io.grpc.ClientInterceptors +import io.grpc.MethodDescriptor +import io.grpc.examples.helloworld.GreeterGrpcKt +import io.grpc.examples.helloworld.HelloRequest +import io.grpc.examples.helloworld.MultiHelloRequest +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.withContext +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import java.util.UUID +import kotlin.coroutines.CoroutineContext + +@RunWith(JUnit4::class) +class ClientCallOptionsCoroutineContextPropagationTest : AbstractCallsTest() { + + @Test + fun `should capture coroutine context with unary call`() { + val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest) = helloReply("Hello, ${request.name}!") + } + val interceptor = CoroutineContextCapturingInterceptor() + val contextElement = DummyCoroutineContextElement() + val channel = ClientInterceptors.intercept(makeChannel(server), interceptor) + val stub = GreeterGrpcKt.GreeterCoroutineStub(channel) + + runBlocking { + withContext(contextElement) { + ProtoTruth.assertThat(stub.sayHello(helloRequest("Steven"))) + .isEqualTo(helloReply("Hello, Steven!")) + } + } + assertThat(interceptor.coroutineContext).isNotNull() + assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement) + } + + @Test + fun `should capture coroutine context with client streaming`() { + val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override suspend fun clientStreamSayHello(requests: Flow) = requests.map { request -> + helloReply("Hello, ${request.name}!") + }.first() + } + val interceptor = CoroutineContextCapturingInterceptor() + val contextElement = DummyCoroutineContextElement() + val channel = ClientInterceptors.intercept(makeChannel(server), interceptor) + val stub = GreeterGrpcKt.GreeterCoroutineStub(channel) + + runBlocking { + withContext(contextElement) { + ProtoTruth.assertThat(stub.clientStreamSayHello(flowOf(helloRequest("Steven")))) + .isEqualTo(helloReply("Hello, Steven!")) + } + } + assertThat(interceptor.coroutineContext).isNotNull() + assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement) + } + + @Test + fun `should capture coroutine context with server streaming`() { + val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override fun serverStreamSayHello(request: MultiHelloRequest) = flowOf( + helloReply("Hello, ${request.nameList.joinToString()}!") + ) + } + val interceptor = CoroutineContextCapturingInterceptor() + val contextElement = DummyCoroutineContextElement() + val channel = ClientInterceptors.intercept(makeChannel(server), interceptor) + val stub = GreeterGrpcKt.GreeterCoroutineStub(channel) + + runBlocking { + withContext(contextElement) { + ProtoTruth.assertThat(stub.serverStreamSayHello(multiHelloRequest("Steven", "Andrew")).first()) + .isEqualTo(helloReply("Hello, Steven, Andrew!")) + } + } + assertThat(interceptor.coroutineContext).isNotNull() + assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement) + } + + @Test + fun `should capture coroutine context with bidi streaming`() { + val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override fun bidiStreamSayHello(requests: Flow) = requests.map { request -> + helloReply("Hello, ${request.name}!") + } + } + val interceptor = CoroutineContextCapturingInterceptor() + val contextElement = DummyCoroutineContextElement() + val channel = ClientInterceptors.intercept(makeChannel(server), interceptor) + val stub = GreeterGrpcKt.GreeterCoroutineStub(channel) + + runBlocking { + withContext(contextElement) { + ProtoTruth.assertThat(stub.bidiStreamSayHello(flowOf(helloRequest("Steven"))).first()) + .isEqualTo(helloReply("Hello, Steven!")) + } + } + assertThat(interceptor.coroutineContext).isNotNull() + assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement) + } +} + +private data class DummyCoroutineContextElement(val value: UUID = UUID.randomUUID()) : CoroutineContext.Element { + override val key: CoroutineContext.Key<*> = Key + + companion object Key : CoroutineContext.Key +} + +private class CoroutineContextCapturingInterceptor : ClientInterceptor { + + var coroutineContext: CoroutineContext? = null + + override fun interceptCall( + method: MethodDescriptor, + callOptions: CallOptions, + next: Channel, + ): ClientCall { + coroutineContext = callOptions.coroutineContext + + return next.newCall(method, callOptions) + } +}