diff --git a/containers/test-apps/courier/pom.xml b/containers/test-apps/courier/pom.xml
index e6675bed8..bd507a0c3 100644
--- a/containers/test-apps/courier/pom.xml
+++ b/containers/test-apps/courier/pom.xml
@@ -6,7 +6,7 @@
twosigma
courier
- 1.2.1
+ 1.3.0
courier
https://github.com/twosigma/waiter/tree/master/test-apps/courier
diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
index 2c3295750..16b8481c8 100644
--- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
+++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java
@@ -36,7 +36,9 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -148,6 +150,15 @@ public static CourierReply sendPackage(final String host,
} catch (final StatusRuntimeException e) {
logFunction.apply("RPC failed, status: " + e.getStatus());
return null;
+ } catch (ExecutionException e) {
+ final Status status = Status.fromThrowable(e.getCause());
+ logFunction.apply("RPC execution failed: " + status);
+ return CourierReply
+ .newBuilder()
+ .setId(status.getCode().toString())
+ .setMessage(status.getDescription())
+ .setResponse("error")
+ .build();
} catch (final Exception e) {
logFunction.apply("RPC failed, message: " + e.getMessage());
return null;
@@ -167,12 +178,14 @@ public static CourierReply sendPackage(final String host,
public static List collectPackages(final String host,
final int port,
final Map headers,
- final String idPrefix,
+ final List ids,
final String from,
final List messages,
final int interMessageSleepMs,
- final boolean lockStepMode) throws InterruptedException {
+ final boolean lockStepMode,
+ final int cancelThreshold) throws InterruptedException {
final ManagedChannel channel = initializeChannel(host, port);
+ final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true);
try {
final Semaphore lockStep = new Semaphore(1);
@@ -214,6 +227,16 @@ public void onError(final Throwable throwable) {
logFunction.apply("releasing semaphore after receiving error");
lockStep.release();
}
+ if (throwable instanceof StatusRuntimeException) {
+ final StatusRuntimeException exception = (StatusRuntimeException) throwable;
+ final CourierSummary response = CourierSummary
+ .newBuilder()
+ .setNumMessages(0)
+ .setStatusCode(exception.getStatus().getCode().name())
+ .setStatusDescription(exception.getStatus().getDescription())
+ .build();
+ resultList.add(response);
+ }
}
@Override
@@ -229,11 +252,17 @@ private void resolveResponsePromise() {
});
for (int i = 0; i < messages.size(); i++) {
+ if (i >= cancelThreshold) {
+ logFunction.apply("cancelling sending messages");
+ awaitChannelTermination.set(false);
+ channel.shutdownNow();
+ throw new CancellationException("Cancel threshold reached: " + cancelThreshold);
+ }
if (errorSignal.get()) {
logFunction.apply("aborting sending messages as error was discovered");
break;
}
- final String requestId = idPrefix + i;
+ final String requestId = ids.get(i);
if (lockStepMode) {
logFunction.apply("acquiring semaphore before sending request " + requestId);
lockStep.acquire();
@@ -264,39 +293,105 @@ private void resolveResponsePromise() {
}
} finally {
- shutdownChannel(channel);
+ if (awaitChannelTermination.get()) {
+ shutdownChannel(channel);
+ }
}
}
+ public static List collectPackages(final String host,
+ final int port,
+ final Map headers,
+ final String idPrefix,
+ final String from,
+ final List messages,
+ final int interMessageSleepMs,
+ final boolean lockStepMode,
+ final int cancelThreshold) throws InterruptedException {
+
+ final List ids = new ArrayList<>(messages.size());
+ for (int i = 0; i < messages.size(); i++) {
+ ids.add(idPrefix + i);
+ }
+
+ return collectPackages(host, port, headers, ids, from, messages, interMessageSleepMs, lockStepMode, cancelThreshold);
+ }
+
/**
* Greet server. If provided, the first element of {@code args} is the name to use in the
* greeting.
*/
public static void main(final String... args) throws Exception {
/* Access a service running on the local machine on port 8080 */
+ final long startTimeMillis = System.currentTimeMillis();
final String host = "localhost";
final int port = 8080;
final HashMap headers = new HashMap<>();
- final String id = UUID.randomUUID().toString();
- final String user = "Jim";
- final StringBuilder sb = new StringBuilder();
- for (int i = 0; i < 100_000; i++) {
- sb.append("a");
- if (i % 1000 == 0) {
- sb.append(".");
+ if (false) {
+ final String id = UUID.randomUUID().toString() + ".SEND_ERROR";
+ final String user = "Jim";
+ final StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 100_000; i++) {
+ sb.append("a");
+ if (i % 1000 == 0) {
+ sb.append(".");
+ }
+ }
+
+ headers.put("x-cid", "cid-send-package." + startTimeMillis);
+ final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString());
+ logFunction.apply("sendPackage response = " + courierReply);
+ }
+
+ if (false) {
+ headers.put("x-cid", "cid-collect-packages-complete." + startTimeMillis);
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final List courierSummaries =
+ collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size());
+ logFunction.apply("collectPackages response size = " + courierSummaries.size());
+ if (!courierSummaries.isEmpty()) {
+ final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1);
+ logFunction.apply("collectPackages[complete] summary = " + courierSummary.toString());
}
}
- final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString());
- logFunction.apply("sendPackage response = " + courierReply);
-
- final List messages = IntStream.range(0, 100).mapToObj(i -> "message-" + i).collect(Collectors.toList());
- final List courierSummaries =
- collectPackages(host, port, headers, "id-", "User", messages, 10, true);
- logFunction.apply("collectPackages response size = " + courierSummaries.size());
- if (!courierSummaries.isEmpty()) {
- final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1);
- logFunction.apply("collectPackages summary = " + courierSummary.toString());
+
+ if (false) {
+ headers.put("x-cid", "cid-collect-packages-cancel." + startTimeMillis);
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final List courierSummaries =
+ collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size() / 2);
+ logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ }
+
+ if (false) {
+ headers.put("x-cid", "cid-collect-packages-server-pre-cancel." + startTimeMillis);
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final List courierSummaries =
+ collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
+ logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ }
+
+ if (false) {
+ headers.put("x-cid", "cid-collect-packages-server-post-cancel." + startTimeMillis);
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".EXIT_POST_RESPONSE");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final List courierSummaries =
+ collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
+ logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
+ }
+
+ if (true) {
+ headers.put("x-cid", "cid-collect-packages-server-error." + startTimeMillis);
+ final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList());
+ ids.set(5, ids.get(5) + ".SEND_ERROR");
+ final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList());
+ final List courierSummaries =
+ collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1);
+ logFunction.apply("collectPackages[cancel] summary = " + courierSummaries);
}
}
}
diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
index 9b64f3205..111eec3c0 100644
--- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
+++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java
@@ -22,9 +22,18 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
@@ -33,7 +42,6 @@
public class GrpcServer {
private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());
-
private Server server;
void start(final int port) throws IOException {
@@ -76,19 +84,39 @@ public void sendPackage(final CourierRequest request, final StreamObserver() {
private long numMessages = 0;
@@ -100,19 +128,48 @@ public void onNext(final CourierRequest courierRequest) {
numMessages += 1;
totalLength += courierRequest.getMessage().length();
+ LOGGER.severe("Summary of collected packages: numMessages=" + numMessages +
+ " with totalLength=" + totalLength);
+
+ if (courierRequest.getId().contains("EXIT_PRE_RESPONSE")) {
+ sleep(1000);
+ LOGGER.info("Exiting server abruptly");
+ System.exit(1);
+ } else if (courierRequest.getId().contains("SEND_ERROR")) {
+ final StatusRuntimeException error = Status.CANCELLED
+ .withCause(new RuntimeException(courierRequest.getId()))
+ .withDescription("Cancelled by server")
+ .asRuntimeException();
+ LOGGER.info("Sending cancelled by server error");
+ responseObserver.onError(error);
+ } else {
+ final CourierSummary courierSummary = CourierSummary
+ .newBuilder()
+ .setNumMessages(numMessages)
+ .setTotalLength(totalLength)
+ .build();
+ LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
+ responseObserver.onNext(courierSummary);
+ }
- final CourierSummary courierSummary = CourierSummary
- .newBuilder()
- .setNumMessages(numMessages)
- .setTotalLength(totalLength)
- .build();
- LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
- responseObserver.onNext(courierSummary);
+ if (courierRequest.getId().contains("EXIT_POST_RESPONSE")) {
+ sleep(1000);
+ LOGGER.info("Exiting server abruptly");
+ System.exit(1);
+ }
+ }
+
+ private void sleep(final int durationMillis) {
+ try {
+ Thread.sleep(durationMillis);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
}
@Override
public void onError(final Throwable throwable) {
- LOGGER.severe("Error in collecting packages" + throwable.getMessage());
+ LOGGER.severe("Error in collecting packages: " + throwable.getMessage());
responseObserver.onError(throwable);
}
@@ -142,6 +199,7 @@ public ServerCall.Listener interceptCall(
new ForwardingServerCall.SimpleForwardingServerCall(serverCall) {
@Override
public void sendHeaders(final Metadata responseHeaders) {
+ LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
logMetadata(requestMetadata, "response");
if (correlationId != null) {
LOGGER.info("response linked to cid: " + correlationId);
@@ -149,13 +207,51 @@ public void sendHeaders(final Metadata responseHeaders) {
}
super.sendHeaders(responseHeaders);
}
+
+ @Override
+ public void sendMessage(final RespT response) {
+ LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
+ super.sendMessage(response);
+ }
+
+ @Override
+ public void close(final Status status, final Metadata trailers) {
+ LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
+ super.close(status, trailers);
+ }
};
- return serverCallHandler.startCall(wrapperCall, requestMetadata);
+ final ServerCall.Listener listener = serverCallHandler.startCall(wrapperCall, requestMetadata);
+ return new ServerCall.Listener() {
+ public void onMessage(final ReqT message) {
+ LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
+ listener.onMessage(message);
+ }
+
+ public void onHalfClose() {
+ LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
+ listener.onHalfClose();
+ }
+
+ public void onCancel() {
+ LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
+ listener.onCancel();
+ }
+
+ public void onComplete() {
+ LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
+ listener.onComplete();
+ }
+
+ public void onReady() {
+ LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
+ listener.onReady();
+ }
+ };
}
private void logMetadata(final Metadata metadata, final String label) {
final Set metadataKeys = metadata.keys();
- LOGGER.info(label + " metadata keys = " + metadataKeys);
+ LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
for (final String key : metadataKeys) {
final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER));
LOGGER.info(label + " metadata " + key + " = " + value);
diff --git a/containers/test-apps/courier/src/main/proto/courier.proto b/containers/test-apps/courier/src/main/proto/courier.proto
index a9b960b75..29617dbc3 100644
--- a/containers/test-apps/courier/src/main/proto/courier.proto
+++ b/containers/test-apps/courier/src/main/proto/courier.proto
@@ -22,7 +22,6 @@ option objc_class_prefix = "TSCP";
package courier;
-// The greeting service definition.
service Courier {
// Sends a package.
rpc SendPackage (CourierRequest) returns (CourierReply);
@@ -30,22 +29,21 @@ service Courier {
rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary);
}
-// The request message containing the package's name.
message CourierRequest {
string id = 1;
string from = 2;
string message = 3;
}
-// The response message containing the package response.
message CourierReply {
string id = 1;
string message = 2;
string response = 3;
}
-// The response message containing the package response.
message CourierSummary {
int64 num_messages = 1;
int64 total_length = 2;
+ string status_code = 3;
+ string status_description = 4;
}
diff --git a/waiter/integration/waiter/grpc_test.clj b/waiter/integration/waiter/grpc_test.clj
index 5965a5535..431a3013d 100644
--- a/waiter/integration/waiter/grpc_test.clj
+++ b/waiter/integration/waiter/grpc_test.clj
@@ -22,6 +22,11 @@
(:import (com.twosigma.waiter.courier GrpcClient)
(java.util.function Function)))
+;; initialize logging on the grpc client
+(GrpcClient/setLogFunction (reify Function
+ (apply [_ message]
+ (log/info message))))
+
(defn- basic-grpc-service-parameters
[]
(let [courier-command (courier-server-command "${PORT0} ${PORT1}")]
@@ -29,12 +34,15 @@
{:x-waiter-backend-proto "h2c"
:x-waiter-cmd courier-command
:x-waiter-cmd-type "shell"
+ :x-waiter-concurrency-level 32
:x-waiter-cpus 0.2
:x-waiter-debug true
:x-waiter-grace-period-secs 120
:x-waiter-health-check-port-index 1
:x-waiter-health-check-proto "http"
:x-waiter-idle-timeout-mins 10
+ :x-waiter-max-instances 1
+ :x-waiter-min-instances 1
:x-waiter-mem 512
:x-waiter-name (rand-name)
:x-waiter-ports 2
@@ -45,65 +53,221 @@
[length]
(apply str (take length (repeatedly #(char (+ (rand 26) 65))))))
-(deftest ^:parallel ^:integration-fast test-basic-grpc-server
- (testing-using-waiter-url
- (GrpcClient/setLogFunction (reify Function
- (apply [_ message]
- (log/info message))))
- (let [[host _] (str/split waiter-url #":")
- h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url))
- request-headers (basic-grpc-service-parameters)
- {:keys [cookies headers] :as response} (make-request waiter-url "/waiter-ping" :headers request-headers)
- cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies))
- service-id (get headers "x-waiter-service-id")]
+(defn ping-courier-service
+ [waiter-url request-headers]
+ (make-request waiter-url "/waiter-ping" :headers request-headers))
+
+(defn start-courier-instance
+ [waiter-url]
+ (let [[host _] (str/split waiter-url #":")
+ h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url))
+ request-headers (basic-grpc-service-parameters)
+ {:keys [cookies headers] :as response} (ping-courier-service waiter-url request-headers)
+ cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies))
+ service-id (get headers "x-waiter-service-id")
+ request-headers (assoc request-headers
+ "cookie" cookie-header
+ "x-waiter-timeout" "60000")]
+ (assert-response-status response 200)
+ (is service-id)
+ (log/info "ping cid:" (get headers "x-cid"))
+ (log/info "service-id:" service-id)
+ (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)]
+ (is (= "received-response" (:result ping-response)) (str ping-response))
+ (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check"))
+ (assert-response-status ping-response 200)
+ (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state)
+ (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state))
+ (str service-state)))
+ (assert-service-on-all-routers waiter-url service-id cookies)
- (assert-response-status response 200)
- (is service-id)
+ {:cookies cookies
+ :h2c-port h2c-port
+ :host host
+ :request-headers request-headers
+ :service-id service-id}))
+(deftest ^:parallel ^:integration-fast test-grpc-unary-call
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
(with-service-cleanup
service-id
-
- (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)]
- (is (= "received-response" (:result ping-response)) (str ping-response))
- (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check"))
- (assert-response-status ping-response 200)
- (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state)
- (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state))
- (str service-state)))
-
(testing "small request and reply"
(log/info "starting small request and reply test")
(let [id (rand-name "m")
from (rand-name "f")
content (rand-str 1000)
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
+ request-headers (assoc request-headers "x-cid" (rand-name))
reply (GrpcClient/sendPackage host h2c-port request-headers id from content)]
(is reply)
(when reply
(is (= id (.getId reply)))
(is (= content (.getMessage reply)))
- (is (= "received" (.getResponse reply))))))
+ (is (= "received" (.getResponse reply))))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-error
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (testing "small request and reply"
+ (log/info "starting small request and reply test")
+ (let [id (str (rand-name "m") ".SEND_ERROR")
+ from (rand-name "f")
+ content (rand-str 1000)
+ request-cid (rand-name)
+ _ (log/info "cid:" request-cid)
+ request-headers (assoc request-headers "x-cid" request-cid)
+ reply (GrpcClient/sendPackage host h2c-port request-headers id from content)]
+ (is reply)
+ (when reply
+ (is (= "CANCELLED" (.getId reply)))
+ (is (= "Cancelled by server" (.getMessage reply)))
+ (is (= "error" (.getResponse reply))))))))))
+
+(deftest ^:parallel ^:integration-fast test-grpc-streaming-successful
+ (testing-using-waiter-url
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (doseq [max-message-length [1000 10000 100000]]
+ (let [num-messages 100
+ messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))]
- (testing "streaming to and from server"
- (doseq [max-message-length [100 1000 10000 100000]]
- (let [messages (doall (repeatedly 200 #(rand-str (inc (rand-int max-message-length)))))]
+ (testing (str "independent mode " max-message-length " messages completion")
+ (log/info "starting streaming to and from server - independent mode test")
+ (let [cancel-threshold (inc num-messages)
+ from (rand-name "f")
+ collect-cid (str (rand-name) "-independent-complete")
+ request-headers (assoc request-headers "x-cid" collect-cid)
+ summaries (GrpcClient/collectPackages
+ host h2c-port request-headers "m-" from messages 10 false cancel-threshold)]
+ (log/info collect-cid "collecting independent packages...")
+ (is (= (count messages) (count summaries)))
+ (when (seq summaries)
+ (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
+ (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))
- (testing (str "independent mode " max-message-length " messages")
- (log/info "starting streaming to and from server - independent mode test")
- (let [from (rand-name "f")
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
- summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 false)]
- (is (= (count messages) (count summaries)))
- (when (seq summaries)
- (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
- (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))
+ (testing (str "lock-step mode " max-message-length " messages completion")
+ (log/info "starting streaming to and from server - lock-step mode test")
+ (let [cancel-threshold (inc num-messages)
+ from (rand-name "f")
+ collect-cid (str (rand-name) "-lock-step-complete")
+ request-headers (assoc request-headers "x-cid" collect-cid)
+ summaries (GrpcClient/collectPackages
+ host h2c-port request-headers "m-" from messages 1 true cancel-threshold)]
+ (log/info collect-cid "collecting lock-step packages...")
+ (is (= (count messages) (count summaries)))
+ (when (seq summaries)
+ (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
+ (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))))))))
- (testing (str "lock-step mode " max-message-length " messages")
- (log/info "starting streaming to and from server - lock-step mode test")
- (let [from (rand-name "f")
- request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name))
- summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 true)]
- (is (= (count messages) (count summaries)))
- (when (seq summaries)
- (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries)))
- (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))))))))))
+(deftest ^:parallel ^:integration-slow test-grpc-streaming-server-exit
+ (testing-using-waiter-url
+ (let [num-messages 100
+ num-iterations 5]
+ (dotimes [iteration num-iterations]
+ (doseq [max-message-length [1000 10000 100000]]
+ (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]]
+ (testing (str "lock-step mode " max-message-length " messages exits pre-response")
+ (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [exit-index (* iteration (/ num-messages num-iterations))
+ collect-cid (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length)
+ _ (log/info "collect packages cid" collect-cid "for"
+ {:iteration iteration :max-message-length max-message-length})
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages))
+ messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))
+ request-headers (assoc request-headers "x-cid" collect-cid)
+ summaries (GrpcClient/collectPackages
+ host h2c-port request-headers ids from messages 1 true (inc num-messages))
+ assertion-message (str {:exit-index exit-index
+ :collect-cid collect-cid
+ :iteration iteration
+ :service-id service-id
+ :summaries (map (fn [s]
+ {:num-messages (.getNumMessages s)
+ :status-code (.getStatusCode s)
+ :total-length (.getTotalLength s)})
+ summaries)})
+ message-summaries (take (dec (count summaries)) summaries)
+ status-summary (last summaries)
+ expected-summary-count (cond-> exit-index
+ (= "EXIT_POST_RESPONSE" mode) inc)]
+ (log/info "result" assertion-message) ;; TODO shams convert to log
+ (is (= expected-summary-count (count message-summaries)) assertion-message)
+ (when (seq message-summaries)
+ (is (= (range 1 (inc expected-summary-count))
+ (map #(.getNumMessages %) message-summaries))
+ assertion-message)
+ (is (= (reductions + (map count (take expected-summary-count messages)))
+ (map #(.getTotalLength %) message-summaries))
+ assertion-message))
+ (is status-summary)
+ (when status-summary
+ (log/info "server exit summary" status-summary)
+ (is (contains? #{"UNAVAILABLE" "INTERNAL"} (.getStatusCode status-summary)) assertion-message)
+ (is (zero? (.getNumMessages status-summary)) assertion-message))))))))))))
+
+(deftest ^:parallel ^:integration-slow test-grpc-streaming-server-cancellation
+ (testing-using-waiter-url
+ (let [num-messages 100
+ num-iterations 5]
+ (dotimes [iteration num-iterations]
+ (doseq [max-message-length [1000 10000 100000]]
+ (testing (str "lock-step mode " max-message-length " messages server error")
+ (let [mode "SEND_ERROR"
+ {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)]
+ (with-service-cleanup
+ service-id
+ (let [error-index (* iteration (/ num-messages num-iterations))
+ collect-cid (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length)
+ from (rand-name "f")
+ ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages))
+ messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))
+ request-headers (assoc request-headers "x-cid" collect-cid)
+ _ (println "collect packages cid" collect-cid "for"
+ {:iteration iteration :max-message-length max-message-length})
+ summaries (GrpcClient/collectPackages
+ host h2c-port request-headers ids from messages 1 true (inc num-messages))
+ assertion-message (str {:collect-cid collect-cid
+ :error-index error-index
+ :iteration iteration
+ :service-id service-id
+ :summaries (map (fn [s]
+ {:num-messages (.getNumMessages s)
+ :status-code (.getStatusCode s)
+ :status-description (.getStatusDescription s)
+ :total-length (.getTotalLength s)})
+ summaries)})
+ expected-summary-count error-index
+ message-summaries (take (dec (count summaries)) summaries)
+ status-summary (last summaries)]
+ (log/info "result" assertion-message)
+ (is (= expected-summary-count (count message-summaries)) assertion-message)
+ (when (seq message-summaries)
+ (is (= (range 1 (inc expected-summary-count))
+ (map #(.getNumMessages %) message-summaries))
+ assertion-message)
+ (is (= (reductions + (map count (take expected-summary-count messages)))
+ (map #(.getTotalLength %) message-summaries))
+ assertion-message))
+ (is status-summary)
+ (when status-summary
+ (log/info "server cancellation summary" status-summary)
+ (if (zero? iteration)
+ (do
+ ;; TODO this if-block should not be needed, cancellation should be propagated correctly to the client
+ (is (= "INTERNAL" (.getStatusCode status-summary))
+ assertion-message)
+ (is (str/includes? (.getStatusDescription status-summary) "Received Rst Stream")
+ assertion-message))
+ (do
+ (is (= "CANCELLED" (.getStatusCode status-summary))
+ assertion-message)
+ (is (= "Cancelled by server" (.getStatusDescription status-summary))
+ assertion-message)))
+ (is (zero? (.getNumMessages status-summary))
+ assertion-message)))))))))))
diff --git a/waiter/project.clj b/waiter/project.clj
index 79fd92a10..cfe18553d 100644
--- a/waiter/project.clj
+++ b/waiter/project.clj
@@ -32,7 +32,7 @@
:dependencies [[bidi "2.1.5"
:exclusions [prismatic/schema ring/ring-core]]
- [twosigma/courier "1.2.1"
+ [twosigma/courier "1.3.0"
:exclusions [com.google.guava/guava io.grpc/grpc-core]
:scope "test"]
;; avoids the following:
@@ -41,7 +41,7 @@
[io.grpc/grpc-core "1.20.0"
:exclusions [com.google.guava/guava]
:scope "test"]
- [twosigma/jet "0.7.10-20190701_144314-ga425faa"
+ [twosigma/jet "0.7.10-20190709_215542-ge7c9dbb"
:exclusions [org.mortbay.jetty.alpn/alpn-boot]]
[twosigma/clj-http "1.0.2-20180124_201819-gcdf23e5"
:exclusions [commons-codec commons-io org.clojure/tools.reader potemkin slingshot]]
diff --git a/waiter/src/waiter/process_request.clj b/waiter/src/waiter/process_request.clj
index 8c736c57f..ef405e97a 100644
--- a/waiter/src/waiter/process_request.clj
+++ b/waiter/src/waiter/process_request.clj
@@ -506,6 +506,60 @@
(log/warn "unable to abort as request not found inside response!"))]
(log/info "aborted backend request:" aborted))))
+(defn- forward-grpc-status-headers-in-trailers
+ "Adds logging for tracking response trailers for requests.
+ Since we always send some trailers, we need to repeat the grpc status headers in the trailers
+ to ensure client evaluates the request to the same grpc error.
+ Please see https://github.com/eclipse/jetty.project/issues/3829 for details."
+ [{:keys [headers trailers] :as response}]
+ (if trailers
+ (let [correlation-id (cid/get-correlation-id)
+ trailers-copy-ch (async/chan 5)
+ grpc-headers (select-keys headers ["grpc-message" "grpc-status"])]
+ (if (seq grpc-headers)
+ (do
+ (async/go
+ (cid/with-correlation-id
+ correlation-id
+ (try
+ (let [trailers-map (async/! trailers-copy-ch modified-trailers)))
+ (catch Throwable th
+ (log/error th "error in parsing response trailers")))
+ (log/info "closing response trailers channel")
+ (async/close! trailers-copy-ch)))
+ (assoc response :trailers trailers-copy-ch))
+ response))
+ response))
+
+(defn- handle-grpc-response
+ "Eagerly terminates grpc requests with error status headers.
+ We cannot rely on jetty to close the request for us in a timely manner,
+ please see https://github.com/eclipse/jetty.project/issues/3842 for details."
+ [{:keys [abort-ch body error-chan] :as response} request backend-proto reservation-status-promise]
+ (let [request-headers (:headers request)
+ {:strs [grpc-status]} (:headers response)
+ proto-version (hu/backend-protocol->http-version backend-proto)]
+ (when (and (hu/grpc? request-headers proto-version)
+ (not (str/blank? grpc-status))
+ (not= "0" grpc-status)
+ (au/chan? body))
+ (log/info "eagerly closing response body as grpc status is" grpc-status)
+ (when abort-ch
+ ;; disallow aborting the request
+ (async/close! abort-ch))
+ ;; mark the request as successful, grpc failures are reported in the headers
+ (deliver reservation-status-promise :success)
+ ;; stop writing any content in the body
+ (async/close! body)
+ (async/close! error-chan))
+ response))
+
(defn process-http-response
"Processes a response resulting from a http request.
It includes book-keeping for async requests and asynchronously streaming the content."
@@ -539,6 +593,8 @@
location (post-process-async-request-response-fn
service-id metric-group backend-proto instance (handler/make-auth-user-map request)
reason-map instance-request-properties location query-string))
+ (forward-grpc-status-headers-in-trailers)
+ (handle-grpc-response request backend-proto reservation-status-promise)
(assoc :body resp-chan)
(update-in [:headers] (fn update-response-headers [headers]
(utils/filterm #(not= "connection" (str/lower-case (str (key %)))) headers)))))))