diff --git a/containers/test-apps/courier/pom.xml b/containers/test-apps/courier/pom.xml index e6675bed8..7aa76f242 100644 --- a/containers/test-apps/courier/pom.xml +++ b/containers/test-apps/courier/pom.xml @@ -6,7 +6,7 @@ twosigma courier - 1.2.1 + 1.4.6 courier https://github.com/twosigma/waiter/tree/master/test-apps/courier diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java index 2c3295750..330aebb3b 100644 --- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java +++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java @@ -36,10 +36,13 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -49,19 +52,57 @@ public class GrpcClient { private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName()); - private static Function logFunction = new Function() { - @Override - public Void apply(final String message) { - LOGGER.info(message); - return null; + public static final class RpcResult { + private final Result result; + private final Status status; + + private RpcResult(final Result result, final Status status) { + this.result = result; + this.status = status; + } + + public Result result() { + return result; + } + + public Status status() { + return status; + } + } + + private static Variant retrieveVariant(final String id) { + if (id.contains("SEND_ERROR")) { + return Variant.SEND_ERROR; + } else if (id.contains("EXIT_PRE_RESPONSE")) { + return Variant.EXIT_PRE_RESPONSE; + } else if (id.contains("EXIT_POST_RESPONSE")) { + return Variant.EXIT_POST_RESPONSE; + } else { + return Variant.NORMAL; } - }; + } + + private final Function logFunction; + private final String host; + private final int port; - public static void setLogFunction(final Function logFunction) { - GrpcClient.logFunction = logFunction; + public GrpcClient(final String host, final int port) { + this(host, port, new Function() { + @Override + public Void apply(final String message) { + LOGGER.info(message); + return null; + } + }); + } + + public GrpcClient(final String host, final int port, final Function logFunction) { + this.host = host; + this.port = port; + this.logFunction = logFunction; } - private static ManagedChannel initializeChannel(final String host, final int port) { + private ManagedChannel initializeChannel() { logFunction.apply("initializing plaintext client at " + host + ":" + port); return ManagedChannelBuilder .forAddress(host, port) @@ -69,17 +110,21 @@ private static ManagedChannel initializeChannel(final String host, final int por .build(); } - private static void shutdownChannel(final ManagedChannel channel) throws InterruptedException { + private void shutdownChannel(final ManagedChannel channel) { logFunction.apply("shutting down channel"); - channel.shutdown().awaitTermination(1, TimeUnit.SECONDS); - if (channel.isShutdown()) { - logFunction.apply("channel shutdown successfully"); - } else { - logFunction.apply("channel shutdown timed out!"); + try { + channel.shutdown().awaitTermination(1, TimeUnit.SECONDS); + if (channel.isShutdown()) { + logFunction.apply("channel shutdown successfully"); + } else { + logFunction.apply("channel shutdown timed out!"); + } + } catch (Exception ex) { + logFunction.apply("error in channel shutdown: " + ex.getMessage()); } } - private static Metadata createRequestHeadersMetadata(final Map headers) { + private Metadata createRequestHeadersMetadata(final Map headers) { final Metadata headerMetadata = new Metadata(); for (Map.Entry entry : headers.entrySet()) { final String key = entry.getKey(); @@ -89,7 +134,7 @@ private static Metadata createRequestHeadersMetadata(final Map h return headerMetadata; } - private static Channel wrapResponseLogger(final ManagedChannel channel) { + private Channel wrapResponseLogger(final ManagedChannel channel) { return ClientInterceptors.intercept(channel, new ClientInterceptor() { @Override public ClientCall interceptCall(final MethodDescriptor method, @@ -120,13 +165,11 @@ public void onClose(final Status status, final Metadata trailers) { }); } - public static CourierReply sendPackage(final String host, - final int port, - final Map headers, - final String id, - final String from, - final String message) throws InterruptedException { - final ManagedChannel channel = initializeChannel(host, port); + public RpcResult sendPackage(final Map headers, + final String id, + final String from, + final String message) { + final ManagedChannel channel = initializeChannel(); try { final Channel wrappedChannel = wrapResponseLogger(channel); @@ -141,38 +184,52 @@ public static CourierReply sendPackage(final String host, .setId(id) .setFrom(from) .setMessage(message) + .setVariant(retrieveVariant(id)) .build(); - final CourierReply response; + + final AtomicReference status = new AtomicReference<>(); + final AtomicReference response = new AtomicReference<>(); try { - response = futureStub.sendPackage(request).get(); - } catch (final StatusRuntimeException e) { - logFunction.apply("RPC failed, status: " + e.getStatus()); - return null; - } catch (final Exception e) { - logFunction.apply("RPC failed, message: " + e.getMessage()); - return null; + final CourierReply reply = futureStub.sendPackage(request).get(); + status.set(Status.OK); + response.set(reply); + } catch (final StatusRuntimeException ex) { + final Status errorStatus = ex.getStatus(); + logFunction.apply("RPC failed, status: " + errorStatus); + status.set(errorStatus); + } catch (final ExecutionException ex) { + final Status errorStatus = Status.fromThrowable(ex.getCause()); + logFunction.apply("RPC execution failed: " + errorStatus); + status.set(errorStatus); + } catch (final Throwable th) { + logFunction.apply("RPC failed, message: " + th.getMessage()); + status.set(Status.UNKNOWN.withDescription(th.getMessage())); } - logFunction.apply("received response CourierReply{" + - "id=" + response.getId() + ", " + - "response=" + response.getResponse() + ", " + - "message.length=" + response.getMessage().length() + "}"); - logFunction.apply("messages equal = " + message.equals(response.getMessage())); - return response; + + if (response.get() != null) { + final CourierReply reply = response.get(); + logFunction.apply("received response CourierReply{" + + "id=" + reply.getId() + ", " + + "response=" + reply.getResponse() + ", " + + "message.length=" + reply.getMessage().length() + "}"); + logFunction.apply("messages equal = " + message.equals(reply.getMessage())); + } + return new RpcResult<>(response.get(), status.get()); } finally { shutdownChannel(channel); } } - public static List collectPackages(final String host, - final int port, - final Map headers, - final String idPrefix, - final String from, - final List messages, - final int interMessageSleepMs, - final boolean lockStepMode) throws InterruptedException { - final ManagedChannel channel = initializeChannel(host, port); + public RpcResult> collectPackages(final Map headers, + final List ids, + final String from, + final List messages, + final int interMessageSleepMs, + final boolean lockStepMode, + final int cancelThreshold) { + final ManagedChannel channel = initializeChannel(); + final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true); try { final Semaphore lockStep = new Semaphore(1); @@ -186,6 +243,9 @@ public static List collectPackages(final String host, logFunction.apply("will try to send package from " + from + " ..."); + final AtomicReference status = new AtomicReference<>(); + final AtomicReference> response = new AtomicReference<>(); + final CompletableFuture> responsePromise = new CompletableFuture<>(); try { final StreamObserver collector = @@ -206,19 +266,26 @@ public void onNext(final CourierSummary response) { } @Override - public void onError(final Throwable throwable) { - logFunction.apply("error in collecting summaries " + throwable); + public void onError(final Throwable th) { + logFunction.apply("error in collecting summaries " + th); errorSignal.compareAndSet(false, true); - resolveResponsePromise(); if (lockStepMode) { logFunction.apply("releasing semaphore after receiving error"); lockStep.release(); } + if (th instanceof StatusRuntimeException) { + final StatusRuntimeException exception = (StatusRuntimeException) th; + status.set(exception.getStatus()); + } else { + status.set(Status.UNKNOWN.withDescription(th.getMessage())); + } + resolveResponsePromise(); } @Override public void onCompleted() { logFunction.apply("completed collecting summaries"); + status.set(Status.OK); resolveResponsePromise(); } @@ -229,11 +296,16 @@ private void resolveResponsePromise() { }); for (int i = 0; i < messages.size(); i++) { + if (i >= cancelThreshold) { + logFunction.apply("cancelling sending messages"); + awaitChannelTermination.set(false); + throw new CancellationException("Cancel threshold reached: " + cancelThreshold); + } if (errorSignal.get()) { logFunction.apply("aborting sending messages as error was discovered"); break; } - final String requestId = idPrefix + i; + final String requestId = ids.get(i); if (lockStepMode) { logFunction.apply("acquiring semaphore before sending request " + requestId); lockStep.acquire(); @@ -243,6 +315,7 @@ private void resolveResponsePromise() { .setId(requestId) .setFrom(from) .setMessage(messages.get(i)) + .setVariant(retrieveVariant(requestId)) .build(); logFunction.apply("sending message CourierRequest{" + "id=" + request.getId() + ", " + @@ -254,17 +327,134 @@ private void resolveResponsePromise() { logFunction.apply("completed sending packages"); collector.onCompleted(); - return responsePromise.get(); - } catch (final StatusRuntimeException e) { - logFunction.apply("RPC failed, status: " + e.getStatus()); - return null; - } catch (final Exception e) { - logFunction.apply("RPC failed, message: " + e.getMessage()); - return null; + response.set(responsePromise.get()); + } catch (final StatusRuntimeException ex) { + logFunction.apply("RPC failed, status: " + ex.getStatus()); + status.set(ex.getStatus()); + } catch (final Exception ex) { + logFunction.apply("RPC failed, message: " + ex.getMessage()); + status.set(Status.UNKNOWN.withDescription(ex.getMessage())); } + return new RpcResult<>(response.get(), status.get()); + } finally { - shutdownChannel(channel); + if (awaitChannelTermination.get()) { + shutdownChannel(channel); + } else { + channel.shutdownNow(); + } + } + } + + public RpcResult aggregatePackages(final Map headers, + final List ids, + final String from, + final List messages, + final int interMessageSleepMs, + final int cancelThreshold) { + final ManagedChannel channel = initializeChannel(); + final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true); + + try { + final AtomicBoolean errorSignal = new AtomicBoolean(false); + + final Channel wrappedChannel = wrapResponseLogger(channel); + final Metadata headerMetadata = createRequestHeadersMetadata(headers); + + final CourierGrpc.CourierStub rawStub = CourierGrpc.newStub(wrappedChannel); + final CourierGrpc.CourierStub futureStub = MetadataUtils.attachHeaders(rawStub, headerMetadata); + + logFunction.apply("will try to agggreate package from " + from + " ..."); + + final AtomicReference status = new AtomicReference<>(); + final AtomicReference response = new AtomicReference<>(); + + final CompletableFuture responsePromise = new CompletableFuture<>(); + try { + final StreamObserver collector = + futureStub.aggregatePackages(new StreamObserver() { + + @Override + public void onNext(final CourierSummary summary) { + logFunction.apply("received response CourierSummary{" + + "count=" + summary.getNumMessages() + ", " + + "length=" + summary.getTotalLength() + "}"); + response.set(summary); + } + + @Override + public void onError(final Throwable th) { + logFunction.apply("error in aggregating summaries " + th); + errorSignal.compareAndSet(false, true); + if (th instanceof StatusRuntimeException) { + final StatusRuntimeException exception = (StatusRuntimeException) th; + status.set(exception.getStatus()); + } else { + status.set(Status.UNKNOWN.withDescription(th.getMessage())); + } + resolveResponsePromise(); + } + + @Override + public void onCompleted() { + logFunction.apply("completed aggregating summaries"); + status.set(Status.OK); + resolveResponsePromise(); + } + + private void resolveResponsePromise() { + final CourierSummary courierSummary = response.get(); + logFunction.apply("client result: " + courierSummary); + responsePromise.complete(courierSummary); + } + }); + + for (int i = 0; i < messages.size(); i++) { + if (i >= cancelThreshold) { + logFunction.apply("cancelling sending messages"); + awaitChannelTermination.set(false); + throw new CancellationException("Cancel threshold reached: " + cancelThreshold); + } + if (errorSignal.get()) { + logFunction.apply("aborting sending messages as error was discovered"); + break; + } + final String requestId = ids.get(i); + final CourierRequest request = CourierRequest + .newBuilder() + .setId(requestId) + .setFrom(from) + .setMessage(messages.get(i)) + .setVariant(retrieveVariant(requestId)) + .build(); + logFunction.apply("sending message CourierRequest{" + + "id=" + request.getId() + ", " + + "from=" + request.getFrom() + ", " + + "message.length=" + request.getMessage().length() + "}"); + collector.onNext(request); + Thread.sleep(interMessageSleepMs); + } + logFunction.apply("completed sending packages"); + collector.onCompleted(); + + responsePromise.get(); + } catch (final StatusRuntimeException ex) { + logFunction.apply("RPC failed, status: " + ex.getStatus()); + status.set(ex.getStatus()); + } catch (final Exception ex) { + logFunction.apply("RPC failed, message: " + ex.getMessage()); + status.set(Status.UNKNOWN.withDescription(ex.getMessage())); + } + + return new RpcResult<>(response.get(), status.get()); + + } finally { + if (awaitChannelTermination.get()) { + shutdownChannel(channel); + } else { + channel.shutdownNow(); + } } } @@ -276,8 +466,20 @@ public static void main(final String... args) throws Exception { /* Access a service running on the local machine on port 8080 */ final String host = "localhost"; final int port = 8080; - final HashMap headers = new HashMap<>(); + final GrpcClient client = new GrpcClient(host, port); + + // runSendPackageSuccess(client); + // runSendPackageSendError(client); + // runCollectPackagesSuccess(client); + // runCollectPackagesSendError(client); + // runCollectPackagesExitPreResponse(client); + // runCollectPackagesExitPostResponse(client); + // runAggregatePackagesSuccess(client); + // runAggregatePackagesSendError(client); + // runAggregatePackagesExitPreResponse(client); + } + private static void runSendPackageSuccess(final GrpcClient client) { final String id = UUID.randomUUID().toString(); final String user = "Jim"; final StringBuilder sb = new StringBuilder(); @@ -287,16 +489,129 @@ public static void main(final String... args) throws Exception { sb.append("."); } } - final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString()); - logFunction.apply("sendPackage response = " + courierReply); - - final List messages = IntStream.range(0, 100).mapToObj(i -> "message-" + i).collect(Collectors.toList()); - final List courierSummaries = - collectPackages(host, port, headers, "id-", "User", messages, 10, true); - logFunction.apply("collectPackages response size = " + courierSummaries.size()); - if (!courierSummaries.isEmpty()) { - final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1); - logFunction.apply("collectPackages summary = " + courierSummary.toString()); + + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-send-package." + System.currentTimeMillis()); + final RpcResult rpcResult = client.sendPackage(headers, id, user, sb.toString()); + final CourierReply courierReply = rpcResult.result(); + client.logFunction.apply("sendPackage response = " + courierReply); + final Status status = rpcResult.status(); + client.logFunction.apply("sendPackage status = " + status); + } + + private static void runSendPackageSendError(final GrpcClient client) { + final String id = UUID.randomUUID().toString() + ".SEND_ERROR"; + final String user = "Jim"; + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100_000; i++) { + sb.append("a"); + if (i % 1000 == 0) { + sb.append("."); + } } + + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-send-package." + System.currentTimeMillis()); + final RpcResult rpcResult = client.sendPackage(headers, id, user, sb.toString()); + final CourierReply courierReply = rpcResult.result(); + client.logFunction.apply("sendPackage response = " + courierReply); + final Status status = rpcResult.status(); + client.logFunction.apply("sendPackage status = " + status); + } + + private static void runCollectPackagesSuccess(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-collect-packages-success." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult> rpcResult = + client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1); + final List courierSummaries = rpcResult.result(); + client.logFunction.apply("collectPackages[success] summary = " + courierSummaries); + final Status status = rpcResult.status(); + client.logFunction.apply("collectPackages[success] status = " + status); + } + + private static void runCollectPackagesSendError(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-collect-packages-server-error." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".SEND_ERROR"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult> rpcResult = + client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1); + final List courierSummaries = rpcResult.result(); + client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + final Status status = rpcResult.status(); + client.logFunction.apply("collectPackages[cancel] status = " + status); + } + + private static void runCollectPackagesExitPreResponse(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-collect-packages-server-pre-cancel." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult> rpcResult = + client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1); + final List courierSummaries = rpcResult.result(); + client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + final Status status = rpcResult.status(); + client.logFunction.apply("collectPackages[cancel] status = " + status); + } + + private static void runCollectPackagesExitPostResponse(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-collect-packages-server-post-cancel." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".EXIT_POST_RESPONSE"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult> rpcResult = + client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1); + final List courierSummaries = rpcResult.result(); + client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + final Status status = rpcResult.status(); + client.logFunction.apply("collectPackages[cancel] status = " + status); + } + + private static void runAggregatePackagesSuccess(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-aggregate-packages-success." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult rpcResult = + client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1); + final CourierSummary courierSummary = rpcResult.result(); + client.logFunction.apply("aggregatePackages[success] summary = " + courierSummary); + final Status status = rpcResult.status(); + client.logFunction.apply("aggregatePackages[success] status = " + status); + } + + private static void runAggregatePackagesSendError(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-aggregate-packages-server-error." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".SEND_ERROR"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult rpcResult = + client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1); + final CourierSummary courierSummary = rpcResult.result(); + client.logFunction.apply("aggregatePackages[cancel] summary = " + courierSummary); + final Status status = rpcResult.status(); + client.logFunction.apply("aggregatePackages[cancel] status = " + status); + } + + private static void runAggregatePackagesExitPreResponse(final GrpcClient client) { + final HashMap headers = new HashMap<>(); + headers.put("x-cid", "cid-aggregate-packages-server-pre-cancel." + System.currentTimeMillis()); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final RpcResult rpcResult = + client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1); + final CourierSummary courierSummary = rpcResult.result(); + client.logFunction.apply("aggregatePackages[cancel] summary = " + courierSummary); + final Status status = rpcResult.status(); + client.logFunction.apply("aggregatePackages[cancel] status = " + status); } } diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java index 9b64f3205..bad02109a 100644 --- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java +++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java @@ -22,6 +22,9 @@ import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.io.IOException; @@ -34,6 +37,14 @@ public class GrpcServer { private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName()); + private static void sleep(final int durationMillis) { + try { + Thread.sleep(durationMillis); + } catch (final Exception ex) { + ex.printStackTrace(); + } + } + private Server server; void start(final int port) throws IOException { @@ -76,44 +87,89 @@ public void sendPackage(final CourierRequest request, final StreamObserver() { private long numMessages = 0; private long totalLength = 0; @Override - public void onNext(final CourierRequest courierRequest) { - LOGGER.info("Received CourierRequest id=" + courierRequest.getId()); + public void onNext(final CourierRequest request) { + LOGGER.info("Received CourierRequest id=" + request.getId()); numMessages += 1; - totalLength += courierRequest.getMessage().length(); + totalLength += request.getMessage().length(); + LOGGER.info("Summary of collected packages: numMessages=" + numMessages + + " with totalLength=" + totalLength); - final CourierSummary courierSummary = CourierSummary - .newBuilder() - .setNumMessages(numMessages) - .setTotalLength(totalLength) - .build(); - LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId()); - responseObserver.onNext(courierSummary); + if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) { + sleep(1000); + LOGGER.info("Exiting server abruptly"); + System.exit(1); + } else if (Variant.SEND_ERROR.equals(request.getVariant())) { + final StatusRuntimeException error = Status.CANCELLED + .withCause(new RuntimeException(request.getId())) + .withDescription("Cancelled by server") + .asRuntimeException(); + LOGGER.info("Sending cancelled by server error"); + responseObserver.onError(error); + } else { + final CourierSummary courierSummary = CourierSummary + .newBuilder() + .setNumMessages(numMessages) + .setTotalLength(totalLength) + .build(); + LOGGER.info("Sending CourierSummary for id=" + request.getId()); + responseObserver.onNext(courierSummary); + } + + if (Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) { + sleep(1000); + LOGGER.info("Exiting server abruptly"); + System.exit(1); + } } @Override - public void onError(final Throwable throwable) { - LOGGER.severe("Error in collecting packages" + throwable.getMessage()); - responseObserver.onError(throwable); + public void onError(final Throwable th) { + LOGGER.severe("Error in collecting packages: " + th.getMessage()); + responseObserver.onError(th); } @Override @@ -123,6 +179,63 @@ public void onCompleted() { } }; } + + @Override + public StreamObserver aggregatePackages(final StreamObserver responseObserver) { + + if (responseObserver instanceof ServerCallStreamObserver) { + ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> { + LOGGER.info("CancelHandler:collectPackages() was cancelled"); + }); + } + return new StreamObserver() { + + private long numMessages = 0; + private long totalLength = 0; + + @Override + public void onNext(final CourierRequest request) { + LOGGER.info("Received CourierRequest id=" + request.getId()); + + numMessages += 1; + totalLength += request.getMessage().length(); + LOGGER.info("Summary of collected packages: numMessages=" + numMessages + + " with totalLength=" + totalLength); + + if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) { + sleep(1000); + LOGGER.info("Exiting server abruptly"); + System.exit(1); + } else if (Variant.SEND_ERROR.equals(request.getVariant())) { + final StatusRuntimeException error = Status.CANCELLED + .withCause(new RuntimeException(request.getId())) + .withDescription("Cancelled by server") + .asRuntimeException(); + LOGGER.info("Sending cancelled by server error"); + responseObserver.onError(error); + } + } + + @Override + public void onError(final Throwable th) { + LOGGER.severe("Error in aggregating packages: " + th.getMessage()); + responseObserver.onError(th); + } + + @Override + public void onCompleted() { + LOGGER.severe("Completed aggregating packages"); + final CourierSummary courierSummary = CourierSummary + .newBuilder() + .setNumMessages(numMessages) + .setTotalLength(totalLength) + .build(); + LOGGER.info("Sending aggregated CourierSummary"); + responseObserver.onNext(courierSummary); + responseObserver.onCompleted(); + } + }; + } } private static class GrpcServerInterceptor implements ServerInterceptor { @@ -142,6 +255,7 @@ public ServerCall.Listener interceptCall( new ForwardingServerCall.SimpleForwardingServerCall(serverCall) { @Override public void sendHeaders(final Metadata responseHeaders) { + LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]"); logMetadata(requestMetadata, "response"); if (correlationId != null) { LOGGER.info("response linked to cid: " + correlationId); @@ -149,13 +263,51 @@ public void sendHeaders(final Metadata responseHeaders) { } super.sendHeaders(responseHeaders); } + + @Override + public void sendMessage(final RespT response) { + LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]"); + super.sendMessage(response); + } + + @Override + public void close(final Status status, final Metadata trailers) { + LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers); + super.close(status, trailers); + } }; - return serverCallHandler.startCall(wrapperCall, requestMetadata); + final ServerCall.Listener listener = serverCallHandler.startCall(wrapperCall, requestMetadata); + return new ServerCall.Listener() { + public void onMessage(final ReqT message) { + LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]"); + listener.onMessage(message); + } + + public void onHalfClose() { + LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]"); + listener.onHalfClose(); + } + + public void onCancel() { + LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]"); + listener.onCancel(); + } + + public void onComplete() { + LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]"); + listener.onComplete(); + } + + public void onReady() { + LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]"); + listener.onReady(); + } + }; } private void logMetadata(final Metadata metadata, final String label) { final Set metadataKeys = metadata.keys(); - LOGGER.info(label + " metadata keys = " + metadataKeys); + LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys); for (final String key : metadataKeys) { final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER)); LOGGER.info(label + " metadata " + key + " = " + value); diff --git a/containers/test-apps/courier/src/main/proto/courier.proto b/containers/test-apps/courier/src/main/proto/courier.proto index a9b960b75..dcb2585ac 100644 --- a/containers/test-apps/courier/src/main/proto/courier.proto +++ b/containers/test-apps/courier/src/main/proto/courier.proto @@ -22,29 +22,35 @@ option objc_class_prefix = "TSCP"; package courier; -// The greeting service definition. service Courier { // Sends a package. rpc SendPackage (CourierRequest) returns (CourierReply); - // Processes a stream of messages + // Processes a stream of messages, returns a stream of messages rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary); + // Processes a stream of messages, returns a unary message + rpc AggregatePackages (stream CourierRequest) returns (CourierSummary); +} + +enum Variant { + NORMAL = 0; + SEND_ERROR = 1; + EXIT_PRE_RESPONSE = 2; + EXIT_POST_RESPONSE = 3; } -// The request message containing the package's name. message CourierRequest { string id = 1; string from = 2; string message = 3; + Variant variant = 4; } -// The response message containing the package response. message CourierReply { string id = 1; string message = 2; string response = 3; } -// The response message containing the package response. message CourierSummary { int64 num_messages = 1; int64 total_length = 2; diff --git a/waiter/integration/waiter/grpc_test.clj b/waiter/integration/waiter/grpc_test.clj index 5965a5535..457599e77 100644 --- a/waiter/integration/waiter/grpc_test.clj +++ b/waiter/integration/waiter/grpc_test.clj @@ -18,10 +18,21 @@ [clojure.test :refer :all] [clojure.tools.logging :as log] [clojure.walk :as walk] + [waiter.correlation-id :as cid] [waiter.util.client-tools :refer :all]) - (:import (com.twosigma.waiter.courier GrpcClient) + (:import (com.twosigma.waiter.courier CourierReply CourierSummary GrpcClient) + (io.grpc Status) (java.util.function Function))) +(defn- initialize-grpc-client + "Initializes grpc client logging to specific correlation id" + [correlation-id host port] + (GrpcClient. host port (reify Function + (apply [_ message] + (cid/with-correlation-id + correlation-id + (log/info message)))))) + (defn- basic-grpc-service-parameters [] (let [courier-command (courier-server-command "${PORT0} ${PORT1}")] @@ -29,12 +40,15 @@ {:x-waiter-backend-proto "h2c" :x-waiter-cmd courier-command :x-waiter-cmd-type "shell" + :x-waiter-concurrency-level 32 :x-waiter-cpus 0.2 :x-waiter-debug true :x-waiter-grace-period-secs 120 :x-waiter-health-check-port-index 1 :x-waiter-health-check-proto "http" :x-waiter-idle-timeout-mins 10 + :x-waiter-max-instances 1 + :x-waiter-min-instances 1 :x-waiter-mem 512 :x-waiter-name (rand-name) :x-waiter-ports 2 @@ -45,65 +59,403 @@ [length] (apply str (take length (repeatedly #(char (+ (rand 26) 65)))))) -(deftest ^:parallel ^:integration-fast test-basic-grpc-server - (testing-using-waiter-url - (GrpcClient/setLogFunction (reify Function - (apply [_ message] - (log/info message)))) - (let [[host _] (str/split waiter-url #":") - h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url)) - request-headers (basic-grpc-service-parameters) - {:keys [cookies headers] :as response} (make-request waiter-url "/waiter-ping" :headers request-headers) - cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies)) - service-id (get headers "x-waiter-service-id")] - - (assert-response-status response 200) - (is service-id) +(defn ping-courier-service + [waiter-url request-headers] + (make-request waiter-url "/waiter-ping" :headers request-headers)) - (with-service-cleanup - service-id +(defn start-courier-instance + [waiter-url] + (let [[host _] (str/split waiter-url #":") + h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url)) + request-headers (basic-grpc-service-parameters) + {:keys [cookies headers] :as response} (ping-courier-service waiter-url request-headers) + cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies)) + service-id (get headers "x-waiter-service-id") + request-headers (assoc request-headers + "cookie" cookie-header + "x-waiter-timeout" "60000")] + (assert-response-status response 200) + (is service-id) + (log/info "ping cid:" (get headers "x-cid")) + (log/info "service-id:" service-id) + (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)] + (is (= "received-response" (:result ping-response)) (str ping-response)) + (is (= "OK" (some-> ping-response :body)) (str ping-response)) + (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check") (str ping-response)) + (assert-response-status ping-response 200) + (is (true? (:exists? service-state)) (str service-state)) + (is (= service-id (:service-id service-state)) (str service-state)) + (is (contains? #{"Running" "Starting"} (:status service-state)) (str service-state))) + (assert-service-on-all-routers waiter-url service-id cookies) - (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)] - (is (= "received-response" (:result ping-response)) (str ping-response)) - (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check")) - (assert-response-status ping-response 200) - (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state) - (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state)) - (str service-state))) + {:h2c-port h2c-port + :host host + :request-headers request-headers + :service-id service-id})) +(deftest ^:parallel ^:integration-fast test-grpc-unary-call + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id (testing "small request and reply" (log/info "starting small request and reply test") (let [id (rand-name "m") from (rand-name "f") content (rand-str 1000) - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) - reply (GrpcClient/sendPackage host h2c-port request-headers id from content)] - (is reply) + correlation-id (rand-name) + request-headers (assoc request-headers "x-cid" correlation-id) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.sendPackage grpc-client request-headers id from content) + ^CourierReply reply (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id} + reply (assoc :reply {:id (.getId reply) + :response (.getResponse reply)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (is reply assertion-message) (when reply - (is (= id (.getId reply))) - (is (= content (.getMessage reply))) - (is (= "received" (.getResponse reply)))))) - - (testing "streaming to and from server" - (doseq [max-message-length [100 1000 10000 100000]] - (let [messages (doall (repeatedly 200 #(rand-str (inc (rand-int max-message-length)))))] - - (testing (str "independent mode " max-message-length " messages") - (log/info "starting streaming to and from server - independent mode test") - (let [from (rand-name "f") - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) - summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 false)] - (is (= (count messages) (count summaries))) - (when (seq summaries) - (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) - (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))) - - (testing (str "lock-step mode " max-message-length " messages") - (log/info "starting streaming to and from server - lock-step mode test") - (let [from (rand-name "f") - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) - summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 true)] - (is (= (count messages) (count summaries))) - (when (seq summaries) - (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) - (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))))))))) + (is (= id (.getId reply)) assertion-message) + (is (= content (.getMessage reply)) assertion-message) + (is (= "received" (.getResponse reply)) assertion-message)) + (is status assertion-message) + (when status + (is (= "OK" (-> status .getCode str)) assertion-message) + (is (str/blank? (.getDescription status)) assertion-message)))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-cancellation + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (testing "small request and reply" + (log/info "starting small request and reply test") + (let [id (str (rand-name "m") ".SEND_ERROR") + from (rand-name "f") + content (rand-str 1000) + correlation-id (rand-name) + _ (log/info "cid:" correlation-id) + request-headers (assoc request-headers "x-cid" correlation-id) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.sendPackage grpc-client request-headers id from content) + ^CourierReply reply (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id} + reply (assoc :reply {:id (.getId reply) + :response (.getResponse reply)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (is (nil? reply) assertion-message) + (is status assertion-message) + (when status + (is (= "CANCELLED" (-> status .getCode str)) assertion-message) + (is (= "Cancelled by server" (.getDescription status)) assertion-message)))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-exit + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (testing "small request and reply" + (log/info "starting small request and reply test") + (let [id (str (rand-name "m") ".EXIT_PRE_RESPONSE") + from (rand-name "f") + content (rand-str 1000) + correlation-id (rand-name) + _ (log/info "cid:" correlation-id) + request-headers (assoc request-headers "x-cid" correlation-id) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.sendPackage grpc-client request-headers id from content) + ^CourierReply reply (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id} + reply (assoc :reply {:id (.getId reply) + :response (.getResponse reply)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (is (nil? reply) assertion-message) + (is status assertion-message) + (when status + (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message)))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-bidi-streaming-successful + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (doseq [max-message-length [1000 100000]] + (let [num-messages 120 + messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + + (testing (str "independent mode " max-message-length " messages completion") + (log/info "starting streaming to and from server - independent mode test") + (let [cancel-threshold (inc num-messages) + from (rand-name "f") + correlation-id (str (rand-name) "-independent-complete") + request-headers (assoc request-headers "x-cid" correlation-id) + ids (map #(str "id-inde-" %) (range num-messages)) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.collectPackages grpc-client request-headers ids from messages 10 false cancel-threshold) + summaries (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id + :summaries (map (fn [^CourierSummary s] + {:num-messages (.getNumMessages s) + :total-length (.getTotalLength s)}) + summaries)} + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (log/info correlation-id "collecting independent packages...") + (is (= (count messages) (count summaries)) assertion-message) + (when (seq summaries) + (is (= (range 1 (inc (count messages))) (map #(.getNumMessages ^CourierSummary %) summaries)) + assertion-message) + (is (= (reductions + (map count messages)) (map #(.getTotalLength ^CourierSummary %) summaries)) + assertion-message)) + (is status assertion-message) + (when status + (is (= "OK" (-> status .getCode str)) assertion-message) + (is (str/blank? (.getDescription status)) assertion-message)))) + + (testing (str "lock-step mode " max-message-length " messages completion") + (log/info "starting streaming to and from server - lock-step mode test") + (let [cancel-threshold (inc num-messages) + from (rand-name "f") + correlation-id (str (rand-name) "-lock-step-complete") + request-headers (assoc request-headers "x-cid" correlation-id) + ids (map #(str "id-lock-" %) (range num-messages)) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true cancel-threshold) + summaries (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id + :summaries (map (fn [^CourierSummary s] + {:num-messages (.getNumMessages s) + :total-length (.getTotalLength s)}) + summaries)} + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (log/info correlation-id "collecting lock-step packages...") + (is (= (count messages) (count summaries)) assertion-message) + (when (seq summaries) + (is (= (range 1 (inc (count messages))) (map #(.getNumMessages ^CourierSummary %) summaries)) + assertion-message) + (is (= (reductions + (map count messages)) (map #(.getTotalLength ^CourierSummary %) summaries)) + assertion-message)) + (is status assertion-message) + (when status + (is (= "OK" (-> status .getCode str)) assertion-message) + (is (str/blank? (.getDescription status)) assertion-message)))))))))) + +(deftest ^:parallel ^:integration-slow test-grpc-bidi-streaming-server-exit + (testing-using-waiter-url + (let [num-messages 120 + num-iterations 3] + (doseq [max-message-length [1000 100000]] + (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + (dotimes [iteration num-iterations] + (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]] + (testing (str "lock-step mode " max-message-length " messages " mode) + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [exit-index (* iteration (/ num-messages num-iterations)) + correlation-id (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length) + _ (log/info "collect packages cid" correlation-id "for" + {:iteration iteration :max-message-length max-message-length}) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages)) + request-headers (assoc request-headers "x-cid" correlation-id) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true (inc num-messages)) + message-summaries (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :exit-index exit-index + :iteration iteration + :service-id service-id + :summaries (map (fn [^CourierSummary s] + {:num-messages (.getNumMessages s) + :total-length (.getTotalLength s)}) + message-summaries)} + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)}))) + expected-summary-count (cond-> exit-index + (= "EXIT_POST_RESPONSE" mode) inc)] + (log/info "result" assertion-message) + (is (= expected-summary-count (count message-summaries)) assertion-message) + (when (seq message-summaries) + (is (= (range 1 (inc expected-summary-count)) + (map #(.getNumMessages %) message-summaries)) + assertion-message) + (is (= (reductions + (map count (take expected-summary-count messages))) + (map #(.getTotalLength %) message-summaries)) + assertion-message)) + (is status assertion-message) + (when status + (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message))))))))))))) + +(deftest ^:parallel ^:integration-slow test-grpc-bidi-streaming-server-cancellation + (testing-using-waiter-url + (let [num-messages 120 + num-iterations 3] + (doseq [max-message-length [1000 100000]] + (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + (dotimes [iteration num-iterations] + (testing (str "lock-step mode " max-message-length " messages server error") + (let [mode "SEND_ERROR" + {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [error-index (* iteration (/ num-messages num-iterations)) + correlation-id (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages)) + request-headers (assoc request-headers "x-cid" correlation-id) + _ (log/info "collect packages cid" correlation-id "for" + {:iteration iteration :max-message-length max-message-length}) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true (inc num-messages)) + message-summaries (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :error-index error-index + :iteration iteration + :service-id service-id + :summaries (map (fn [s] + {:num-messages (.getNumMessages s) + :total-length (.getTotalLength s)}) + message-summaries)} + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)}))) + expected-summary-count error-index] + (log/info "result" assertion-message) + (is (= expected-summary-count (count message-summaries)) assertion-message) + (when (seq message-summaries) + (is (= (range 1 (inc expected-summary-count)) + (map #(.getNumMessages %) message-summaries)) + assertion-message) + (is (= (reductions + (map count (take expected-summary-count messages))) + (map #(.getTotalLength %) message-summaries)) + assertion-message)) + (is status assertion-message) + (when status + (is (= "CANCELLED" (-> status .getCode str)) assertion-message) + (is (= "Cancelled by server" (.getDescription status)) assertion-message)))))))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-client-streaming-successful + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (doseq [max-message-length [1000 100000]] + (let [num-messages 120 + messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + + (testing (str max-message-length " messages completion") + (log/info "starting streaming to and from server - independent mode test") + (let [cancel-threshold (inc num-messages) + from (rand-name "f") + correlation-id (rand-name) + request-headers (assoc request-headers "x-cid" correlation-id) + ids (map #(str "id-" %) (range num-messages)) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.aggregatePackages grpc-client request-headers ids from messages 10 cancel-threshold) + ^CourierSummary summary (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :service-id service-id} + summary (assoc :summary {:num-messages (.getNumMessages summary) + :total-length (.getTotalLength summary)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (log/info correlation-id "aggregated packages...") + (is summary) + (when summary + (is (= (count messages) (.getNumMessages summary)) assertion-message) + (is (= (reduce + (map count messages)) (.getTotalLength summary)) assertion-message)) + (is status) + (when status + (is (= "OK" (-> status .getCode str)) assertion-message) + (is (str/blank? (.getDescription status)) assertion-message)))))))))) + +(deftest ^:parallel ^:integration-slow test-grpc-client-streaming-server-exit + (testing-using-waiter-url + (let [num-messages 120 + num-iterations 3] + (doseq [max-message-length [1000 100000]] + (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + (dotimes [iteration num-iterations] + (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]] + (testing (str "lock-step mode " max-message-length " messages " mode) + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [exit-index (* iteration (/ num-messages num-iterations)) + correlation-id (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length) + _ (log/info "aggregate packages cid" correlation-id "for" + {:iteration iteration :max-message-length max-message-length}) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages)) + request-headers (assoc request-headers "x-cid" correlation-id) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.aggregatePackages grpc-client request-headers ids from messages 1 (inc num-messages)) + ^CourierSummary message-summary (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :exit-index exit-index + :iteration iteration + :service-id service-id} + message-summary (assoc :summary {:num-messages (.getNumMessages message-summary) + :total-length (.getTotalLength message-summary)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (log/info "result" assertion-message) + (is (nil? message-summary) assertion-message) + (is status assertion-message) + (when status + (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message))))))))))))) + +(deftest ^:parallel ^:integration-slow test-grpc-client-streaming-server-cancellation + (testing-using-waiter-url + (let [num-messages 120 + num-iterations 3] + (doseq [max-message-length [1000 100000]] + (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] + (dotimes [iteration num-iterations] + (testing (str max-message-length " messages server error") + (let [mode "SEND_ERROR" + {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [error-index (* iteration (/ num-messages num-iterations)) + correlation-id (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages)) + request-headers (assoc request-headers "x-cid" correlation-id) + _ (log/info "aggregate packages cid" correlation-id "for" + {:iteration iteration :max-message-length max-message-length}) + grpc-client (initialize-grpc-client correlation-id host h2c-port) + rpc-result (.aggregatePackages grpc-client request-headers ids from messages 1 (inc num-messages)) + ^CourierSummary message-summary (.result rpc-result) + ^Status status (.status rpc-result) + assertion-message (str (cond-> {:correlation-id correlation-id + :error-index error-index + :iteration iteration + :service-id service-id} + message-summary (assoc :summary {:num-messages (.getNumMessages message-summary) + :total-length (.getTotalLength message-summary)}) + status (assoc :status {:code (-> status .getCode str) + :description (.getDescription status)})))] + (log/info "result" assertion-message) + (is (nil? message-summary) assertion-message) + (is status assertion-message) + (when status + (is (= "CANCELLED" (-> status .getCode str)) assertion-message) + (is (= "Cancelled by server" (.getDescription status)) assertion-message)))))))))))) diff --git a/waiter/project.clj b/waiter/project.clj index 79fd92a10..c9c40e898 100644 --- a/waiter/project.clj +++ b/waiter/project.clj @@ -32,7 +32,7 @@ :dependencies [[bidi "2.1.5" :exclusions [prismatic/schema ring/ring-core]] - [twosigma/courier "1.2.1" + [twosigma/courier "1.4.6" :exclusions [com.google.guava/guava io.grpc/grpc-core] :scope "test"] ;; avoids the following: @@ -41,7 +41,7 @@ [io.grpc/grpc-core "1.20.0" :exclusions [com.google.guava/guava] :scope "test"] - [twosigma/jet "0.7.10-20190701_144314-ga425faa" + [twosigma/jet "0.7.10-20190709_215542-ge7c9dbb" :exclusions [org.mortbay.jetty.alpn/alpn-boot]] [twosigma/clj-http "1.0.2-20180124_201819-gcdf23e5" :exclusions [commons-codec commons-io org.clojure/tools.reader potemkin slingshot]] diff --git a/waiter/src/waiter/process_request.clj b/waiter/src/waiter/process_request.clj index 819a17871..a67f7f7e4 100644 --- a/waiter/src/waiter/process_request.clj +++ b/waiter/src/waiter/process_request.clj @@ -469,6 +469,61 @@ (log/warn "unable to abort as request not found inside response!"))] (log/info "aborted backend request:" aborted)))) +(defn- forward-grpc-status-headers-in-trailers + "Adds logging for tracking response trailers for requests. + Since we always send some trailers, we need to repeat the grpc status headers in the trailers + to ensure client evaluates the request to the same grpc error. + Please see https://github.com/eclipse/jetty.project/issues/3829 for details." + [{:keys [headers trailers] :as response}] + (if trailers + (let [correlation-id (cid/get-correlation-id) + trailers-copy-ch (async/chan 1) + grpc-headers (select-keys headers ["grpc-message" "grpc-status"])] + (if (seq grpc-headers) + (do + (async/go + (cid/with-correlation-id + correlation-id + (try + (let [trailers-map (async/! trailers-copy-ch modified-trailers)) + (catch Throwable th + (log/error th "error in parsing response trailers"))) + (log/info "closing response trailers channel") + (async/close! trailers-copy-ch))) + (assoc response :trailers trailers-copy-ch)) + response)) + response)) + +(defn- handle-grpc-error-response + "Eagerly terminates grpc requests with error status headers. + We cannot rely on jetty to close the request for us in a timely manner, + please see https://github.com/eclipse/jetty.project/issues/3842 for details." + [{:keys [abort-ch body error-chan trailers] :as response} request backend-proto reservation-status-promise] + (let [request-headers (:headers request) + {:strs [grpc-status]} (:headers response) + proto-version (hu/backend-protocol->http-version backend-proto)] + (when (and (hu/grpc? request-headers proto-version) + (not (str/blank? grpc-status)) + (not= "0" grpc-status) + (au/chan? body)) + (log/info "eagerly closing response body as grpc status is" grpc-status) + ;; mark the request as successful, grpc failures are reported in the headers + (deliver reservation-status-promise :success) + (when abort-ch + ;; disallow aborting the request as we deem the request a success and will trigger normal + ;; request completion by closing the body channel + (async/close! abort-ch)) + ;; stop writing any content in the body from stream-http-response and trigger request completion + (async/close! body) + ;; do not expect any trailers either in the response + (async/close! trailers) + ;; eagerly close channel as request is deemed a success and avoid blocking in stream-http-response + (async/close! error-chan)) + response)) + (defn process-http-response "Processes a response resulting from a http request. It includes book-keeping for async requests and asynchronously streaming the content." @@ -502,6 +557,8 @@ location (post-process-async-request-response-fn service-id metric-group backend-proto instance (handler/make-auth-user-map request) reason-map instance-request-properties location query-string)) + (handle-grpc-error-response request backend-proto reservation-status-promise) + (forward-grpc-status-headers-in-trailers) (assoc :body resp-chan) (update-in [:headers] (fn update-response-headers [headers] (utils/filterm #(not= "connection" (str/lower-case (str (key %)))) headers)))))))