diff --git a/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java b/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java index e3347fab58..4e75e78023 100644 --- a/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java +++ b/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java @@ -15,6 +15,7 @@ package com.google.cloud.spanner.connection; import com.google.api.core.InternalApi; +import com.google.api.gax.grpc.GrpcInterceptorProvider; import com.google.auth.Credentials; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.connection.ConnectionOptions.Builder; @@ -33,4 +34,10 @@ public static Builder setCredentials(Builder connectionOptionsBuilder, Credentia public static Spanner getSpanner(Connection connection) { return ((ConnectionImpl) connection).getSpanner(); } + + public static Builder setInterceptorProvider( + Builder connectionOptionsBuilder, GrpcInterceptorProvider provider) { + connectionOptionsBuilder.setConfigurator(options -> options.setInterceptorProvider(provider)); + return connectionOptionsBuilder; + } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java index aab0934358..5fa4e6a4b0 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java @@ -65,6 +65,7 @@ import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.spanner.admin.database.v1.InstanceName; import com.google.spanner.v1.DatabaseName; import java.io.DataOutputStream; @@ -144,6 +145,8 @@ public class ConnectionHandler implements Runnable { */ private final LinkedList skippedAutoDetectParseMessages = new LinkedList<>(); + private Map> extraHeaders = ImmutableMap.of(); + private ExtendedQueryProtocolHandler extendedQueryProtocolHandler; private CopyStatement activeCopyStatement; @@ -788,6 +791,14 @@ public synchronized void setStatus(ConnectionStatus status) { this.status = status; } + boolean hasExtraHeaders() { + return !this.extraHeaders.isEmpty(); + } + + Map> getExtraHeaders() { + return this.extraHeaders; + } + public WellKnownClient getWellKnownClient() { return wellKnownClient; } @@ -795,6 +806,9 @@ public WellKnownClient getWellKnownClient() { public void setWellKnownClient(WellKnownClient wellKnownClient) { this.wellKnownClient = wellKnownClient; if (this.wellKnownClient != WellKnownClient.UNSPECIFIED) { + // Include the well-known client in a header that we send to Spanner. + this.extraHeaders = + ImmutableMap.of("pgadapter-well-known-client", ImmutableList.of(wellKnownClient.name())); logger.log( Level.INFO, Logging.format( diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java index 9dbf2f1c6c..cb1f92dd5a 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java @@ -16,9 +16,11 @@ import com.google.api.core.AbstractApiService; import com.google.api.core.InternalApi; +import com.google.api.gax.rpc.ApiCallContext; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; +import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.ThreadFactoryUtil; import com.google.cloud.spanner.connection.SpannerPool; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; @@ -28,6 +30,10 @@ import com.google.cloud.spanner.pgadapter.utils.Metrics; import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage; import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.Context; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Tracer; import java.io.Closeable; @@ -59,6 +65,11 @@ public class ProxyServer extends AbstractApiService { private static final Logger logger = Logger.getLogger(ProxyServer.class.getName()); + private static final CallOptions.Key, String>> + DYNAMIC_HEADERS_CALL_OPTION_KEY = + CallOptions.Key.createWithDefault( + "gax_dynamic_headers", Collections., String>emptyMap()); + private final OptionsMetadata options; private final OpenTelemetry openTelemetry; private final Metrics metrics; @@ -343,7 +354,24 @@ void createConnectionHandler(Socket socket) throws SocketException { socket.setTcpNoDelay(true); ConnectionHandler handler = new ConnectionHandler(this, socket); register(handler); - Thread thread = threadFactory.newThread(handler); + // Create a gRPC context for the connection, so we can assign custom headers specifically for + // that connection. This allows us to for example include the name of the client that is + // connected in the user-agent header. + Context context = + Context.current() + .withValue( + SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY, + new SpannerOptions.CallContextConfigurator() { + @Override + public ApiCallContext configure( + ApiCallContext context, ReqT request, MethodDescriptor method) { + if (handler.hasExtraHeaders()) { + return context.withExtraHeaders(handler.getExtraHeaders()); + } + return context; + } + }); + Thread thread = threadFactory.newThread(context.wrap(handler)); handler.setThread(thread); handler.start(); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java index 8cfcdd441e..6ff3966e07 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/AbstractMockServerTest.java @@ -75,7 +75,10 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.logging.Logger; @@ -1040,6 +1043,8 @@ protected static ResultSetMetadata createAllArrayTypesResultSetMetadata(String c protected static MockInstanceAdminImpl mockInstanceAdmin; protected static Server spannerServer; protected static ProxyServer pgServer; + protected static Set WELL_KNOWN_CLIENT_HEADERS = + Collections.synchronizedSet(new HashSet<>()); protected List getWireMessages() { return new ArrayList<>(pgServer.getDebugMessages()); @@ -1280,6 +1285,14 @@ public Listener interceptCall( assertTrue(userAgent.contains("pg-adapter")); assertTrue( userAgent.contains(ServiceOptions.getGoogApiClientLibName() + "/")); + + String pgAdapterWellKnownClient = + metadata.get( + Metadata.Key.of( + "pgadapter-well-known-client", Metadata.ASCII_STRING_MARSHALLER)); + if (pgAdapterWellKnownClient != null) { + WELL_KNOWN_CLIENT_HEADERS.add(pgAdapterWellKnownClient); + } } return Contexts.interceptCall( Context.current(), serverCall, metadata, serverCallHandler); @@ -1374,6 +1387,7 @@ public void clearRequests() { if (pgServer != null) { pgServer.clearDebugMessages(); } + WELL_KNOWN_CLIENT_HEADERS.clear(); } protected void addDdlResponseToSpannerAdmin() { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java index e02bbe9237..bcd0be45aa 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java @@ -48,6 +48,7 @@ import com.google.cloud.spanner.pgadapter.statements.PgCatalog.PgConstraint; import com.google.cloud.spanner.pgadapter.statements.PgCatalog.PgExtension; import com.google.cloud.spanner.pgadapter.statements.PgCatalog.PgIndex; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.PreparedType; import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage; import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage; @@ -420,6 +421,11 @@ public void testQuery() throws SQLException { assertTrue(request.getTransaction().hasSingleUse()); assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); } + + // Verify that a header was sent to Spanner to indicate which client was connected to PGAdapter. + assertTrue( + WELL_KNOWN_CLIENT_HEADERS.toString(), + WELL_KNOWN_CLIENT_HEADERS.contains(WellKnownClient.JDBC.name())); } @Test diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java index 7748b5f930..6c3bb1fcb5 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/golang/PgxMockServerTest.java @@ -32,6 +32,7 @@ import com.google.cloud.spanner.pgadapter.AbstractMockServerTest; import com.google.cloud.spanner.pgadapter.CopyInMockServerTest; import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata; +import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient; import com.google.common.collect.ImmutableList; import com.google.protobuf.AbstractMessage; import com.google.protobuf.ByteString; @@ -158,6 +159,11 @@ public void testSelect1() { ExecuteSqlRequest executeRequest = requests.get(1); assertEquals(sql, executeRequest.getSql()); assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode()); + + // Verify that a header was sent to Spanner to indicate which client was connected to PGAdapter. + assertTrue( + WELL_KNOWN_CLIENT_HEADERS.toString(), + WELL_KNOWN_CLIENT_HEADERS.contains(WellKnownClient.PGX.name())); } @Test diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java index af7fb0ef67..9e7f234fae 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/nodejs/NodePostgresMockServerTest.java @@ -90,6 +90,10 @@ public void testSelect1() throws Exception { ExecuteSqlRequest request = executeSqlRequests.get(0); assertTrue(request.getTransaction().hasSingleUse()); assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + + // Verify that no header was sent to Spanner to indicate which client was connected to + // PGAdapter, as this client is not auto-detected. + assertTrue(WELL_KNOWN_CLIENT_HEADERS.isEmpty()); } @Test diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg3/Psycopg3MockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg3/Psycopg3MockServerTest.java index 0158a11f3b..eff5a6609f 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg3/Psycopg3MockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/python/psycopg3/Psycopg3MockServerTest.java @@ -122,6 +122,10 @@ public void testSelect1() throws Exception { assertEquals(sql, request.getSql()); assertTrue(request.getTransaction().hasSingleUse()); assertTrue(request.getTransaction().getSingleUse().hasReadOnly()); + + // Verify that no header was sent to Spanner to indicate which client was connected to + // PGAdapter, as this client is not auto-detected. + assertTrue(WELL_KNOWN_CLIENT_HEADERS.isEmpty()); } @Test