diff --git a/containers/test-apps/courier/pom.xml b/containers/test-apps/courier/pom.xml
index e6675bed8..7aa76f242 100644
--- a/containers/test-apps/courier/pom.xml
+++ b/containers/test-apps/courier/pom.xml
@@ -6,7 +6,7 @@
twosigma
courier
- 1.2.1
+ 1.4.6
courier
https://github.com/twosigma/waiter/tree/master/test-apps/courier
diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
index 2c3295750..330aebb3b 100644
--- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
+++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
@@ -36,10 +36,13 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -49,19 +52,57 @@ public class GrpcClient {
private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());
- private static Function logFunction = new Function() {
- @Override
- public Void apply(final String message) {
- LOGGER.info(message);
- return null;
+ public static final class RpcResult {
+ private final Result result;
+ private final Status status;
+
+ private RpcResult(final Result result, final Status status) {
+ this.result = result;
+ this.status = status;
+ }
+
+ public Result result() {
+ return result;
+ }
+
+ public Status status() {
+ return status;
+ }
+ }
+
+ private static Variant retrieveVariant(final String id) {
+ if (id.contains("SEND_ERROR")) {
+ return Variant.SEND_ERROR;
+ } else if (id.contains("EXIT_PRE_RESPONSE")) {
+ return Variant.EXIT_PRE_RESPONSE;
+ } else if (id.contains("EXIT_POST_RESPONSE")) {
+ return Variant.EXIT_POST_RESPONSE;
+ } else {
+ return Variant.NORMAL;
}
- };
+ }
+
+ private final Function logFunction;
+ private final String host;
+ private final int port;
- public static void setLogFunction(final Function logFunction) {
- GrpcClient.logFunction = logFunction;
+ public GrpcClient(final String host, final int port) {
+ this(host, port, new Function() {
+ @Override
+ public Void apply(final String message) {
+ LOGGER.info(message);
+ return null;
+ }
+ });
+ }
+
+ public GrpcClient(final String host, final int port, final Function logFunction) {
+ this.host = host;
+ this.port = port;
+ this.logFunction = logFunction;
}
- private static ManagedChannel initializeChannel(final String host, final int port) {
+ private ManagedChannel initializeChannel() {
logFunction.apply("initializing plaintext client at " + host + ":" + port);
return ManagedChannelBuilder
.forAddress(host, port)
@@ -69,17 +110,21 @@ private static ManagedChannel initializeChannel(final String host, final int por
.build();
}
- private static void shutdownChannel(final ManagedChannel channel) throws InterruptedException {
+ private void shutdownChannel(final ManagedChannel channel) {
logFunction.apply("shutting down channel");
- channel.shutdown().awaitTermination(1, TimeUnit.SECONDS);
- if (channel.isShutdown()) {
- logFunction.apply("channel shutdown successfully");
- } else {
- logFunction.apply("channel shutdown timed out!");
+ try {
+ channel.shutdown().awaitTermination(1, TimeUnit.SECONDS);
+ if (channel.isShutdown()) {
+ logFunction.apply("channel shutdown successfully");
+ } else {
+ logFunction.apply("channel shutdown timed out!");
+ }
+ } catch (Exception ex) {
+ logFunction.apply("error in channel shutdown: " + ex.getMessage());
}
}
- private static Metadata createRequestHeadersMetadata(final Map headers) {
+ private Metadata createRequestHeadersMetadata(final Map headers) {
final Metadata headerMetadata = new Metadata();
for (Map.Entry entry : headers.entrySet()) {
final String key = entry.getKey();
@@ -89,7 +134,7 @@ private static Metadata createRequestHeadersMetadata(final Map h
return headerMetadata;
}
- private static Channel wrapResponseLogger(final ManagedChannel channel) {
+ private Channel wrapResponseLogger(final ManagedChannel channel) {
return ClientInterceptors.intercept(channel, new ClientInterceptor() {
@Override
public ClientCall interceptCall(final MethodDescriptor method,
@@ -120,13 +165,11 @@ public void onClose(final Status status, final Metadata trailers) {
});
}
- public static CourierReply sendPackage(final String host,
- final int port,
- final Map headers,
- final String id,
- final String from,
- final String message) throws InterruptedException {
- final ManagedChannel channel = initializeChannel(host, port);
+ public RpcResult sendPackage(final Map headers,
+ final String id,
+ final String from,
+ final String message) {
+ final ManagedChannel channel = initializeChannel();
try {
final Channel wrappedChannel = wrapResponseLogger(channel);
@@ -141,38 +184,52 @@ public static CourierReply sendPackage(final String host,
.setId(id)
.setFrom(from)
.setMessage(message)
+ .setVariant(retrieveVariant(id))
.build();
- final CourierReply response;
+
+ final AtomicReference status = new AtomicReference<>();
+ final AtomicReference response = new AtomicReference<>();
try {
- response = futureStub.sendPackage(request).get();
- } catch (final StatusRuntimeException e) {
- logFunction.apply("RPC failed, status: " + e.getStatus());
- return null;
- } catch (final Exception e) {
- logFunction.apply("RPC failed, message: " + e.getMessage());
- return null;
+ final CourierReply reply = futureStub.sendPackage(request).get();
+ status.set(Status.OK);
+ response.set(reply);
+ } catch (final StatusRuntimeException ex) {
+ final Status errorStatus = ex.getStatus();
+ logFunction.apply("RPC failed, status: " + errorStatus);
+ status.set(errorStatus);
+ } catch (final ExecutionException ex) {
+ final Status errorStatus = Status.fromThrowable(ex.getCause());
+ logFunction.apply("RPC execution failed: " + errorStatus);
+ status.set(errorStatus);
+ } catch (final Throwable th) {
+ logFunction.apply("RPC failed, message: " + th.getMessage());
+ status.set(Status.UNKNOWN.withDescription(th.getMessage()));
}
- logFunction.apply("received response CourierReply{" +
- "id=" + response.getId() + ", " +
- "response=" + response.getResponse() + ", " +
- "message.length=" + response.getMessage().length() + "}");
- logFunction.apply("messages equal = " + message.equals(response.getMessage()));
- return response;
+
+ if (response.get() != null) {
+ final CourierReply reply = response.get();
+ logFunction.apply("received response CourierReply{" +
+ "id=" + reply.getId() + ", " +
+ "response=" + reply.getResponse() + ", " +
+ "message.length=" + reply.getMessage().length() + "}");
+ logFunction.apply("messages equal = " + message.equals(reply.getMessage()));
+ }
+ return new RpcResult<>(response.get(), status.get());
} finally {
shutdownChannel(channel);
}
}
- public static List collectPackages(final String host,
- final int port,
- final Map headers,
- final String idPrefix,
- final String from,
- final List messages,
- final int interMessageSleepMs,
- final boolean lockStepMode) throws InterruptedException {
- final ManagedChannel channel = initializeChannel(host, port);
+ public RpcResult> collectPackages(final Map headers,
+ final List ids,
+ final String from,
+ final List messages,
+ final int interMessageSleepMs,
+ final boolean lockStepMode,
+ final int cancelThreshold) {
+ final ManagedChannel channel = initializeChannel();
+ final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true);
try {
final Semaphore lockStep = new Semaphore(1);
@@ -186,6 +243,9 @@ public static List collectPackages(final String host,
logFunction.apply("will try to send package from " + from + " ...");
+ final AtomicReference status = new AtomicReference<>();
+ final AtomicReference> response = new AtomicReference<>();
+
final CompletableFuture> responsePromise = new CompletableFuture<>();
try {
final StreamObserver collector =
@@ -206,19 +266,26 @@ public void onNext(final CourierSummary response) {
}
@Override
- public void onError(final Throwable throwable) {
- logFunction.apply("error in collecting summaries " + throwable);
+ public void onError(final Throwable th) {
+ logFunction.apply("error in collecting summaries " + th);
errorSignal.compareAndSet(false, true);
- resolveResponsePromise();
if (lockStepMode) {
logFunction.apply("releasing semaphore after receiving error");
lockStep.release();
}
+ if (th instanceof StatusRuntimeException) {
+ final StatusRuntimeException exception = (StatusRuntimeException) th;
+ status.set(exception.getStatus());
+ } else {
+ status.set(Status.UNKNOWN.withDescription(th.getMessage()));
+ }
+ resolveResponsePromise();
}
@Override
public void onCompleted() {
logFunction.apply("completed collecting summaries");
+ status.set(Status.OK);
resolveResponsePromise();
}
@@ -229,11 +296,16 @@ private void resolveResponsePromise() {
});
for (int i = 0; i < messages.size(); i++) {
+ if (i >= cancelThreshold) {
+ logFunction.apply("cancelling sending messages");
+ awaitChannelTermination.set(false);
+ throw new CancellationException("Cancel threshold reached: " + cancelThreshold);
+ }
if (errorSignal.get()) {
logFunction.apply("aborting sending messages as error was discovered");
break;
}
- final String requestId = idPrefix + i;
+ final String requestId = ids.get(i);
if (lockStepMode) {
logFunction.apply("acquiring semaphore before sending request " + requestId);
lockStep.acquire();
@@ -243,6 +315,7 @@ private void resolveResponsePromise() {
.setId(requestId)
.setFrom(from)
.setMessage(messages.get(i))
+ .setVariant(retrieveVariant(requestId))
.build();
logFunction.apply("sending message CourierRequest{" +
"id=" + request.getId() + ", " +
@@ -254,17 +327,134 @@ private void resolveResponsePromise() {
logFunction.apply("completed sending packages");
collector.onCompleted();
- return responsePromise.get();
- } catch (final StatusRuntimeException e) {
- logFunction.apply("RPC failed, status: " + e.getStatus());
- return null;
- } catch (final Exception e) {
- logFunction.apply("RPC failed, message: " + e.getMessage());
- return null;
+ response.set(responsePromise.get());
+ } catch (final StatusRuntimeException ex) {
+ logFunction.apply("RPC failed, status: " + ex.getStatus());
+ status.set(ex.getStatus());
+ } catch (final Exception ex) {
+ logFunction.apply("RPC failed, message: " + ex.getMessage());
+ status.set(Status.UNKNOWN.withDescription(ex.getMessage()));
}
+ return new RpcResult<>(response.get(), status.get());
+
} finally {
- shutdownChannel(channel);
+ if (awaitChannelTermination.get()) {
+ shutdownChannel(channel);
+ } else {
+ channel.shutdownNow();
+ }
+ }
+ }
+
+ public RpcResult aggregatePackages(final Map headers,
+ final List ids,
+ final String from,
+ final List messages,
+ final int interMessageSleepMs,
+ final int cancelThreshold) {
+ final ManagedChannel channel = initializeChannel();
+ final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true);
+
+ try {
+ final AtomicBoolean errorSignal = new AtomicBoolean(false);
+
+ final Channel wrappedChannel = wrapResponseLogger(channel);
+ final Metadata headerMetadata = createRequestHeadersMetadata(headers);
+
+ final CourierGrpc.CourierStub rawStub = CourierGrpc.newStub(wrappedChannel);
+ final CourierGrpc.CourierStub futureStub = MetadataUtils.attachHeaders(rawStub, headerMetadata);
+
+ logFunction.apply("will try to agggreate package from " + from + " ...");
+
+ final AtomicReference status = new AtomicReference<>();
+ final AtomicReference response = new AtomicReference<>();
+
+ final CompletableFuture responsePromise = new CompletableFuture<>();
+ try {
+ final StreamObserver collector =
+ futureStub.aggregatePackages(new StreamObserver() {
+
+ @Override
+ public void onNext(final CourierSummary summary) {
+ logFunction.apply("received response CourierSummary{" +
+ "count=" + summary.getNumMessages() + ", " +
+ "length=" + summary.getTotalLength() + "}");
+ response.set(summary);
+ }
+
+ @Override
+ public void onError(final Throwable th) {
+ logFunction.apply("error in aggregating summaries " + th);
+ errorSignal.compareAndSet(false, true);
+ if (th instanceof StatusRuntimeException) {
+ final StatusRuntimeException exception = (StatusRuntimeException) th;
+ status.set(exception.getStatus());
+ } else {
+ status.set(Status.UNKNOWN.withDescription(th.getMessage()));
+ }
+ resolveResponsePromise();
+ }
+
+ @Override
+ public void onCompleted() {
+ logFunction.apply("completed aggregating summaries");
+ status.set(Status.OK);
+ resolveResponsePromise();
+ }
+
+ private void resolveResponsePromise() {
+ final CourierSummary courierSummary = response.get();
+ logFunction.apply("client result: " + courierSummary);
+ responsePromise.complete(courierSummary);
+ }
+ });
+
+ for (int i = 0; i < messages.size(); i++) {
+ if (i >= cancelThreshold) {
+ logFunction.apply("cancelling sending messages");
+ awaitChannelTermination.set(false);
+ throw new CancellationException("Cancel threshold reached: " + cancelThreshold);
+ }
+ if (errorSignal.get()) {
+ logFunction.apply("aborting sending messages as error was discovered");
+ break;
+ }
+ final String requestId = ids.get(i);
+ final CourierRequest request = CourierRequest
+ .newBuilder()
+ .setId(requestId)
+ .setFrom(from)
+ .setMessage(messages.get(i))
+ .setVariant(retrieveVariant(requestId))
+ .build();
+ logFunction.apply("sending message CourierRequest{" +
+ "id=" + request.getId() + ", " +
+ "from=" + request.getFrom() + ", " +
+ "message.length=" + request.getMessage().length() + "}");
+ collector.onNext(request);
+ Thread.sleep(interMessageSleepMs);
+ }
+ logFunction.apply("completed sending packages");
+ collector.onCompleted();
+
+ responsePromise.get();
+ } catch (final StatusRuntimeException ex) {
+ logFunction.apply("RPC failed, status: " + ex.getStatus());
+ status.set(ex.getStatus());
+ } catch (final Exception ex) {
+ logFunction.apply("RPC failed, message: " + ex.getMessage());
+ status.set(Status.UNKNOWN.withDescription(ex.getMessage()));
+ }
+
+ return new RpcResult<>(response.get(), status.get());
+
+ } finally {
+ if (awaitChannelTermination.get()) {
+ shutdownChannel(channel);
+ } else {
+ channel.shutdownNow();
+ }
}
}
@@ -276,8 +466,20 @@ public static void main(final String... args) throws Exception {
/* Access a service running on the local machine on port 8080 */
final String host = "localhost";
final int port = 8080;
- final HashMap headers = new HashMap<>();
+ final GrpcClient client = new GrpcClient(host, port);
+
+ // runSendPackageSuccess(client);
+ // runSendPackageSendError(client);
+ // runCollectPackagesSuccess(client);
+ // runCollectPackagesSendError(client);
+ // runCollectPackagesExitPreResponse(client);
+ // runCollectPackagesExitPostResponse(client);
+ // runAggregatePackagesSuccess(client);
+ // runAggregatePackagesSendError(client);
+ // runAggregatePackagesExitPreResponse(client);
+ }
+ private static void runSendPackageSuccess(final GrpcClient client) {
final String id = UUID.randomUUID().toString();
final String user = "Jim";
final StringBuilder sb = new StringBuilder();
@@ -287,16 +489,129 @@ public static void main(final String... args) throws Exception {
sb.append(".");
}
}
- final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString());
- logFunction.apply("sendPackage response = " + courierReply);
-
- final List messages = IntStream.range(0, 100).mapToObj(i -> "message-" + i).collect(Collectors.toList());
- final List courierSummaries =
- collectPackages(host, port, headers, "id-", "User", messages, 10, true);
- logFunction.apply("collectPackages response size = " + courierSummaries.size());
- if (!courierSummaries.isEmpty()) {
- final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1);
- logFunction.apply("collectPackages summary = " + courierSummary.toString());
+
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-send-package." + System.currentTimeMillis());
+ final RpcResult rpcResult = client.sendPackage(headers, id, user, sb.toString());
+ final CourierReply courierReply = rpcResult.result();
+ client.logFunction.apply("sendPackage response = " + courierReply);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("sendPackage status = " + status);
+ }
+
+ private static void runSendPackageSendError(final GrpcClient client) {
+ final String id = UUID.randomUUID().toString() + ".SEND_ERROR";
+ final String user = "Jim";
+ final StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 100_000; i++) {
+ sb.append("a");
+ if (i % 1000 == 0) {
+ sb.append(".");
+ }
}
+
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-send-package." + System.currentTimeMillis());
+ final RpcResult rpcResult = client.sendPackage(headers, id, user, sb.toString());
+ final CourierReply courierReply = rpcResult.result();
+ client.logFunction.apply("sendPackage response = " + courierReply);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("sendPackage status = " + status);
+ }
+
+ private static void runCollectPackagesSuccess(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-collect-packages-success." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult> rpcResult =
+ client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1);
+ final List courierSummaries = rpcResult.result();
+ client.logFunction.apply("collectPackages[success] summary = " + courierSummaries);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("collectPackages[success] status = " + status);
+ }
+
+ private static void runCollectPackagesSendError(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-collect-packages-server-error." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".SEND_ERROR");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult> rpcResult =
+ client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1);
+ final List courierSummaries = rpcResult.result();
+ client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("collectPackages[cancel] status = " + status);
+ }
+
+ private static void runCollectPackagesExitPreResponse(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-collect-packages-server-pre-cancel." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult> rpcResult =
+ client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1);
+ final List courierSummaries = rpcResult.result();
+ client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("collectPackages[cancel] status = " + status);
+ }
+
+ private static void runCollectPackagesExitPostResponse(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-collect-packages-server-post-cancel." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".EXIT_POST_RESPONSE");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult> rpcResult =
+ client.collectPackages(headers, ids, "User", messages, 100, true, messages.size() + 1);
+ final List courierSummaries = rpcResult.result();
+ client.logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("collectPackages[cancel] status = " + status);
+ }
+
+ private static void runAggregatePackagesSuccess(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-aggregate-packages-success." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult rpcResult =
+ client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1);
+ final CourierSummary courierSummary = rpcResult.result();
+ client.logFunction.apply("aggregatePackages[success] summary = " + courierSummary);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("aggregatePackages[success] status = " + status);
+ }
+
+ private static void runAggregatePackagesSendError(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-aggregate-packages-server-error." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".SEND_ERROR");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult rpcResult =
+ client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1);
+ final CourierSummary courierSummary = rpcResult.result();
+ client.logFunction.apply("aggregatePackages[cancel] summary = " + courierSummary);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("aggregatePackages[cancel] status = " + status);
+ }
+
+ private static void runAggregatePackagesExitPreResponse(final GrpcClient client) {
+ final HashMap headers = new HashMap<>();
+ headers.put("x-cid", "cid-aggregate-packages-server-pre-cancel." + System.currentTimeMillis());
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final RpcResult rpcResult =
+ client.aggregatePackages(headers, ids, "User", messages, 100, messages.size() + 1);
+ final CourierSummary courierSummary = rpcResult.result();
+ client.logFunction.apply("aggregatePackages[cancel] summary = " + courierSummary);
+ final Status status = rpcResult.status();
+ client.logFunction.apply("aggregatePackages[cancel] status = " + status);
}
}
diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
index 9b64f3205..bad02109a 100644
--- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
+++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
@@ -22,6 +22,9 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
@@ -34,6 +37,14 @@ public class GrpcServer {
private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());
+ private static void sleep(final int durationMillis) {
+ try {
+ Thread.sleep(durationMillis);
+ } catch (final Exception ex) {
+ ex.printStackTrace();
+ }
+ }
+
private Server server;
void start(final int port) throws IOException {
@@ -76,44 +87,89 @@ public void sendPackage(final CourierRequest request, final StreamObserver() {
private long numMessages = 0;
private long totalLength = 0;
@Override
- public void onNext(final CourierRequest courierRequest) {
- LOGGER.info("Received CourierRequest id=" + courierRequest.getId());
+ public void onNext(final CourierRequest request) {
+ LOGGER.info("Received CourierRequest id=" + request.getId());
numMessages += 1;
- totalLength += courierRequest.getMessage().length();
+ totalLength += request.getMessage().length();
+ LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
+ " with totalLength=" + totalLength);
- final CourierSummary courierSummary = CourierSummary
- .newBuilder()
- .setNumMessages(numMessages)
- .setTotalLength(totalLength)
- .build();
- LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
- responseObserver.onNext(courierSummary);
+ if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
+ sleep(1000);
+ LOGGER.info("Exiting server abruptly");
+ System.exit(1);
+ } else if (Variant.SEND_ERROR.equals(request.getVariant())) {
+ final StatusRuntimeException error = Status.CANCELLED
+ .withCause(new RuntimeException(request.getId()))
+ .withDescription("Cancelled by server")
+ .asRuntimeException();
+ LOGGER.info("Sending cancelled by server error");
+ responseObserver.onError(error);
+ } else {
+ final CourierSummary courierSummary = CourierSummary
+ .newBuilder()
+ .setNumMessages(numMessages)
+ .setTotalLength(totalLength)
+ .build();
+ LOGGER.info("Sending CourierSummary for id=" + request.getId());
+ responseObserver.onNext(courierSummary);
+ }
+
+ if (Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
+ sleep(1000);
+ LOGGER.info("Exiting server abruptly");
+ System.exit(1);
+ }
}
@Override
- public void onError(final Throwable throwable) {
- LOGGER.severe("Error in collecting packages" + throwable.getMessage());
- responseObserver.onError(throwable);
+ public void onError(final Throwable th) {
+ LOGGER.severe("Error in collecting packages: " + th.getMessage());
+ responseObserver.onError(th);
}
@Override
@@ -123,6 +179,63 @@ public void onCompleted() {
}
};
}
+
+ @Override
+ public StreamObserver aggregatePackages(final StreamObserver responseObserver) {
+
+ if (responseObserver instanceof ServerCallStreamObserver) {
+ ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
+ LOGGER.info("CancelHandler:collectPackages() was cancelled");
+ });
+ }
+ return new StreamObserver() {
+
+ private long numMessages = 0;
+ private long totalLength = 0;
+
+ @Override
+ public void onNext(final CourierRequest request) {
+ LOGGER.info("Received CourierRequest id=" + request.getId());
+
+ numMessages += 1;
+ totalLength += request.getMessage().length();
+ LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
+ " with totalLength=" + totalLength);
+
+ if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
+ sleep(1000);
+ LOGGER.info("Exiting server abruptly");
+ System.exit(1);
+ } else if (Variant.SEND_ERROR.equals(request.getVariant())) {
+ final StatusRuntimeException error = Status.CANCELLED
+ .withCause(new RuntimeException(request.getId()))
+ .withDescription("Cancelled by server")
+ .asRuntimeException();
+ LOGGER.info("Sending cancelled by server error");
+ responseObserver.onError(error);
+ }
+ }
+
+ @Override
+ public void onError(final Throwable th) {
+ LOGGER.severe("Error in aggregating packages: " + th.getMessage());
+ responseObserver.onError(th);
+ }
+
+ @Override
+ public void onCompleted() {
+ LOGGER.severe("Completed aggregating packages");
+ final CourierSummary courierSummary = CourierSummary
+ .newBuilder()
+ .setNumMessages(numMessages)
+ .setTotalLength(totalLength)
+ .build();
+ LOGGER.info("Sending aggregated CourierSummary");
+ responseObserver.onNext(courierSummary);
+ responseObserver.onCompleted();
+ }
+ };
+ }
}
private static class GrpcServerInterceptor implements ServerInterceptor {
@@ -142,6 +255,7 @@ public ServerCall.Listener interceptCall(
new ForwardingServerCall.SimpleForwardingServerCall(serverCall) {
@Override
public void sendHeaders(final Metadata responseHeaders) {
+ LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
logMetadata(requestMetadata, "response");
if (correlationId != null) {
LOGGER.info("response linked to cid: " + correlationId);
@@ -149,13 +263,51 @@ public void sendHeaders(final Metadata responseHeaders) {
}
super.sendHeaders(responseHeaders);
}
+
+ @Override
+ public void sendMessage(final RespT response) {
+ LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
+ super.sendMessage(response);
+ }
+
+ @Override
+ public void close(final Status status, final Metadata trailers) {
+ LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
+ super.close(status, trailers);
+ }
};
- return serverCallHandler.startCall(wrapperCall, requestMetadata);
+ final ServerCall.Listener listener = serverCallHandler.startCall(wrapperCall, requestMetadata);
+ return new ServerCall.Listener() {
+ public void onMessage(final ReqT message) {
+ LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
+ listener.onMessage(message);
+ }
+
+ public void onHalfClose() {
+ LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
+ listener.onHalfClose();
+ }
+
+ public void onCancel() {
+ LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
+ listener.onCancel();
+ }
+
+ public void onComplete() {
+ LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
+ listener.onComplete();
+ }
+
+ public void onReady() {
+ LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
+ listener.onReady();
+ }
+ };
}
private void logMetadata(final Metadata metadata, final String label) {
final Set metadataKeys = metadata.keys();
- LOGGER.info(label + " metadata keys = " + metadataKeys);
+ LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
for (final String key : metadataKeys) {
final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER));
LOGGER.info(label + " metadata " + key + " = " + value);
diff --git a/containers/test-apps/courier/src/main/proto/courier.proto b/containers/test-apps/courier/src/main/proto/courier.proto
index a9b960b75..dcb2585ac 100644
--- a/containers/test-apps/courier/src/main/proto/courier.proto
+++ b/containers/test-apps/courier/src/main/proto/courier.proto
@@ -22,29 +22,35 @@ option objc_class_prefix = "TSCP";
package courier;
-// The greeting service definition.
service Courier {
// Sends a package.
rpc SendPackage (CourierRequest) returns (CourierReply);
- // Processes a stream of messages
+ // Processes a stream of messages, returns a stream of messages
rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary);
+ // Processes a stream of messages, returns a unary message
+ rpc AggregatePackages (stream CourierRequest) returns (CourierSummary);
+}
+
+enum Variant {
+ NORMAL = 0;
+ SEND_ERROR = 1;
+ EXIT_PRE_RESPONSE = 2;
+ EXIT_POST_RESPONSE = 3;
}
-// The request message containing the package's name.
message CourierRequest {
string id = 1;
string from = 2;
string message = 3;
+ Variant variant = 4;
}
-// The response message containing the package response.
message CourierReply {
string id = 1;
string message = 2;
string response = 3;
}
-// The response message containing the package response.
message CourierSummary {
int64 num_messages = 1;
int64 total_length = 2;
diff --git a/waiter/integration/waiter/grpc_test.clj b/waiter/integration/waiter/grpc_test.clj
index 5965a5535..457599e77 100644
--- a/waiter/integration/waiter/grpc_test.clj
+++ b/waiter/integration/waiter/grpc_test.clj
@@ -18,10 +18,21 @@
[clojure.test :refer :all]
[clojure.tools.logging :as log]
[clojure.walk :as walk]
+ [waiter.correlation-id :as cid]
[waiter.util.client-tools :refer :all])
- (:import (com.twosigma.waiter.courier GrpcClient)
+ (:import (com.twosigma.waiter.courier CourierReply CourierSummary GrpcClient)
+ (io.grpc Status)
(java.util.function Function)))
+(defn- initialize-grpc-client
+ "Initializes grpc client logging to specific correlation id"
+ [correlation-id host port]
+ (GrpcClient. host port (reify Function
+ (apply [_ message]
+ (cid/with-correlation-id
+ correlation-id
+ (log/info message))))))
+
(defn- basic-grpc-service-parameters
[]
(let [courier-command (courier-server-command "${PORT0} ${PORT1}")]
@@ -29,12 +40,15 @@
{:x-waiter-backend-proto "h2c"
:x-waiter-cmd courier-command
:x-waiter-cmd-type "shell"
+ :x-waiter-concurrency-level 32
:x-waiter-cpus 0.2
:x-waiter-debug true
:x-waiter-grace-period-secs 120
:x-waiter-health-check-port-index 1
:x-waiter-health-check-proto "http"
:x-waiter-idle-timeout-mins 10
+ :x-waiter-max-instances 1
+ :x-waiter-min-instances 1
:x-waiter-mem 512
:x-waiter-name (rand-name)
:x-waiter-ports 2
@@ -45,65 +59,403 @@
[length]
(apply str (take length (repeatedly #(char (+ (rand 26) 65))))))
-(deftest ^:parallel ^:integration-fast test-basic-grpc-server
- (testing-using-waiter-url
- (GrpcClient/setLogFunction (reify Function
- (apply [_ message]
- (log/info message))))
- (let [[host _] (str/split waiter-url #":")
- h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url))
- request-headers (basic-grpc-service-parameters)
- {:keys [cookies headers] :as response} (make-request waiter-url "/waiter-ping" :headers request-headers)
- cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies))
- service-id (get headers "x-waiter-service-id")]
-
- (assert-response-status response 200)
- (is service-id)
+(defn ping-courier-service
+ [waiter-url request-headers]
+ (make-request waiter-url "/waiter-ping" :headers request-headers))
- (with-service-cleanup
- service-id
+(defn start-courier-instance
+ [waiter-url]
+ (let [[host _] (str/split waiter-url #":")
+ h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url))
+ request-headers (basic-grpc-service-parameters)
+ {:keys [cookies headers] :as response} (ping-courier-service waiter-url request-headers)
+ cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies))
+ service-id (get headers "x-waiter-service-id")
+ request-headers (assoc request-headers
+ "cookie" cookie-header
+ "x-waiter-timeout" "60000")]
+ (assert-response-status response 200)
+ (is service-id)
+ (log/info "ping cid:" (get headers "x-cid"))
+ (log/info "service-id:" service-id)
+ (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)]
+ (is (= "received-response" (:result ping-response)) (str ping-response))
+ (is (= "OK" (some-> ping-response :body)) (str ping-response))
+ (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check") (str ping-response))
+ (assert-response-status ping-response 200)
+ (is (true? (:exists? service-state)) (str service-state))
+ (is (= service-id (:service-id service-state)) (str service-state))
+ (is (contains? #{"Running" "Starting"} (:status service-state)) (str service-state)))
+ (assert-service-on-all-routers waiter-url service-id cookies)
- (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)]
- (is (= "received-response" (:result ping-response)) (str ping-response))
- (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check"))
- (assert-response-status ping-response 200)
- (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state)
- (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state))
- (str service-state)))
+ {:h2c-port h2c-port
+ :host host
+ :request-headers request-headers
+ :service-id service-id}))
+(deftest ^:parallel ^:integration-fast test-grpc-unary-call
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
(testing "small request and reply"
(log/info "starting small request and reply test")
(let [id (rand-name "m")
from (rand-name "f")
content (rand-str 1000)
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
- reply (GrpcClient/sendPackage host h2c-port request-headers id from content)]
- (is reply)
+ correlation-id (rand-name)
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.sendPackage grpc-client request-headers id from content)
+ ^CourierReply reply (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id}
+ reply (assoc :reply {:id (.getId reply)
+ :response (.getResponse reply)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (is reply assertion-message)
(when reply
- (is (= id (.getId reply)))
- (is (= content (.getMessage reply)))
- (is (= "received" (.getResponse reply))))))
-
- (testing "streaming to and from server"
- (doseq [max-message-length [100 1000 10000 100000]]
- (let [messages (doall (repeatedly 200 #(rand-str (inc (rand-int max-message-length)))))]
-
- (testing (str "independent mode " max-message-length " messages")
- (log/info "starting streaming to and from server - independent mode test")
- (let [from (rand-name "f")
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
- summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 false)]
- (is (= (count messages) (count summaries)))
- (when (seq summaries)
- (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
- (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))
-
- (testing (str "lock-step mode " max-message-length " messages")
- (log/info "starting streaming to and from server - lock-step mode test")
- (let [from (rand-name "f")
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
- summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 true)]
- (is (= (count messages) (count summaries)))
- (when (seq summaries)
- (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
- (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))))))))))
+ (is (= id (.getId reply)) assertion-message)
+ (is (= content (.getMessage reply)) assertion-message)
+ (is (= "received" (.getResponse reply)) assertion-message))
+ (is status assertion-message)
+ (when status
+ (is (= "OK" (-> status .getCode str)) assertion-message)
+ (is (str/blank? (.getDescription status)) assertion-message))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-cancellation
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (testing "small request and reply"
+ (log/info "starting small request and reply test")
+ (let [id (str (rand-name "m") ".SEND_ERROR")
+ from (rand-name "f")
+ content (rand-str 1000)
+ correlation-id (rand-name)
+ _ (log/info "cid:" correlation-id)
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.sendPackage grpc-client request-headers id from content)
+ ^CourierReply reply (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id}
+ reply (assoc :reply {:id (.getId reply)
+ :response (.getResponse reply)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (is (nil? reply) assertion-message)
+ (is status assertion-message)
+ (when status
+ (is (= "CANCELLED" (-> status .getCode str)) assertion-message)
+ (is (= "Cancelled by server" (.getDescription status)) assertion-message))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-exit
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (testing "small request and reply"
+ (log/info "starting small request and reply test")
+ (let [id (str (rand-name "m") ".EXIT_PRE_RESPONSE")
+ from (rand-name "f")
+ content (rand-str 1000)
+ correlation-id (rand-name)
+ _ (log/info "cid:" correlation-id)
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.sendPackage grpc-client request-headers id from content)
+ ^CourierReply reply (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id}
+ reply (assoc :reply {:id (.getId reply)
+ :response (.getResponse reply)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (is (nil? reply) assertion-message)
+ (is status assertion-message)
+ (when status
+ (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-bidi-streaming-successful
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (doseq [max-message-length [1000 100000]]
+ (let [num-messages 120
+ messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+
+ (testing (str "independent mode " max-message-length " messages completion")
+ (log/info "starting streaming to and from server - independent mode test")
+ (let [cancel-threshold (inc num-messages)
+ from (rand-name "f")
+ correlation-id (str (rand-name) "-independent-complete")
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ ids (map #(str "id-inde-" %) (range num-messages))
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.collectPackages grpc-client request-headers ids from messages 10 false cancel-threshold)
+ summaries (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id
+ :summaries (map (fn [^CourierSummary s]
+ {:num-messages (.getNumMessages s)
+ :total-length (.getTotalLength s)})
+ summaries)}
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (log/info correlation-id "collecting independent packages...")
+ (is (= (count messages) (count summaries)) assertion-message)
+ (when (seq summaries)
+ (is (= (range 1 (inc (count messages))) (map #(.getNumMessages ^CourierSummary %) summaries))
+ assertion-message)
+ (is (= (reductions + (map count messages)) (map #(.getTotalLength ^CourierSummary %) summaries))
+ assertion-message))
+ (is status assertion-message)
+ (when status
+ (is (= "OK" (-> status .getCode str)) assertion-message)
+ (is (str/blank? (.getDescription status)) assertion-message))))
+
+ (testing (str "lock-step mode " max-message-length " messages completion")
+ (log/info "starting streaming to and from server - lock-step mode test")
+ (let [cancel-threshold (inc num-messages)
+ from (rand-name "f")
+ correlation-id (str (rand-name) "-lock-step-complete")
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ ids (map #(str "id-lock-" %) (range num-messages))
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true cancel-threshold)
+ summaries (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id
+ :summaries (map (fn [^CourierSummary s]
+ {:num-messages (.getNumMessages s)
+ :total-length (.getTotalLength s)})
+ summaries)}
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (log/info correlation-id "collecting lock-step packages...")
+ (is (= (count messages) (count summaries)) assertion-message)
+ (when (seq summaries)
+ (is (= (range 1 (inc (count messages))) (map #(.getNumMessages ^CourierSummary %) summaries))
+ assertion-message)
+ (is (= (reductions + (map count messages)) (map #(.getTotalLength ^CourierSummary %) summaries))
+ assertion-message))
+ (is status assertion-message)
+ (when status
+ (is (= "OK" (-> status .getCode str)) assertion-message)
+ (is (str/blank? (.getDescription status)) assertion-message))))))))))
+
+(deftest ^:parallel ^:integration-slow test-grpc-bidi-streaming-server-exit
+ (testing-using-waiter-url
+ (let [num-messages 120
+ num-iterations 3]
+ (doseq [max-message-length [1000 100000]]
+ (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+ (dotimes [iteration num-iterations]
+ (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]]
+ (testing (str "lock-step mode " max-message-length " messages " mode)
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [exit-index (* iteration (/ num-messages num-iterations))
+ correlation-id (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length)
+ _ (log/info "collect packages cid" correlation-id "for"
+ {:iteration iteration :max-message-length max-message-length})
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages))
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true (inc num-messages))
+ message-summaries (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :exit-index exit-index
+ :iteration iteration
+ :service-id service-id
+ :summaries (map (fn [^CourierSummary s]
+ {:num-messages (.getNumMessages s)
+ :total-length (.getTotalLength s)})
+ message-summaries)}
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))
+ expected-summary-count (cond-> exit-index
+ (= "EXIT_POST_RESPONSE" mode) inc)]
+ (log/info "result" assertion-message)
+ (is (= expected-summary-count (count message-summaries)) assertion-message)
+ (when (seq message-summaries)
+ (is (= (range 1 (inc expected-summary-count))
+ (map #(.getNumMessages %) message-summaries))
+ assertion-message)
+ (is (= (reductions + (map count (take expected-summary-count messages)))
+ (map #(.getTotalLength %) message-summaries))
+ assertion-message))
+ (is status assertion-message)
+ (when status
+ (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message)))))))))))))
+
+(deftest ^:parallel ^:integration-slow test-grpc-bidi-streaming-server-cancellation
+ (testing-using-waiter-url
+ (let [num-messages 120
+ num-iterations 3]
+ (doseq [max-message-length [1000 100000]]
+ (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+ (dotimes [iteration num-iterations]
+ (testing (str "lock-step mode " max-message-length " messages server error")
+ (let [mode "SEND_ERROR"
+ {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [error-index (* iteration (/ num-messages num-iterations))
+ correlation-id (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length)
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages))
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ _ (log/info "collect packages cid" correlation-id "for"
+ {:iteration iteration :max-message-length max-message-length})
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.collectPackages grpc-client request-headers ids from messages 1 true (inc num-messages))
+ message-summaries (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :error-index error-index
+ :iteration iteration
+ :service-id service-id
+ :summaries (map (fn [s]
+ {:num-messages (.getNumMessages s)
+ :total-length (.getTotalLength s)})
+ message-summaries)}
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))
+ expected-summary-count error-index]
+ (log/info "result" assertion-message)
+ (is (= expected-summary-count (count message-summaries)) assertion-message)
+ (when (seq message-summaries)
+ (is (= (range 1 (inc expected-summary-count))
+ (map #(.getNumMessages %) message-summaries))
+ assertion-message)
+ (is (= (reductions + (map count (take expected-summary-count messages)))
+ (map #(.getTotalLength %) message-summaries))
+ assertion-message))
+ (is status assertion-message)
+ (when status
+ (is (= "CANCELLED" (-> status .getCode str)) assertion-message)
+ (is (= "Cancelled by server" (.getDescription status)) assertion-message))))))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-client-streaming-successful
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (doseq [max-message-length [1000 100000]]
+ (let [num-messages 120
+ messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+
+ (testing (str max-message-length " messages completion")
+ (log/info "starting streaming to and from server - independent mode test")
+ (let [cancel-threshold (inc num-messages)
+ from (rand-name "f")
+ correlation-id (rand-name)
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ ids (map #(str "id-" %) (range num-messages))
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.aggregatePackages grpc-client request-headers ids from messages 10 cancel-threshold)
+ ^CourierSummary summary (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :service-id service-id}
+ summary (assoc :summary {:num-messages (.getNumMessages summary)
+ :total-length (.getTotalLength summary)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (log/info correlation-id "aggregated packages...")
+ (is summary)
+ (when summary
+ (is (= (count messages) (.getNumMessages summary)) assertion-message)
+ (is (= (reduce + (map count messages)) (.getTotalLength summary)) assertion-message))
+ (is status)
+ (when status
+ (is (= "OK" (-> status .getCode str)) assertion-message)
+ (is (str/blank? (.getDescription status)) assertion-message))))))))))
+
+(deftest ^:parallel ^:integration-slow test-grpc-client-streaming-server-exit
+ (testing-using-waiter-url
+ (let [num-messages 120
+ num-iterations 3]
+ (doseq [max-message-length [1000 100000]]
+ (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+ (dotimes [iteration num-iterations]
+ (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]]
+ (testing (str "lock-step mode " max-message-length " messages " mode)
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [exit-index (* iteration (/ num-messages num-iterations))
+ correlation-id (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length)
+ _ (log/info "aggregate packages cid" correlation-id "for"
+ {:iteration iteration :max-message-length max-message-length})
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages))
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.aggregatePackages grpc-client request-headers ids from messages 1 (inc num-messages))
+ ^CourierSummary message-summary (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :exit-index exit-index
+ :iteration iteration
+ :service-id service-id}
+ message-summary (assoc :summary {:num-messages (.getNumMessages message-summary)
+ :total-length (.getTotalLength message-summary)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (log/info "result" assertion-message)
+ (is (nil? message-summary) assertion-message)
+ (is status assertion-message)
+ (when status
+ (is (contains? #{"UNAVAILABLE" "INTERNAL"} (-> status .getCode str)) assertion-message)))))))))))))
+
+(deftest ^:parallel ^:integration-slow test-grpc-client-streaming-server-cancellation
+ (testing-using-waiter-url
+ (let [num-messages 120
+ num-iterations 3]
+ (doseq [max-message-length [1000 100000]]
+ (let [messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
+ (dotimes [iteration num-iterations]
+ (testing (str max-message-length " messages server error")
+ (let [mode "SEND_ERROR"
+ {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [error-index (* iteration (/ num-messages num-iterations))
+ correlation-id (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length)
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages))
+ request-headers (assoc request-headers "x-cid" correlation-id)
+ _ (log/info "aggregate packages cid" correlation-id "for"
+ {:iteration iteration :max-message-length max-message-length})
+ grpc-client (initialize-grpc-client correlation-id host h2c-port)
+ rpc-result (.aggregatePackages grpc-client request-headers ids from messages 1 (inc num-messages))
+ ^CourierSummary message-summary (.result rpc-result)
+ ^Status status (.status rpc-result)
+ assertion-message (str (cond-> {:correlation-id correlation-id
+ :error-index error-index
+ :iteration iteration
+ :service-id service-id}
+ message-summary (assoc :summary {:num-messages (.getNumMessages message-summary)
+ :total-length (.getTotalLength message-summary)})
+ status (assoc :status {:code (-> status .getCode str)
+ :description (.getDescription status)})))]
+ (log/info "result" assertion-message)
+ (is (nil? message-summary) assertion-message)
+ (is status assertion-message)
+ (when status
+ (is (= "CANCELLED" (-> status .getCode str)) assertion-message)
+ (is (= "Cancelled by server" (.getDescription status)) assertion-message))))))))))))
diff --git a/waiter/project.clj b/waiter/project.clj
index 79fd92a10..c9c40e898 100644
--- a/waiter/project.clj
+++ b/waiter/project.clj
@@ -32,7 +32,7 @@
:dependencies [[bidi "2.1.5"
:exclusions [prismatic/schema ring/ring-core]]
- [twosigma/courier "1.2.1"
+ [twosigma/courier "1.4.6"
:exclusions [com.google.guava/guava io.grpc/grpc-core]
:scope "test"]
;; avoids the following:
@@ -41,7 +41,7 @@
[io.grpc/grpc-core "1.20.0"
:exclusions [com.google.guava/guava]
:scope "test"]
- [twosigma/jet "0.7.10-20190701_144314-ga425faa"
+ [twosigma/jet "0.7.10-20190709_215542-ge7c9dbb"
:exclusions [org.mortbay.jetty.alpn/alpn-boot]]
[twosigma/clj-http "1.0.2-20180124_201819-gcdf23e5"
:exclusions [commons-codec commons-io org.clojure/tools.reader potemkin slingshot]]
diff --git a/waiter/src/waiter/process_request.clj b/waiter/src/waiter/process_request.clj
index 819a17871..a67f7f7e4 100644
--- a/waiter/src/waiter/process_request.clj
+++ b/waiter/src/waiter/process_request.clj
@@ -469,6 +469,61 @@
(log/warn "unable to abort as request not found inside response!"))]
(log/info "aborted backend request:" aborted))))
+(defn- forward-grpc-status-headers-in-trailers
+ "Adds logging for tracking response trailers for requests.
+ Since we always send some trailers, we need to repeat the grpc status headers in the trailers
+ to ensure client evaluates the request to the same grpc error.
+ Please see https://github.com/eclipse/jetty.project/issues/3829 for details."
+ [{:keys [headers trailers] :as response}]
+ (if trailers
+ (let [correlation-id (cid/get-correlation-id)
+ trailers-copy-ch (async/chan 1)
+ grpc-headers (select-keys headers ["grpc-message" "grpc-status"])]
+ (if (seq grpc-headers)
+ (do
+ (async/go
+ (cid/with-correlation-id
+ correlation-id
+ (try
+ (let [trailers-map (async/! trailers-copy-ch modified-trailers))
+ (catch Throwable th
+ (log/error th "error in parsing response trailers")))
+ (log/info "closing response trailers channel")
+ (async/close! trailers-copy-ch)))
+ (assoc response :trailers trailers-copy-ch))
+ response))
+ response))
+
+(defn- handle-grpc-error-response
+ "Eagerly terminates grpc requests with error status headers.
+ We cannot rely on jetty to close the request for us in a timely manner,
+ please see https://github.com/eclipse/jetty.project/issues/3842 for details."
+ [{:keys [abort-ch body error-chan trailers] :as response} request backend-proto reservation-status-promise]
+ (let [request-headers (:headers request)
+ {:strs [grpc-status]} (:headers response)
+ proto-version (hu/backend-protocol->http-version backend-proto)]
+ (when (and (hu/grpc? request-headers proto-version)
+ (not (str/blank? grpc-status))
+ (not= "0" grpc-status)
+ (au/chan? body))
+ (log/info "eagerly closing response body as grpc status is" grpc-status)
+ ;; mark the request as successful, grpc failures are reported in the headers
+ (deliver reservation-status-promise :success)
+ (when abort-ch
+ ;; disallow aborting the request as we deem the request a success and will trigger normal
+ ;; request completion by closing the body channel
+ (async/close! abort-ch))
+ ;; stop writing any content in the body from stream-http-response and trigger request completion
+ (async/close! body)
+ ;; do not expect any trailers either in the response
+ (async/close! trailers)
+ ;; eagerly close channel as request is deemed a success and avoid blocking in stream-http-response
+ (async/close! error-chan))
+ response))
+
(defn process-http-response
"Processes a response resulting from a http request.
It includes book-keeping for async requests and asynchronously streaming the content."
@@ -502,6 +557,8 @@
location (post-process-async-request-response-fn
service-id metric-group backend-proto instance (handler/make-auth-user-map request)
reason-map instance-request-properties location query-string))
+ (handle-grpc-error-response request backend-proto reservation-status-promise)
+ (forward-grpc-status-headers-in-trailers)
(assoc :body resp-chan)
(update-in [:headers] (fn update-response-headers [headers]
(utils/filterm #(not= "connection" (str/lower-case (str (key %)))) headers)))))))