From 097bc93c39d10d8ff08eb8fb3a378bbe6d96c1ac Mon Sep 17 00:00:00 2001 From: Lawrence Qiu Date: Fri, 12 Jan 2024 20:13:26 +0000 Subject: [PATCH] feat: Validate the Universe Domain (#2330) * feat: Validate the universe domain * chore: Merge in from origin/main * chore: Add comments for ApiCallContext * chore: Add comments * chore: Address PR comments * chore: Merge endpoint context in both transports * chore: Use @throws for the exceptions * chore: Provide a default EndpointContext * chore: Address PR comments * chore: Update error message * chore: Address PR comments * chore: Address PR comments * chore: Address PR comments --- .../google/api/gax/grpc/GrpcCallContext.java | 152 ++++++++++++++++-- .../google/api/gax/grpc/GrpcClientCalls.java | 4 + .../google/api/gax/grpc/ChannelPoolTest.java | 21 ++- .../api/gax/grpc/GrpcCallContextTest.java | 3 +- .../api/gax/grpc/GrpcCallableFactoryTest.java | 20 ++- .../api/gax/grpc/GrpcClientCallsTest.java | 127 +++++++++++++-- ...GrpcDirectServerStreamingCallableTest.java | 12 +- .../grpc/GrpcDirectStreamingCallableTest.java | 11 +- .../api/gax/grpc/GrpcLongRunningTest.java | 22 ++- .../gax/grpc/GrpcResponseMetadataTest.java | 12 +- .../com/google/api/gax/grpc/TimeoutTest.java | 15 +- .../api/gax/httpjson/HttpJsonCallContext.java | 97 +++++++++-- .../api/gax/httpjson/HttpJsonClientCalls.java | 4 + .../gax/httpjson/HttpJsonClientCallsTest.java | 145 +++++++++++++++++ .../HttpJsonClientInterceptorTest.java | 14 +- .../httpjson/HttpJsonDirectCallableTest.java | 73 ++++----- ...JsonDirectServerStreamingCallableTest.java | 16 +- gax-java/gax/clirr-ignored-differences.xml | 11 ++ .../google/api/gax/rpc/ApiCallContext.java | 3 + .../com/google/api/gax/rpc/ClientContext.java | 1 + .../google/api/gax/rpc/EndpointContext.java | 47 +++++- .../api/gax/rpc/EndpointContextTest.java | 75 +++++++++ .../api/gax/rpc/testing/FakeCallContext.java | 57 +++++-- 23 files changed, 822 insertions(+), 120 deletions(-) create mode 100644 gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientCallsTest.java diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 5f33c36448..823d8c56b7 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -30,15 +30,21 @@ package com.google.api.gax.grpc; import com.google.api.core.BetaApi; +import com.google.api.core.InternalApi; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ApiExceptionFactory; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; +import com.google.api.gax.rpc.UnauthenticatedException; +import com.google.api.gax.rpc.UnavailableException; import com.google.api.gax.rpc.internal.ApiCallContextOptions; import com.google.api.gax.rpc.internal.Headers; import com.google.api.gax.tracing.ApiTracer; import com.google.api.gax.tracing.BaseApiTracer; import com.google.auth.Credentials; +import com.google.auth.Retryable; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -47,7 +53,9 @@ import io.grpc.Channel; import io.grpc.Deadline; import io.grpc.Metadata; +import io.grpc.Status; import io.grpc.auth.MoreCallCredentials; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Objects; @@ -66,9 +74,13 @@ */ @BetaApi("Reference ApiCallContext instead - this class is likely to experience breaking changes") public final class GrpcCallContext implements ApiCallContext { + private static final GrpcStatusCode UNAUTHENTICATED_STATUS_CODE = + GrpcStatusCode.of(Status.Code.UNAUTHENTICATED); + static final CallOptions.Key TRACER_KEY = CallOptions.Key.create("gax.tracer"); private final Channel channel; + @Nullable private final Credentials credentials; private final CallOptions callOptions; @Nullable private final Duration timeout; @Nullable private final Duration streamWaitTimeout; @@ -78,10 +90,12 @@ public final class GrpcCallContext implements ApiCallContext { @Nullable private final ImmutableSet retryableCodes; private final ImmutableMap> extraHeaders; private final ApiCallContextOptions options; + private final EndpointContext endpointContext; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { return new GrpcCallContext( + null, null, CallOptions.DEFAULT, null, @@ -91,6 +105,7 @@ public static GrpcCallContext createDefault() { ImmutableMap.>of(), ApiCallContextOptions.getDefaultOptions(), null, + null, null); } @@ -98,6 +113,7 @@ public static GrpcCallContext createDefault() { public static GrpcCallContext of(Channel channel, CallOptions callOptions) { return new GrpcCallContext( channel, + null, callOptions, null, null, @@ -106,11 +122,13 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { ImmutableMap.>of(), ApiCallContextOptions.getDefaultOptions(), null, + null, null); } private GrpcCallContext( Channel channel, + @Nullable Credentials credentials, CallOptions callOptions, @Nullable Duration timeout, @Nullable Duration streamWaitTimeout, @@ -119,8 +137,10 @@ private GrpcCallContext( ImmutableMap> extraHeaders, ApiCallContextOptions options, @Nullable RetrySettings retrySettings, - @Nullable Set retryableCodes) { + @Nullable Set retryableCodes, + EndpointContext endpointContext) { this.channel = channel; + this.credentials = credentials; this.callOptions = Preconditions.checkNotNull(callOptions); this.timeout = timeout; this.streamWaitTimeout = streamWaitTimeout; @@ -130,6 +150,7 @@ private GrpcCallContext( this.options = Preconditions.checkNotNull(options); this.retrySettings = retrySettings; this.retryableCodes = retryableCodes == null ? null : ImmutableSet.copyOf(retryableCodes); + this.endpointContext = endpointContext; } /** @@ -158,7 +179,19 @@ public GrpcCallContext nullToSelf(ApiCallContext inputContext) { public GrpcCallContext withCredentials(Credentials newCredentials) { Preconditions.checkNotNull(newCredentials); CallCredentials callCredentials = MoreCallCredentials.from(newCredentials); - return withCallOptions(callOptions.withCallCredentials(callCredentials)); + return new GrpcCallContext( + channel, + newCredentials, + callOptions.withCallCredentials(callCredentials), + timeout, + streamWaitTimeout, + streamIdleTimeout, + channelAffinity, + extraHeaders, + options, + retrySettings, + retryableCodes, + endpointContext); } @Override @@ -172,6 +205,24 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { return withChannel(transportChannel.getChannel()); } + @Override + public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { + Preconditions.checkNotNull(endpointContext); + return new GrpcCallContext( + channel, + credentials, + callOptions, + timeout, + streamWaitTimeout, + streamIdleTimeout, + channelAffinity, + extraHeaders, + options, + retrySettings, + retryableCodes, + endpointContext); + } + @Override public GrpcCallContext withTimeout(@Nullable Duration timeout) { // Default RetrySettings use 0 for RPC timeout. Treat that as disabled timeouts. @@ -186,6 +237,7 @@ public GrpcCallContext withTimeout(@Nullable Duration timeout) { return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -194,7 +246,8 @@ public GrpcCallContext withTimeout(@Nullable Duration timeout) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Nullable @@ -212,6 +265,7 @@ public GrpcCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTimeou return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -220,7 +274,8 @@ public GrpcCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTimeou extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Override @@ -232,6 +287,7 @@ public GrpcCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTimeou return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -240,13 +296,15 @@ public GrpcCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTimeou extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -255,7 +313,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -266,6 +325,7 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) Headers.mergeHeaders(this.extraHeaders, extraHeaders); return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -274,7 +334,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) newExtraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Override @@ -286,6 +347,7 @@ public RetrySettings getRetrySettings() { public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -294,7 +356,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Override @@ -306,6 +369,7 @@ public Set getRetryableCodes() { public GrpcCallContext withRetryableCodes(Set retryableCodes) { return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -314,7 +378,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Override @@ -329,6 +394,11 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { } GrpcCallContext grpcCallContext = (GrpcCallContext) inputCallContext; + Credentials newCredentials = grpcCallContext.credentials; + if (newCredentials == null) { + newCredentials = credentials; + } + Channel newChannel = grpcCallContext.channel; if (newChannel == null) { newChannel = channel; @@ -394,8 +464,11 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newCallOptions = newCallOptions.withOption(TRACER_KEY, newTracer); } + // The EndpointContext is not updated as there should be no reason for a user + // to update this. return new GrpcCallContext( newChannel, + newCredentials, newCallOptions, newTimeout, newStreamWaitTimeout, @@ -404,7 +477,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newExtraHeaders, newOptions, newRetrySettings, - newRetryableCodes); + newRetryableCodes, + endpointContext); } /** The {@link Channel} set on this context. */ @@ -456,6 +530,7 @@ public Map> getExtraHeaders() { public GrpcCallContext withChannel(Channel newChannel) { return new GrpcCallContext( newChannel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -464,13 +539,15 @@ public GrpcCallContext withChannel(Channel newChannel) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } /** Returns a new instance with the call options set to the given call options. */ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { return new GrpcCallContext( channel, + credentials, newCallOptions, timeout, streamWaitTimeout, @@ -479,7 +556,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -513,6 +591,7 @@ public GrpcCallContext withOption(Key key, T value) { ApiCallContextOptions newOptions = options.withOption(key, value); return new GrpcCallContext( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -521,7 +600,8 @@ public GrpcCallContext withOption(Key key, T value) { extraHeaders, newOptions, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } /** {@inheritDoc} */ @@ -530,10 +610,49 @@ public T getOption(Key key) { return options.getOption(key); } + /** + * Validate the Universe Domain to ensure that the user configured Universe Domain and the + * Credentials' Universe Domain match. An exception will be raised if there are any issues when + * trying to validate (i.e. unable to access the universe domain). + * + * @throws UnauthenticatedException Thrown if the universe domain that the user configured does + * not match the Credential's universe domain. + * @throws UnavailableException If client library is unable to retrieve the universe domain from + * the Credentials and the RPC is configured to retry Unavailable exceptions, the client + * library will attempt to retry with the RPC's defined retry bounds. If the retry bounds have + * been exceeded and the library is still unable to retrieve the universe domain, the + * exception will be thrown back to the user. + */ + @InternalApi + public void validateUniverseDomain() { + try { + endpointContext.validateUniverseDomain(credentials, UNAUTHENTICATED_STATUS_CODE); + } catch (IOException e) { + // Check if it is an Auth Exception (All instances of IOException from endpointContext's + // `validateUniverseDomain()` call should be an Auth Exception). + if (e instanceof Retryable) { + Retryable retryable = (Retryable) e; + // Keep the behavior the same as gRPC-Java. Mark as Auth Exceptions as Unavailable + throw ApiExceptionFactory.createException( + EndpointContext.UNABLE_TO_RETRIEVE_CREDENTIALS_ERROR_MESSAGE, + e, + GrpcStatusCode.of(Status.Code.UNAVAILABLE), + retryable.isRetryable()); + } + // This exception below should never be raised as all IOExceptions should be caught above. + throw ApiExceptionFactory.createException( + EndpointContext.UNABLE_TO_RETRIEVE_CREDENTIALS_ERROR_MESSAGE, + e, + UNAUTHENTICATED_STATUS_CODE, + false); + } + } + @Override public int hashCode() { return Objects.hash( channel, + credentials, callOptions, timeout, streamWaitTimeout, @@ -542,7 +661,8 @@ public int hashCode() { extraHeaders, options, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } @Override @@ -556,6 +676,7 @@ public boolean equals(Object o) { GrpcCallContext that = (GrpcCallContext) o; return Objects.equals(channel, that.channel) + && Objects.equals(credentials, that.credentials) && Objects.equals(callOptions, that.callOptions) && Objects.equals(timeout, that.timeout) && Objects.equals(streamWaitTimeout, that.streamWaitTimeout) @@ -564,7 +685,8 @@ public boolean equals(Object o) { && Objects.equals(extraHeaders, that.extraHeaders) && Objects.equals(options, that.options) && Objects.equals(retrySettings, that.retrySettings) - && Objects.equals(retryableCodes, that.retryableCodes); + && Objects.equals(retryableCodes, that.retryableCodes) + && Objects.equals(endpointContext, that.endpointContext); } Metadata getMetadata() { diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcClientCalls.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcClientCalls.java index bc72f6f1f1..80e8797f01 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcClientCalls.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcClientCalls.java @@ -95,6 +95,10 @@ public static ClientCall newCall( channel = ClientInterceptors.intercept(channel, interceptor); } + // Validate the Universe Domain prior to the call. Only allow the call to go through + // if the Universe Domain is valid. + grpcContext.validateUniverseDomain(); + try (Scope ignored = grpcContext.getTracer().inScope()) { return channel.newCall(descriptor, callOptions); } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index 7131873333..ebc941ec0a 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -37,6 +37,7 @@ import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeMethodDescriptor; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; @@ -44,6 +45,7 @@ import com.google.api.gax.rpc.UnaryCallSettings; import com.google.api.gax.rpc.UnaryCallable; import com.google.api.gax.util.FakeLogHandler; +import com.google.auth.Credentials; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -628,10 +630,17 @@ public void testReleasingClientCallCancelEarly() throws IOException { ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1); ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel)); pool = ChannelPool.create(channelPoolSettings, factory); + + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + ClientContext context = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(pool)) - .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(pool, CallOptions.DEFAULT).withEndpointContext(endpointContext)) .build(); ServerStreamingCallSettings settings = ServerStreamingCallSettings.newBuilder().build(); @@ -680,11 +689,19 @@ public void testDoubleRelease() throws Exception { pool = ChannelPool.create(channelPoolSettings, factory); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + // Construct a fake callable to use the channel pool ClientContext context = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(pool)) - .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(pool, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .build(); UnaryCallSettings settings = diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java index 4e563225e5..e67c4c13c2 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java @@ -46,6 +46,7 @@ import io.grpc.CallOptions; import io.grpc.ManagedChannel; import io.grpc.Metadata.Key; +import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -373,7 +374,7 @@ public void testWithOptions() { } @Test - public void testMergeOptions() { + public void testMergeOptions() throws IOException { GrpcCallContext emptyCallContext = GrpcCallContext.createDefault(); ApiCallContext.Key contextKey1 = ApiCallContext.Key.create("testKey1"); ApiCallContext.Key contextKey2 = ApiCallContext.Key.create("testKey2"); diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallableFactoryTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallableFactoryTest.java index 2ebe93b7f7..a274512e14 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallableFactoryTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallableFactoryTest.java @@ -35,12 +35,15 @@ import com.google.api.gax.grpc.testing.FakeServiceImpl; import com.google.api.gax.grpc.testing.InProcessServer; import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.InvalidArgumentException; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.tracing.SpanName; +import com.google.auth.Credentials; import com.google.common.collect.ImmutableList; import com.google.common.truth.Truth; import com.google.type.Color; @@ -74,10 +77,16 @@ public void setUp() throws Exception { inprocessServer.start(); channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build(); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); clientContext = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(channel)) - .setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(channel, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .build(); } @@ -106,11 +115,10 @@ public void createServerStreamingCallableRetryableExceptions() { GrpcCallableFactory.createServerStreamingCallable( grpcCallSettings, nonRetryableSettings, clientContext); + ApiCallContext defaultCallContext = clientContext.getDefaultCallContext(); Throwable actualError = null; try { - nonRetryableCallable - .first() - .call(Color.getDefaultInstance(), clientContext.getDefaultCallContext()); + nonRetryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext); } catch (Throwable e) { actualError = e; } @@ -134,9 +142,7 @@ public void createServerStreamingCallableRetryableExceptions() { Throwable actualError2 = null; try { - retryableCallable - .first() - .call(Color.getDefaultInstance(), clientContext.getDefaultCallContext()); + retryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext); } catch (Throwable e) { actualError2 = e; } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java index fcdea5afe9..eb9277b2e1 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java @@ -30,10 +30,16 @@ package com.google.api.gax.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.verify; import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeServiceGrpc; +import com.google.api.gax.rpc.EndpointContext; +import com.google.api.gax.rpc.UnauthenticatedException; +import com.google.api.gax.rpc.UnavailableException; +import com.google.auth.Credentials; +import com.google.auth.Retryable; import com.google.common.collect.ImmutableList; import com.google.common.truth.Truth; import com.google.type.Color; @@ -45,18 +51,58 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.threeten.bp.Duration; public class GrpcClientCallsTest { + + // Auth Library's GoogleAuthException is package-private. Copy basic functionality for tests + private static class GoogleAuthException extends IOException implements Retryable { + + private final boolean isRetryable; + + private GoogleAuthException(boolean isRetryable) { + this.isRetryable = isRetryable; + } + + @Override + public boolean isRetryable() { + return isRetryable; + } + + @Override + public int getRetryCount() { + return 0; + } + } + + private GrpcCallContext defaultCallContext; + private EndpointContext endpointContext; + private Credentials credentials; + private Channel mockChannel; + + @Before + public void setUp() throws IOException { + credentials = Mockito.mock(Credentials.class); + endpointContext = Mockito.mock(EndpointContext.class); + mockChannel = Mockito.mock(Channel.class); + + defaultCallContext = GrpcCallContext.createDefault().withEndpointContext(endpointContext); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + } + @Test public void testAffinity() throws IOException { MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; @@ -78,7 +124,7 @@ public void testAffinity() throws IOException { ChannelPool.create( ChannelPoolSettings.staticallySized(2), new FakeChannelFactory(Arrays.asList(channel0, channel1))); - GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool); + GrpcCallContext context = defaultCallContext.withChannel(pool); ClientCall gotCallA = GrpcClientCalls.newCall(descriptor, context.withChannelAffinity(0)); @@ -92,7 +138,7 @@ public void testAffinity() throws IOException { } @Test - public void testExtraHeaders() { + public void testExtraHeaders() throws IOException { Metadata emptyHeaders = new Metadata(); final Map> extraHeaders = new HashMap<>(); extraHeaders.put( @@ -128,12 +174,12 @@ public void testExtraHeaders() { .thenReturn(mockClientCall); GrpcCallContext context = - GrpcCallContext.createDefault().withChannel(mockChannel).withExtraHeaders(extraHeaders); + defaultCallContext.withChannel(mockChannel).withExtraHeaders(extraHeaders); GrpcClientCalls.newCall(descriptor, context).start(mockListener, emptyHeaders); } @Test - public void testTimeoutToDeadlineConversion() { + public void testTimeoutToDeadlineConversion() throws IOException { MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; @SuppressWarnings("unchecked") @@ -152,8 +198,7 @@ public void testTimeoutToDeadlineConversion() { Duration timeout = Duration.ofSeconds(10); Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS); - GrpcCallContext context = - GrpcCallContext.createDefault().withChannel(mockChannel).withTimeout(timeout); + GrpcCallContext context = defaultCallContext.withChannel(mockChannel).withTimeout(timeout); GrpcClientCalls.newCall(descriptor, context).start(mockListener, new Metadata()); @@ -164,7 +209,7 @@ public void testTimeoutToDeadlineConversion() { } @Test - public void testTimeoutAfterDeadline() { + public void testTimeoutAfterDeadline() throws IOException { MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; @SuppressWarnings("unchecked") @@ -185,7 +230,7 @@ public void testTimeoutAfterDeadline() { Duration timeout = Duration.ofSeconds(10); GrpcCallContext context = - GrpcCallContext.createDefault() + defaultCallContext .withChannel(mockChannel) .withCallOptions(CallOptions.DEFAULT.withDeadline(priorDeadline)) .withTimeout(timeout); @@ -197,7 +242,7 @@ public void testTimeoutAfterDeadline() { } @Test - public void testTimeoutBeforeDeadline() { + public void testTimeoutBeforeDeadline() throws IOException { MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; @SuppressWarnings("unchecked") @@ -219,7 +264,7 @@ public void testTimeoutBeforeDeadline() { Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS); GrpcCallContext context = - GrpcCallContext.createDefault() + defaultCallContext .withChannel(mockChannel) .withCallOptions(CallOptions.DEFAULT.withDeadline(subsequentDeadline)) .withTimeout(timeout); @@ -232,4 +277,66 @@ public void testTimeoutBeforeDeadline() { Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtLeast(minExpectedDeadline); Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtMost(maxExpectedDeadline); } + + @Test + public void testValidUniverseDomain() throws IOException { + GrpcCallContext context = + GrpcCallContext.createDefault() + .withChannel(mockChannel) + .withCredentials(credentials) + .withEndpointContext(endpointContext); + + CallOptions callOptions = context.getCallOptions(); + + MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; + GrpcClientCalls.newCall(descriptor, context); + Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions); + } + + // This test is when the universe domain does not match + @Test + public void testInvalidUniverseDomain() throws IOException { + Mockito.doThrow( + new UnauthenticatedException( + null, GrpcStatusCode.of(Status.Code.UNAUTHENTICATED), false)) + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + GrpcCallContext context = + GrpcCallContext.createDefault() + .withChannel(mockChannel) + .withCredentials(credentials) + .withEndpointContext(endpointContext); + + CallOptions callOptions = context.getCallOptions(); + + MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; + UnauthenticatedException exception = + assertThrows( + UnauthenticatedException.class, () -> GrpcClientCalls.newCall(descriptor, context)); + assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAUTHENTICATED); + Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions); + } + + // This test is when the MDS is unable to return a valid universe domain + @Test + public void testUniverseDomainNotReady_shouldRetry() throws IOException { + Mockito.doThrow(new GoogleAuthException(true)) + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + GrpcCallContext context = + GrpcCallContext.createDefault() + .withChannel(mockChannel) + .withCredentials(credentials) + .withEndpointContext(endpointContext); + + CallOptions callOptions = context.getCallOptions(); + + MethodDescriptor descriptor = FakeServiceGrpc.METHOD_RECOGNIZE; + UnavailableException exception = + assertThrows( + UnavailableException.class, () -> GrpcClientCalls.newCall(descriptor, context)); + assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAVAILABLE); + Truth.assertThat(exception.isRetryable()).isTrue(); + Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions); + } } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectServerStreamingCallableTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectServerStreamingCallableTest.java index e5084b753b..5935d1d786 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectServerStreamingCallableTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectServerStreamingCallableTest.java @@ -36,6 +36,7 @@ import com.google.api.gax.grpc.testing.InProcessServer; import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.ServerStreamingCallSettings; @@ -44,6 +45,7 @@ import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.StreamController; import com.google.api.gax.rpc.testing.FakeCallContext; +import com.google.auth.Credentials; import com.google.common.collect.Lists; import com.google.common.truth.Truth; import com.google.type.Color; @@ -63,6 +65,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; @RunWith(JUnit4.class) public class GrpcDirectServerStreamingCallableTest { @@ -85,11 +88,18 @@ public void setUp() throws InstantiationException, IllegalAccessException, IOExc inprocessServer = new InProcessServer<>(serviceImpl, serverName); inprocessServer.start(); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build(); clientContext = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(channel)) - .setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(channel, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .build(); streamingCallSettings = ServerStreamingCallSettings.newBuilder().build(); streamingCallable = diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamingCallableTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamingCallableTest.java index 95d4550b03..0d39f8704d 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamingCallableTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcDirectStreamingCallableTest.java @@ -42,7 +42,9 @@ import com.google.api.gax.rpc.ClientContext; import com.google.api.gax.rpc.ClientStream; import com.google.api.gax.rpc.ClientStreamingCallable; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.StatusCode.Code; +import com.google.auth.Credentials; import com.google.type.Color; import com.google.type.Money; import io.grpc.CallOptions; @@ -58,6 +60,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; @RunWith(JUnit4.class) public class GrpcDirectStreamingCallableTest { @@ -73,10 +76,16 @@ public void setUp() throws InstantiationException, IllegalAccessException, IOExc inprocessServer = new InProcessServer<>(serviceImpl, serverName); inprocessServer.start(); channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build(); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); clientContext = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(channel)) - .setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(channel, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .build(); } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java index 7149f23ba7..20bceeae2c 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcLongRunningTest.java @@ -42,11 +42,13 @@ import com.google.api.gax.longrunning.OperationTimedPollAlgorithm; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.OperationCallSettings; import com.google.api.gax.rpc.OperationCallable; import com.google.api.gax.rpc.TransportChannel; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.auth.Credentials; import com.google.longrunning.Operation; import com.google.longrunning.OperationsSettings; import com.google.longrunning.stub.GrpcOperationsStub; @@ -68,6 +70,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -133,18 +136,25 @@ public void setUp() throws IOException { .setPollingAlgorithm(pollingAlgorithm) .build(); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + initialContext = ClientContext.newBuilder() .setTransportChannel( GrpcTransportChannel.newBuilder().setManagedChannel(channel).build()) .setExecutor(executor) - .setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(channel, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .setClock(clock) .build(); } @Test - public void testCall() { + public void testCall() throws IOException { Color resp = getColor(1.0f); Money meta = getMoney("UAH"); Operation resultOperation = getOperation("testCall", resp, meta, true); @@ -154,7 +164,13 @@ public void testCall() { GrpcCallableFactory.createOperationCallable( createGrpcSettings(), callSettings, initialContext, operationsStub); - Color response = callable.call(2, GrpcCallContext.createDefault()); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + + Color response = + callable.call(2, GrpcCallContext.createDefault().withEndpointContext(endpointContext)); assertThat(response).isEqualTo(resp); assertThat(executor.getIterationsCount()).isEqualTo(0); } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcResponseMetadataTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcResponseMetadataTest.java index 80b041a376..fba3f3f61e 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcResponseMetadataTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcResponseMetadataTest.java @@ -33,8 +33,10 @@ import com.google.api.gax.grpc.testing.FakeServiceGrpc.FakeServiceImplBase; import com.google.api.gax.grpc.testing.InProcessServer; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.UnaryCallSettings; import com.google.api.gax.rpc.UnaryCallable; +import com.google.auth.Credentials; import com.google.type.Color; import com.google.type.Money; import io.grpc.CallOptions; @@ -133,10 +135,18 @@ public void close(Status status, Metadata trailers) { .usePlaintext() .intercept(new GrpcMetadataHandlerInterceptor()) .build(); + + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + clientContext = ClientContext.newBuilder() .setTransportChannel(GrpcTransportChannel.create(channel)) - .setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT)) + .setDefaultCallContext( + GrpcCallContext.of(channel, CallOptions.DEFAULT) + .withEndpointContext(endpointContext)) .build(); } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/TimeoutTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/TimeoutTest.java index 9de95c1752..d40153eff8 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/TimeoutTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/TimeoutTest.java @@ -35,6 +35,7 @@ import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.RequestParamsExtractor; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; @@ -43,6 +44,7 @@ import com.google.api.gax.rpc.UnaryCallSettings; import com.google.api.gax.rpc.UnaryCallable; import com.google.api.gax.rpc.testing.FakeStatusCode; +import com.google.auth.Credentials; import com.google.common.collect.ImmutableSet; import io.grpc.CallOptions; import io.grpc.ClientCall; @@ -51,7 +53,9 @@ import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; +import java.io.IOException; import java.util.concurrent.TimeUnit; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -77,13 +81,22 @@ public class TimeoutTest { private static final Duration totalTimeout = Duration.ofDays(DEADLINE_IN_DAYS); private static final Duration maxRpcTimeout = Duration.ofMinutes(DEADLINE_IN_MINUTES); private static final Duration initialRpcTimeout = Duration.ofSeconds(DEADLINE_IN_SECONDS); - private static final GrpcCallContext defaultCallContext = GrpcCallContext.createDefault(); + private static GrpcCallContext defaultCallContext; @Rule public MockitoRule mockitoRule = MockitoJUnit.rule().strictness(Strictness.STRICT_STUBS); @Mock private Marshaller stringMarshaller; @Mock private RequestParamsExtractor paramsExtractor; @Mock private ManagedChannel managedChannel; + @BeforeClass + public static void setUp() throws IOException { + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class)); + defaultCallContext = GrpcCallContext.createDefault().withEndpointContext(endpointContext); + } + @Test public void testNonRetryUnarySettings() { RetrySettings retrySettings = diff --git a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java index 2d06244bf1..890c205a61 100644 --- a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java +++ b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java @@ -30,10 +30,14 @@ package com.google.api.gax.httpjson; import com.google.api.core.BetaApi; +import com.google.api.core.InternalApi; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ApiExceptionFactory; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; +import com.google.api.gax.rpc.UnauthenticatedException; import com.google.api.gax.rpc.internal.ApiCallContextOptions; import com.google.api.gax.rpc.internal.Headers; import com.google.api.gax.tracing.ApiTracer; @@ -42,6 +46,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Objects; @@ -60,6 +65,8 @@ * arguments solely depends on the arguments themselves. */ public final class HttpJsonCallContext implements ApiCallContext { + private static final HttpJsonStatusCode UNAUTHENTICATED_STATUS_CODE = + HttpJsonStatusCode.of(StatusCode.Code.UNAUTHENTICATED); private final HttpJsonChannel channel; private final HttpJsonCallOptions callOptions; @Nullable private final Duration timeout; @@ -70,6 +77,7 @@ public final class HttpJsonCallContext implements ApiCallContext { private final ApiTracer tracer; @Nullable private final RetrySettings retrySettings; @Nullable private final ImmutableSet retryableCodes; + private final EndpointContext endpointContext; /** Returns an empty instance. */ public static HttpJsonCallContext createDefault() { @@ -83,6 +91,7 @@ public static HttpJsonCallContext createDefault() { ApiCallContextOptions.getDefaultOptions(), null, null, + null, null); } @@ -97,6 +106,7 @@ public static HttpJsonCallContext of(HttpJsonChannel channel, HttpJsonCallOption ApiCallContextOptions.getDefaultOptions(), null, null, + null, null); } @@ -110,7 +120,8 @@ private HttpJsonCallContext( ApiCallContextOptions options, ApiTracer tracer, RetrySettings defaultRetrySettings, - Set defaultRetryableCodes) { + Set defaultRetryableCodes, + EndpointContext endpointContext) { this.channel = channel; this.callOptions = callOptions; this.timeout = timeout; @@ -122,6 +133,7 @@ private HttpJsonCallContext( this.retrySettings = defaultRetrySettings; this.retryableCodes = defaultRetryableCodes == null ? null : ImmutableSet.copyOf(defaultRetryableCodes); + this.endpointContext = endpointContext; } /** @@ -201,6 +213,8 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { newRetryableCodes = this.retryableCodes; } + // The EndpointContext is not updated as there should be no reason for a user + // to update this. return new HttpJsonCallContext( newChannel, newCallOptions, @@ -211,7 +225,8 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { newOptions, newTracer, newRetrySettings, - newRetryableCodes); + newRetryableCodes, + endpointContext); } @Override @@ -232,6 +247,23 @@ public HttpJsonCallContext withTransportChannel(TransportChannel inputChannel) { return withChannel(transportChannel.getChannel()); } + @Override + public HttpJsonCallContext withEndpointContext(EndpointContext endpointContext) { + Preconditions.checkNotNull(endpointContext); + return new HttpJsonCallContext( + this.channel, + this.callOptions, + this.timeout, + this.streamWaitTimeout, + this.streamIdleTimeout, + this.extraHeaders, + this.options, + this.tracer, + this.retrySettings, + this.retryableCodes, + endpointContext); + } + @Override public HttpJsonCallContext withTimeout(Duration timeout) { // Default RetrySettings use 0 for RPC timeout. Treat that as disabled timeouts. @@ -254,7 +286,8 @@ public HttpJsonCallContext withTimeout(Duration timeout) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Nullable @@ -280,7 +313,8 @@ public HttpJsonCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTi this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } /** @@ -311,7 +345,8 @@ public HttpJsonCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTi this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } /** @@ -341,7 +376,8 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -364,7 +400,8 @@ public ApiCallContext withOption(Key key, T value) { newOptions, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } /** {@inheritDoc} */ @@ -373,6 +410,31 @@ public T getOption(Key key) { return options.getOption(key); } + /** + * Validate the Universe Domain to ensure that the user configured Universe Domain and the + * Credentials' Universe Domain match. An exception will be raised if there are any issues when + * trying to validate (i.e. unable to access the universe domain). + * + * @throws UnauthenticatedException Thrown if the universe domain that the user configured does + * not match the Credential's universe domain or if the client library is unable to retrieve + * the Universe Domain from the Credentials. + */ + @InternalApi + public void validateUniverseDomain() { + try { + endpointContext.validateUniverseDomain( + getCallOptions().getCredentials(), UNAUTHENTICATED_STATUS_CODE); + } catch (IOException e) { + // All instances of IOException from endpointContext's `validateUniverseDomain()` + // call should be an Auth Exception + throw ApiExceptionFactory.createException( + EndpointContext.UNABLE_TO_RETRIEVE_CREDENTIALS_ERROR_MESSAGE, + e, + UNAUTHENTICATED_STATUS_CODE, + false); + } + } + public HttpJsonChannel getChannel() { return channel; } @@ -410,7 +472,8 @@ public HttpJsonCallContext withRetrySettings(RetrySettings retrySettings) { this.options, this.tracer, retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -430,7 +493,8 @@ public HttpJsonCallContext withRetryableCodes(Set retryableCode this.options, this.tracer, this.retrySettings, - retryableCodes); + retryableCodes, + this.endpointContext); } public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { @@ -444,7 +508,8 @@ public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { @@ -458,7 +523,8 @@ public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Deprecated @@ -492,7 +558,8 @@ public HttpJsonCallContext withTracer(@Nonnull ApiTracer newTracer) { this.options, newTracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -511,7 +578,8 @@ public boolean equals(Object o) { && Objects.equals(this.options, that.options) && Objects.equals(this.tracer, that.tracer) && Objects.equals(this.retrySettings, that.retrySettings) - && Objects.equals(this.retryableCodes, that.retryableCodes); + && Objects.equals(this.retryableCodes, that.retryableCodes) + && Objects.equals(this.endpointContext, that.endpointContext); } @Override @@ -524,6 +592,7 @@ public int hashCode() { options, tracer, retrySettings, - retryableCodes); + retryableCodes, + endpointContext); } } diff --git a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonClientCalls.java b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonClientCalls.java index ae1ae3ca84..880a38d56a 100644 --- a/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonClientCalls.java +++ b/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonClientCalls.java @@ -72,6 +72,10 @@ public static HttpJsonClientCall newC httpJsonContext = httpJsonContext.withCallOptions(callOptions); } + // Validate the Universe Domain prior to the call. Only allow the call to go through + // if the Universe Domain is valid. + httpJsonContext.validateUniverseDomain(); + // TODO: add headers interceptor logic return httpJsonContext.getChannel().newCall(methodDescriptor, httpJsonContext.getCallOptions()); } diff --git a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientCallsTest.java b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientCallsTest.java new file mode 100644 index 0000000000..c4cdcf1390 --- /dev/null +++ b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientCallsTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2024 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.httpjson; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.api.gax.rpc.EndpointContext; +import com.google.api.gax.rpc.StatusCode; +import com.google.api.gax.rpc.UnauthenticatedException; +import com.google.auth.Credentials; +import com.google.auth.Retryable; +import java.io.IOException; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class HttpJsonClientCallsTest { + + // Auth Library's GoogleAuthException is package-private. Copy basic functionality for tests + private static class GoogleAuthException extends IOException implements Retryable { + + private final boolean isRetryable; + + private GoogleAuthException(boolean isRetryable) { + this.isRetryable = isRetryable; + } + + @Override + public boolean isRetryable() { + return isRetryable; + } + + @Override + public int getRetryCount() { + return 0; + } + } + + private Credentials credentials; + private EndpointContext endpointContext; + private HttpJsonChannel mockChannel; + private ApiMethodDescriptor descriptor; + private HttpJsonCallOptions callOptions; + private HttpJsonCallContext callContext; + + @Before + public void setUp() throws IOException { + credentials = Mockito.mock(Credentials.class); + endpointContext = Mockito.mock(EndpointContext.class); + mockChannel = Mockito.mock(HttpJsonChannel.class); + descriptor = Mockito.mock(ApiMethodDescriptor.class); + callOptions = Mockito.mock(HttpJsonCallOptions.class); + + callContext = + HttpJsonCallContext.of(mockChannel, callOptions) + .withEndpointContext(endpointContext) + .withChannel(mockChannel); + + Mockito.when(callOptions.getCredentials()).thenReturn(credentials); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); + } + + @Test + public void testValidUniverseDomain() { + HttpJsonClientCalls.newCall(descriptor, callContext); + Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions); + } + + // This test is when the universe domain does not match + @Test + public void testInvalidUniverseDomain() throws IOException { + Mockito.doThrow( + new UnauthenticatedException( + null, HttpJsonStatusCode.of(StatusCode.Code.UNAUTHENTICATED), false)) + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); + + UnauthenticatedException exception = + assertThrows( + UnauthenticatedException.class, + () -> HttpJsonClientCalls.newCall(descriptor, callContext)); + assertThat(exception.getStatusCode().getCode()) + .isEqualTo(HttpJsonStatusCode.Code.UNAUTHENTICATED); + Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions); + } + + // This test is when the MDS is unable to return a valid universe domain + @Test + public void testUniverseDomainNotReady_shouldRetry() throws IOException { + Mockito.doThrow(new GoogleAuthException(true)) + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); + HttpJsonCallContext context = + HttpJsonCallContext.createDefault() + .withChannel(mockChannel) + .withCredentials(credentials) + .withEndpointContext(endpointContext); + + HttpJsonCallOptions callOptions = context.getCallOptions(); + + UnauthenticatedException exception = + assertThrows( + UnauthenticatedException.class, + () -> HttpJsonClientCalls.newCall(descriptor, callContext)); + assertThat(exception.getStatusCode().getCode()) + .isEqualTo(HttpJsonStatusCode.Code.UNAUTHENTICATED); + Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions); + } +} diff --git a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientInterceptorTest.java b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientInterceptorTest.java index 6e081fb75f..463b76112b 100644 --- a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientInterceptorTest.java +++ b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonClientInterceptorTest.java @@ -34,6 +34,8 @@ import com.google.api.gax.httpjson.ForwardingHttpJsonClientCall.SimpleForwardingHttpJsonClientCall; import com.google.api.gax.httpjson.ForwardingHttpJsonClientCallListener.SimpleForwardingHttpJsonClientCallListener; import com.google.api.gax.httpjson.testing.MockHttpService; +import com.google.api.gax.rpc.EndpointContext; +import com.google.auth.Credentials; import com.google.protobuf.Field; import com.google.protobuf.Field.Cardinality; import java.io.IOException; @@ -51,6 +53,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -178,14 +181,21 @@ public void tearDown() { } @Test - public void testCustomInterceptor() throws ExecutionException, InterruptedException { + public void testCustomInterceptor() throws ExecutionException, InterruptedException, IOException { HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); + HttpJsonCallContext callContext = HttpJsonCallContext.createDefault() .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); + .withTimeout(Duration.ofSeconds(30)) + .withEndpointContext(endpointContext); Field request; Field expectedResponse; diff --git a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectCallableTest.java b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectCallableTest.java index fa666dc69c..619052744a 100644 --- a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectCallableTest.java +++ b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectCallableTest.java @@ -35,10 +35,13 @@ import com.google.api.gax.httpjson.testing.MockHttpService; import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ApiExceptionFactory; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.testing.FakeStatusCode; +import com.google.auth.Credentials; import com.google.protobuf.Field; import com.google.protobuf.Field.Cardinality; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -53,6 +56,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -94,7 +98,9 @@ public class HttpJsonDirectCallableTest { private static final MockHttpService MOCK_SERVICE = new MockHttpService(Collections.singletonList(FAKE_METHOD_DESCRIPTOR), "google.com:443"); - private final ManagedHttpJsonChannel channel = + private static ExecutorService executorService; + + private static final ManagedHttpJsonChannel channel = new ManagedHttpJsonInterceptorChannel( ManagedHttpJsonChannel.newBuilder() .setEndpoint("google.com:443") @@ -103,10 +109,10 @@ public class HttpJsonDirectCallableTest { .build(), new HttpJsonHeaderInterceptor(Collections.singletonMap("header-key", "headerValue"))); - private static ExecutorService executorService; + private static HttpJsonCallContext defaultCallContext; @BeforeClass - public static void initialize() { + public static void initialize() throws IOException { executorService = Executors.newFixedThreadPool( 2, @@ -115,6 +121,16 @@ public static void initialize() { t.setDaemon(true); return t; }); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); + defaultCallContext = + HttpJsonCallContext.createDefault() + .withChannel(channel) + .withTimeout(Duration.ofSeconds(30)) + .withEndpointContext(endpointContext); } @AfterClass @@ -132,18 +148,13 @@ public void testSuccessfulUnaryResponse() throws ExecutionException, Interrupted HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - Field request; Field expectedResponse; request = expectedResponse = createTestMessage(2); MOCK_SERVICE.addResponse(expectedResponse); - Field actualResponse = callable.futureCall(request, callContext).get(); + Field actualResponse = callable.futureCall(request, defaultCallContext).get(); assertThat(actualResponse).isEqualTo(expectedResponse); assertThat(MOCK_SERVICE.getRequestPaths().size()).isEqualTo(1); @@ -167,11 +178,6 @@ public void testSuccessfulMultipleResponsesForUnaryCall() HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - Field request = createTestMessage(2); Field expectedResponse = createTestMessage(2); Field otherResponse = createTestMessage(10); @@ -179,7 +185,7 @@ public void testSuccessfulMultipleResponsesForUnaryCall() MOCK_SERVICE.addResponse(otherResponse); MOCK_SERVICE.addResponse(otherResponse); - Field actualResponse = callable.futureCall(request, callContext).get(); + Field actualResponse = callable.futureCall(request, defaultCallContext).get(); assertThat(actualResponse).isEqualTo(expectedResponse); assertThat(MOCK_SERVICE.getRequestPaths().size()).isEqualTo(1); String headerValue = MOCK_SERVICE.getRequestHeaders().get("header-key").iterator().next(); @@ -202,11 +208,6 @@ public void testErrorMultipleResponsesForUnaryCall() HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - Field request = createTestMessage(2); Field expectedResponse = createTestMessage(2); Field randomResponse1 = createTestMessage(10); @@ -215,7 +216,7 @@ public void testErrorMultipleResponsesForUnaryCall() MOCK_SERVICE.addResponse(expectedResponse); MOCK_SERVICE.addResponse(randomResponse2); - Field actualResponse = callable.futureCall(request, callContext).get(); + Field actualResponse = callable.futureCall(request, defaultCallContext).get(); // Gax returns the first response for Unary Call assertThat(actualResponse).isEqualTo(randomResponse1); assertThat(actualResponse).isNotEqualTo(expectedResponse); @@ -234,18 +235,13 @@ public void testErrorUnaryResponse() throws InterruptedException { HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - ApiException exception = ApiExceptionFactory.createException( new Exception(), FakeStatusCode.of(Code.NOT_FOUND), false); MOCK_SERVICE.addException(exception); try { - callable.futureCall(createTestMessage(2), callContext).get(); + callable.futureCall(createTestMessage(2), defaultCallContext).get(); Assert.fail("No exception raised"); } catch (ExecutionException e) { HttpResponseException respExp = (HttpResponseException) e.getCause(); @@ -266,15 +262,10 @@ public void testErrorNullContentSuccessfulResponse() throws InterruptedException HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - MOCK_SERVICE.addNullResponse(); try { - callable.futureCall(createTestMessage(2), callContext).get(); + callable.futureCall(createTestMessage(2), defaultCallContext).get(); Assert.fail("No exception raised"); } catch (ExecutionException e) { HttpJsonStatusRuntimeException respExp = (HttpJsonStatusRuntimeException) e.getCause(); @@ -295,14 +286,10 @@ public void testErrorNullContentFailedResponse() throws InterruptedException { HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); MOCK_SERVICE.addNullResponse(400); try { - callable.futureCall(createTestMessage(2), callContext).get(); + callable.futureCall(createTestMessage(2), defaultCallContext).get(); Assert.fail("No exception raised"); } catch (ExecutionException e) { HttpResponseException respExp = (HttpResponseException) e.getCause(); @@ -321,18 +308,13 @@ public void testErrorNon2xxOr4xxResponse() throws InterruptedException { HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault() - .withChannel(channel) - .withTimeout(Duration.ofSeconds(30)); - ApiException exception = ApiExceptionFactory.createException( new Exception(), FakeStatusCode.of(Code.INTERNAL), false); MOCK_SERVICE.addException(500, exception); try { - callable.futureCall(createTestMessage(2), callContext).get(); + callable.futureCall(createTestMessage(2), defaultCallContext).get(); Assert.fail("No exception raised"); } catch (ExecutionException e) { HttpResponseException respExp = (HttpResponseException) e.getCause(); @@ -353,8 +335,7 @@ public void testDeadlineExceededResponse() throws InterruptedException { HttpJsonDirectCallable callable = new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); - HttpJsonCallContext callContext = - HttpJsonCallContext.createDefault().withChannel(channel).withTimeout(Duration.ofSeconds(3)); + HttpJsonCallContext callContext = defaultCallContext.withTimeout(Duration.ofSeconds(3)); Field response = createTestMessage(10); MOCK_SERVICE.addResponse(response, java.time.Duration.ofSeconds(5)); diff --git a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectServerStreamingCallableTest.java b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectServerStreamingCallableTest.java index 82c8e3e0c2..3898b8e908 100644 --- a/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectServerStreamingCallableTest.java +++ b/gax-java/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonDirectServerStreamingCallableTest.java @@ -35,6 +35,7 @@ import com.google.api.gax.rpc.ApiException; import com.google.api.gax.rpc.ClientContext; import com.google.api.gax.rpc.DeadlineExceededException; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.ServerStreamingCallSettings; @@ -44,11 +45,13 @@ import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.StreamController; import com.google.api.gax.rpc.testing.FakeCallContext; +import com.google.auth.Credentials; import com.google.common.collect.Lists; import com.google.common.truth.Truth; import com.google.protobuf.Field; import com.google.type.Color; import com.google.type.Money; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -65,6 +68,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.threeten.bp.Duration; @RunWith(JUnit4.class) @@ -124,7 +128,7 @@ public class HttpJsonDirectServerStreamingCallableTest { private static ExecutorService executorService; @BeforeClass - public static void initialize() { + public static void initialize() throws IOException { executorService = Executors.newFixedThreadPool(2); channel = new ManagedHttpJsonInterceptorChannel( @@ -134,12 +138,18 @@ public static void initialize() { .setHttpTransport(MOCK_SERVICE) .build(), new HttpJsonHeaderInterceptor(Collections.singletonMap("header-key", "headerValue"))); + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); + Mockito.doNothing() + .when(endpointContext) + .validateUniverseDomain( + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); clientContext = ClientContext.newBuilder() .setTransportChannel(HttpJsonTransportChannel.create(channel)) .setDefaultCallContext( HttpJsonCallContext.of(channel, HttpJsonCallOptions.DEFAULT) - .withTimeout(Duration.ofSeconds(3))) + .withTimeout(Duration.ofSeconds(3)) + .withEndpointContext(endpointContext)) .build(); streamingCallSettings = ServerStreamingCallSettings.newBuilder().build(); @@ -202,7 +212,7 @@ public void testServerStreamingStart() throws InterruptedException { // wait for the task to complete, otherwise it may interfere with other tests, since they share // the same MockService and unfinished request in this test may start reading messages // designated for other tests. - Truth.assertThat(latch.await(60, TimeUnit.SECONDS)).isTrue(); + Truth.assertThat(latch.await(2, TimeUnit.SECONDS)).isTrue(); } @Test diff --git a/gax-java/gax/clirr-ignored-differences.xml b/gax-java/gax/clirr-ignored-differences.xml index 1c4dd28a3c..c8e68444ff 100644 --- a/gax-java/gax/clirr-ignored-differences.xml +++ b/gax-java/gax/clirr-ignored-differences.xml @@ -31,4 +31,15 @@ com/google/api/gax/rpc/ClientContext* * *UniverseDomain*(*) + + + 7012 + com/google/api/gax/rpc/ApiCallContext + * withEndpointContext(*) + + + 7012 + com/google/api/gax/rpc/ApiCallContext + * validateUniverseDomain() + diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index 5a34b58e39..e650564826 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -63,6 +63,9 @@ public interface ApiCallContext extends RetryingContext { /** Returns a new ApiCallContext with the given channel set. */ ApiCallContext withTransportChannel(TransportChannel channel); + /** Returns a new ApiCallContext with the given Endpoint Context. */ + ApiCallContext withEndpointContext(EndpointContext endpointContext); + /** * Returns a new ApiCallContext with the given timeout set. * diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java index e7fac9d0c6..03a0cf86d9 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java @@ -230,6 +230,7 @@ public static ClientContext create(StubSettings settings) throws IOException { if (credentials != null) { defaultCallContext = defaultCallContext.withCredentials(credentials); } + defaultCallContext = defaultCallContext.withEndpointContext(endpointContext); WatchdogProvider watchdogProvider = settings.getStreamWatchdogProvider(); @Nullable Watchdog watchdog = null; diff --git a/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java b/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java index bbd0702500..b33a97cd0a 100644 --- a/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java +++ b/gax-java/gax/src/main/java/com/google/api/gax/rpc/EndpointContext.java @@ -38,10 +38,19 @@ import java.io.IOException; import javax.annotation.Nullable; -/** Contains the fields required to resolve the endpoint and Universe Domain */ +/** + * EndpointContext is an internal class used by the client library to resolve the endpoint. It is + * created once the library is initialized should not be updated manually. + * + *

Contains the fields required to resolve the endpoint and Universe Domain + */ @InternalApi @AutoValue public abstract class EndpointContext { + private static final String INVALID_UNIVERSE_DOMAIN_ERROR_TEMPLATE = + "The configured universe domain (%s) does not match the universe domain found in the credentials (%s). If you haven't configured the universe domain explicitly, `googleapis.com` is the default."; + public static final String UNABLE_TO_RETRIEVE_CREDENTIALS_ERROR_MESSAGE = + "Unable to retrieve the Universe Domain from the Credentials."; /** * ServiceName is host URI for Google Cloud Services. It follows the format of @@ -95,6 +104,42 @@ public static Builder newBuilder() { .setUsingGDCH(false); } + /** + * Check that the User configured universe domain matches the Credentials' universe domain. The + * status code parameter is passed in to this method as it's a limitation of Gax's modules. The + * transport-neutral module does have access the transport-specific modules (which contain the + * implementation of the StatusCode). This method is scoped to be internal and should be not be + * accessed by users. + * + * @param credentials Auth Library Credentials + * @param invalidUniverseDomainStatusCode Transport-specific Status Code to be returned if the + * Universe Domain is invalid. For both transports, this is defined to be Unauthorized. + * @throws IOException Implementation of Auth's Retryable interface which tells the client library + * whether the RPC should be retried or not. + */ + public void validateUniverseDomain( + Credentials credentials, StatusCode invalidUniverseDomainStatusCode) throws IOException { + if (usingGDCH()) { + // GDC-H has no universe domain, return + return; + } + String credentialsUniverseDomain = Credentials.GOOGLE_DEFAULT_UNIVERSE; + // If credentials is not NoCredentialsProvider, use the Universe Domain inside Credentials + if (credentials != null) { + credentialsUniverseDomain = credentials.getUniverseDomain(); + } + if (!resolvedUniverseDomain().equals(credentialsUniverseDomain)) { + throw ApiExceptionFactory.createException( + new Throwable( + String.format( + EndpointContext.INVALID_UNIVERSE_DOMAIN_ERROR_TEMPLATE, + resolvedUniverseDomain(), + credentialsUniverseDomain)), + invalidUniverseDomainStatusCode, + false); + } + } + @AutoValue.Builder public abstract static class Builder { /** diff --git a/gax-java/gax/src/test/java/com/google/api/gax/rpc/EndpointContextTest.java b/gax-java/gax/src/test/java/com/google/api/gax/rpc/EndpointContextTest.java index 811a6bd92b..f0dbae60f2 100644 --- a/gax-java/gax/src/test/java/com/google/api/gax/rpc/EndpointContextTest.java +++ b/gax-java/gax/src/test/java/com/google/api/gax/rpc/EndpointContextTest.java @@ -31,21 +31,25 @@ import static org.junit.Assert.assertThrows; +import com.google.api.gax.core.NoCredentialsProvider; import com.google.api.gax.rpc.mtls.MtlsProvider; import com.google.api.gax.rpc.testing.FakeMtlsProvider; import com.google.auth.Credentials; import com.google.common.truth.Truth; +import io.grpc.Status; import java.io.IOException; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; @RunWith(JUnit4.class) public class EndpointContextTest { private static final String DEFAULT_ENDPOINT = "test.googleapis.com:443"; private static final String DEFAULT_MTLS_ENDPOINT = "test.mtls.googleapis.com:443"; private EndpointContext.Builder defaultEndpointContextBuilder; + private StatusCode statusCode; @Before public void setUp() throws IOException { @@ -55,6 +59,9 @@ public void setUp() throws IOException { .setUniverseDomain(Credentials.GOOGLE_DEFAULT_UNIVERSE) .setClientSettingsEndpoint(DEFAULT_ENDPOINT) .setMtlsEndpoint(DEFAULT_MTLS_ENDPOINT); + statusCode = Mockito.mock(StatusCode.class); + Mockito.when(statusCode.getCode()).thenReturn(StatusCode.Code.UNAUTHENTICATED); + Mockito.when(statusCode.getTransportCode()).thenReturn(Status.Code.UNAUTHENTICATED); } @Test @@ -332,4 +339,72 @@ public void endpointContextBuild_gdchFlow_noUniverseDomain_customEndpoint() thro Truth.assertThat(endpointContext.resolvedUniverseDomain()) .isEqualTo(Credentials.GOOGLE_DEFAULT_UNIVERSE); } + + @Test + public void hasValidUniverseDomain_gdchFlow_anyCredentials() throws IOException { + Credentials noCredentials = NoCredentialsProvider.create().getCredentials(); + Credentials validCredentials = Mockito.mock(Credentials.class); + EndpointContext endpointContext = + defaultEndpointContextBuilder.setUniverseDomain(null).setUsingGDCH(true).build(); + endpointContext.validateUniverseDomain(noCredentials, statusCode); + endpointContext.validateUniverseDomain(validCredentials, statusCode); + } + + @Test + public void hasValidUniverseDomain_noCredentials_inGDU() throws IOException { + Credentials noCredentials = NoCredentialsProvider.create().getCredentials(); + EndpointContext endpointContext = defaultEndpointContextBuilder.build(); + endpointContext.validateUniverseDomain(noCredentials, statusCode); + } + + @Test + public void hasValidUniverseDomain_noCredentials_nonGDU() throws IOException { + Credentials noCredentials = NoCredentialsProvider.create().getCredentials(); + EndpointContext endpointContext = + defaultEndpointContextBuilder.setUniverseDomain("test.com").build(); + assertThrows( + UnauthenticatedException.class, + () -> endpointContext.validateUniverseDomain(noCredentials, statusCode)); + } + + @Test + public void hasValidUniverseDomain_credentialsInGDU_configInGDU() throws IOException { + Credentials credentials = Mockito.mock(Credentials.class); + Mockito.when(credentials.getUniverseDomain()).thenReturn(Credentials.GOOGLE_DEFAULT_UNIVERSE); + EndpointContext endpointContext = defaultEndpointContextBuilder.build(); + endpointContext.validateUniverseDomain(credentials, statusCode); + } + + // Non-GDU Universe Domain could be any domain, but this test refers uses `test.com` + @Test + public void hasValidUniverseDomain_credentialsNonGDU_configInGDU() throws IOException { + Credentials credentials = Mockito.mock(Credentials.class); + Mockito.when(credentials.getUniverseDomain()).thenReturn("test.com"); + EndpointContext endpointContext = defaultEndpointContextBuilder.build(); + assertThrows( + UnauthenticatedException.class, + () -> endpointContext.validateUniverseDomain(credentials, statusCode)); + } + + // Non-GDU Universe Domain could be any domain, but this test refers uses `test.com` + @Test + public void hasValidUniverseDomain_credentialsNonGDU_configNonGDU() throws IOException { + Credentials credentials = Mockito.mock(Credentials.class); + Mockito.when(credentials.getUniverseDomain()).thenReturn("test.com"); + EndpointContext endpointContext = + defaultEndpointContextBuilder.setUniverseDomain("test.com").build(); + endpointContext.validateUniverseDomain(credentials, statusCode); + } + + // Non-GDU Universe Domain could be any domain, but this test refers uses `test.com` + @Test + public void hasValidUniverseDomain_credentialsInGDU_configNonGDU() throws IOException { + Credentials credentials = Mockito.mock(Credentials.class); + Mockito.when(credentials.getUniverseDomain()).thenReturn(Credentials.GOOGLE_DEFAULT_UNIVERSE); + EndpointContext endpointContext = + defaultEndpointContextBuilder.setUniverseDomain("test.com").build(); + assertThrows( + UnauthenticatedException.class, + () -> endpointContext.validateUniverseDomain(credentials, statusCode)); + } } diff --git a/gax-java/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java b/gax-java/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java index 84d40e11cc..e7c6c90b1e 100644 --- a/gax-java/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java +++ b/gax-java/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java @@ -33,6 +33,7 @@ import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.ClientContext; +import com.google.api.gax.rpc.EndpointContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; import com.google.api.gax.rpc.internal.ApiCallContextOptions; @@ -62,6 +63,7 @@ public class FakeCallContext implements ApiCallContext { private final ApiTracer tracer; private final RetrySettings retrySettings; private final ImmutableSet retryableCodes; + private final EndpointContext endpointContext; private FakeCallContext( Credentials credentials, @@ -73,7 +75,8 @@ private FakeCallContext( ApiCallContextOptions options, ApiTracer tracer, RetrySettings retrySettings, - Set retryableCodes) { + Set retryableCodes, + EndpointContext endpointContext) { this.credentials = credentials; this.channel = channel; this.timeout = timeout; @@ -84,6 +87,7 @@ private FakeCallContext( this.tracer = tracer; this.retrySettings = retrySettings; this.retryableCodes = retryableCodes == null ? null : ImmutableSet.copyOf(retryableCodes); + this.endpointContext = endpointContext; } public static FakeCallContext createDefault() { @@ -97,6 +101,7 @@ public static FakeCallContext createDefault() { ApiCallContextOptions.getDefaultOptions(), null, null, + null, null); } @@ -183,7 +188,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newOptions, newTracer, newRetrySettings, - newRetryableCodes); + newRetryableCodes, + endpointContext); } public RetrySettings getRetrySettings() { @@ -201,7 +207,8 @@ public FakeCallContext withRetrySettings(RetrySettings retrySettings) { this.options, this.tracer, retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } public Set getRetryableCodes() { @@ -219,7 +226,8 @@ public FakeCallContext withRetryableCodes(Set retryableCodes) { this.options, this.tracer, this.retrySettings, - retryableCodes); + retryableCodes, + this.endpointContext); } public Credentials getCredentials() { @@ -259,7 +267,8 @@ public FakeCallContext withCredentials(Credentials credentials) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -273,6 +282,23 @@ public FakeCallContext withTransportChannel(TransportChannel inputChannel) { return withChannel(transportChannel.getChannel()); } + @Override + public FakeCallContext withEndpointContext(EndpointContext endpointContext) { + Preconditions.checkNotNull(endpointContext); + return new FakeCallContext( + this.credentials, + this.channel, + this.timeout, + this.streamWaitTimeout, + this.streamIdleTimeout, + this.extraHeaders, + this.options, + this.tracer, + this.retrySettings, + this.retryableCodes, + endpointContext); + } + public FakeCallContext withChannel(FakeChannel channel) { return new FakeCallContext( this.credentials, @@ -284,7 +310,8 @@ public FakeCallContext withChannel(FakeChannel channel) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -309,7 +336,8 @@ public FakeCallContext withTimeout(Duration timeout) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -324,7 +352,8 @@ public ApiCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTimeout this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -340,7 +369,8 @@ public ApiCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTimeout this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -358,7 +388,8 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { this.options, this.tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } @Override @@ -380,7 +411,8 @@ public ApiCallContext withOption(Key key, T value) { newOptions, tracer, retrySettings, - retryableCodes); + retryableCodes, + this.endpointContext); } @Override @@ -414,7 +446,8 @@ public ApiCallContext withTracer(@Nonnull ApiTracer tracer) { this.options, tracer, this.retrySettings, - this.retryableCodes); + this.retryableCodes, + this.endpointContext); } public static FakeCallContext create(ClientContext clientContext) {