From 4907cf3e5a1089d64afdd7085d95209993e6cabd Mon Sep 17 00:00:00 2001 From: Shams Imam Date: Tue, 9 Jul 2019 08:56:46 -0500 Subject: [PATCH] adds support for propagating grpc server-side cancellations --- containers/test-apps/courier/pom.xml | 2 +- .../twosigma/waiter/courier/GrpcClient.java | 137 ++++++++-- .../twosigma/waiter/courier/GrpcServer.java | 136 ++++++++-- .../courier/src/main/proto/courier.proto | 6 +- waiter/integration/waiter/grpc_test.clj | 254 ++++++++++++++---- waiter/project.clj | 4 +- waiter/src/waiter/process_request.clj | 56 ++++ 7 files changed, 502 insertions(+), 93 deletions(-) diff --git a/containers/test-apps/courier/pom.xml b/containers/test-apps/courier/pom.xml index e6675bed8..5d875bae5 100644 --- a/containers/test-apps/courier/pom.xml +++ b/containers/test-apps/courier/pom.xml @@ -6,7 +6,7 @@ twosigma courier - 1.2.1 + 1.3.0-SNAPSHOT courier https://github.com/twosigma/waiter/tree/master/test-apps/courier diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java index 2c3295750..16b8481c8 100644 --- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java +++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcClient.java @@ -36,7 +36,9 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -148,6 +150,15 @@ public static CourierReply sendPackage(final String host, } catch (final StatusRuntimeException e) { logFunction.apply("RPC failed, status: " + e.getStatus()); return null; + } catch (ExecutionException e) { + final Status status = Status.fromThrowable(e.getCause()); + logFunction.apply("RPC execution failed: " + status); + return CourierReply + .newBuilder() + .setId(status.getCode().toString()) + .setMessage(status.getDescription()) + .setResponse("error") + .build(); } catch (final Exception e) { logFunction.apply("RPC failed, message: " + e.getMessage()); return null; @@ -167,12 +178,14 @@ public static CourierReply sendPackage(final String host, public static List collectPackages(final String host, final int port, final Map headers, - final String idPrefix, + final List ids, final String from, final List messages, final int interMessageSleepMs, - final boolean lockStepMode) throws InterruptedException { + final boolean lockStepMode, + final int cancelThreshold) throws InterruptedException { final ManagedChannel channel = initializeChannel(host, port); + final AtomicBoolean awaitChannelTermination = new AtomicBoolean(true); try { final Semaphore lockStep = new Semaphore(1); @@ -214,6 +227,16 @@ public void onError(final Throwable throwable) { logFunction.apply("releasing semaphore after receiving error"); lockStep.release(); } + if (throwable instanceof StatusRuntimeException) { + final StatusRuntimeException exception = (StatusRuntimeException) throwable; + final CourierSummary response = CourierSummary + .newBuilder() + .setNumMessages(0) + .setStatusCode(exception.getStatus().getCode().name()) + .setStatusDescription(exception.getStatus().getDescription()) + .build(); + resultList.add(response); + } } @Override @@ -229,11 +252,17 @@ private void resolveResponsePromise() { }); for (int i = 0; i < messages.size(); i++) { + if (i >= cancelThreshold) { + logFunction.apply("cancelling sending messages"); + awaitChannelTermination.set(false); + channel.shutdownNow(); + throw new CancellationException("Cancel threshold reached: " + cancelThreshold); + } if (errorSignal.get()) { logFunction.apply("aborting sending messages as error was discovered"); break; } - final String requestId = idPrefix + i; + final String requestId = ids.get(i); if (lockStepMode) { logFunction.apply("acquiring semaphore before sending request " + requestId); lockStep.acquire(); @@ -264,39 +293,105 @@ private void resolveResponsePromise() { } } finally { - shutdownChannel(channel); + if (awaitChannelTermination.get()) { + shutdownChannel(channel); + } } } + public static List collectPackages(final String host, + final int port, + final Map headers, + final String idPrefix, + final String from, + final List messages, + final int interMessageSleepMs, + final boolean lockStepMode, + final int cancelThreshold) throws InterruptedException { + + final List ids = new ArrayList<>(messages.size()); + for (int i = 0; i < messages.size(); i++) { + ids.add(idPrefix + i); + } + + return collectPackages(host, port, headers, ids, from, messages, interMessageSleepMs, lockStepMode, cancelThreshold); + } + /** * Greet server. If provided, the first element of {@code args} is the name to use in the * greeting. */ public static void main(final String... args) throws Exception { /* Access a service running on the local machine on port 8080 */ + final long startTimeMillis = System.currentTimeMillis(); final String host = "localhost"; final int port = 8080; final HashMap headers = new HashMap<>(); - final String id = UUID.randomUUID().toString(); - final String user = "Jim"; - final StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 100_000; i++) { - sb.append("a"); - if (i % 1000 == 0) { - sb.append("."); + if (false) { + final String id = UUID.randomUUID().toString() + ".SEND_ERROR"; + final String user = "Jim"; + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100_000; i++) { + sb.append("a"); + if (i % 1000 == 0) { + sb.append("."); + } + } + + headers.put("x-cid", "cid-send-package." + startTimeMillis); + final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString()); + logFunction.apply("sendPackage response = " + courierReply); + } + + if (false) { + headers.put("x-cid", "cid-collect-packages-complete." + startTimeMillis); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final List courierSummaries = + collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size()); + logFunction.apply("collectPackages response size = " + courierSummaries.size()); + if (!courierSummaries.isEmpty()) { + final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1); + logFunction.apply("collectPackages[complete] summary = " + courierSummary.toString()); } } - final CourierReply courierReply = sendPackage(host, port, headers, id, user, sb.toString()); - logFunction.apply("sendPackage response = " + courierReply); - - final List messages = IntStream.range(0, 100).mapToObj(i -> "message-" + i).collect(Collectors.toList()); - final List courierSummaries = - collectPackages(host, port, headers, "id-", "User", messages, 10, true); - logFunction.apply("collectPackages response size = " + courierSummaries.size()); - if (!courierSummaries.isEmpty()) { - final CourierSummary courierSummary = courierSummaries.get(courierSummaries.size() - 1); - logFunction.apply("collectPackages summary = " + courierSummary.toString()); + + if (false) { + headers.put("x-cid", "cid-collect-packages-cancel." + startTimeMillis); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final List courierSummaries = + collectPackages(host, port, headers, "id-", "User", messages, 100, true, messages.size() / 2); + logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + } + + if (false) { + headers.put("x-cid", "cid-collect-packages-server-pre-cancel." + startTimeMillis); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".EXIT_PRE_RESPONSE"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final List courierSummaries = + collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1); + logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + } + + if (false) { + headers.put("x-cid", "cid-collect-packages-server-post-cancel." + startTimeMillis); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".EXIT_POST_RESPONSE"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final List courierSummaries = + collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1); + logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); + } + + if (true) { + headers.put("x-cid", "cid-collect-packages-server-error." + startTimeMillis); + final List ids = IntStream.range(0, 10).mapToObj(i -> "id-" + i).collect(Collectors.toList()); + ids.set(5, ids.get(5) + ".SEND_ERROR"); + final List messages = IntStream.range(0, 10).mapToObj(i -> "message-" + i).collect(Collectors.toList()); + final List courierSummaries = + collectPackages(host, port, headers, ids, "User", messages, 100, true, messages.size() + 1); + logFunction.apply("collectPackages[cancel] summary = " + courierSummaries); } } } diff --git a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java index 9b64f3205..111eec3c0 100644 --- a/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java +++ b/containers/test-apps/courier/src/main/java/com/twosigma/waiter/courier/GrpcServer.java @@ -22,9 +22,18 @@ import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.logging.Logger; @@ -33,7 +42,6 @@ public class GrpcServer { private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName()); - private Server server; void start(final int port) throws IOException { @@ -76,19 +84,39 @@ public void sendPackage(final CourierRequest request, final StreamObserver() { private long numMessages = 0; @@ -100,19 +128,48 @@ public void onNext(final CourierRequest courierRequest) { numMessages += 1; totalLength += courierRequest.getMessage().length(); + LOGGER.severe("Summary of collected packages: numMessages=" + numMessages + + " with totalLength=" + totalLength); + + if (courierRequest.getId().contains("EXIT_PRE_RESPONSE")) { + sleep(1000); + LOGGER.info("Exiting server abruptly"); + System.exit(1); + } else if (courierRequest.getId().contains("SEND_ERROR")) { + final StatusRuntimeException error = Status.CANCELLED + .withCause(new RuntimeException(courierRequest.getId())) + .withDescription("Cancelled by server") + .asRuntimeException(); + LOGGER.info("Sending cancelled by server error"); + responseObserver.onError(error); + } else { + final CourierSummary courierSummary = CourierSummary + .newBuilder() + .setNumMessages(numMessages) + .setTotalLength(totalLength) + .build(); + LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId()); + responseObserver.onNext(courierSummary); + } - final CourierSummary courierSummary = CourierSummary - .newBuilder() - .setNumMessages(numMessages) - .setTotalLength(totalLength) - .build(); - LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId()); - responseObserver.onNext(courierSummary); + if (courierRequest.getId().contains("EXIT_POST_RESPONSE")) { + sleep(1000); + LOGGER.info("Exiting server abruptly"); + System.exit(1); + } + } + + private void sleep(final int durationMillis) { + try { + Thread.sleep(durationMillis); + } catch (Exception e) { + e.printStackTrace(); + } } @Override public void onError(final Throwable throwable) { - LOGGER.severe("Error in collecting packages" + throwable.getMessage()); + LOGGER.severe("Error in collecting packages: " + throwable.getMessage()); responseObserver.onError(throwable); } @@ -142,6 +199,7 @@ public ServerCall.Listener interceptCall( new ForwardingServerCall.SimpleForwardingServerCall(serverCall) { @Override public void sendHeaders(final Metadata responseHeaders) { + LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]"); logMetadata(requestMetadata, "response"); if (correlationId != null) { LOGGER.info("response linked to cid: " + correlationId); @@ -149,13 +207,51 @@ public void sendHeaders(final Metadata responseHeaders) { } super.sendHeaders(responseHeaders); } + + @Override + public void sendMessage(final RespT response) { + LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]"); + super.sendMessage(response); + } + + @Override + public void close(final Status status, final Metadata trailers) { + LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers); + super.close(status, trailers); + } }; - return serverCallHandler.startCall(wrapperCall, requestMetadata); + final ServerCall.Listener listener = serverCallHandler.startCall(wrapperCall, requestMetadata); + return new ServerCall.Listener() { + public void onMessage(final ReqT message) { + LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]"); + listener.onMessage(message); + } + + public void onHalfClose() { + LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]"); + listener.onHalfClose(); + } + + public void onCancel() { + LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]"); + listener.onCancel(); + } + + public void onComplete() { + LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]"); + listener.onComplete(); + } + + public void onReady() { + LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]"); + listener.onReady(); + } + }; } private void logMetadata(final Metadata metadata, final String label) { final Set metadataKeys = metadata.keys(); - LOGGER.info(label + " metadata keys = " + metadataKeys); + LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys); for (final String key : metadataKeys) { final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER)); LOGGER.info(label + " metadata " + key + " = " + value); diff --git a/containers/test-apps/courier/src/main/proto/courier.proto b/containers/test-apps/courier/src/main/proto/courier.proto index a9b960b75..29617dbc3 100644 --- a/containers/test-apps/courier/src/main/proto/courier.proto +++ b/containers/test-apps/courier/src/main/proto/courier.proto @@ -22,7 +22,6 @@ option objc_class_prefix = "TSCP"; package courier; -// The greeting service definition. service Courier { // Sends a package. rpc SendPackage (CourierRequest) returns (CourierReply); @@ -30,22 +29,21 @@ service Courier { rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary); } -// The request message containing the package's name. message CourierRequest { string id = 1; string from = 2; string message = 3; } -// The response message containing the package response. message CourierReply { string id = 1; string message = 2; string response = 3; } -// The response message containing the package response. message CourierSummary { int64 num_messages = 1; int64 total_length = 2; + string status_code = 3; + string status_description = 4; } diff --git a/waiter/integration/waiter/grpc_test.clj b/waiter/integration/waiter/grpc_test.clj index 5965a5535..431a3013d 100644 --- a/waiter/integration/waiter/grpc_test.clj +++ b/waiter/integration/waiter/grpc_test.clj @@ -22,6 +22,11 @@ (:import (com.twosigma.waiter.courier GrpcClient) (java.util.function Function))) +;; initialize logging on the grpc client +(GrpcClient/setLogFunction (reify Function + (apply [_ message] + (log/info message)))) + (defn- basic-grpc-service-parameters [] (let [courier-command (courier-server-command "${PORT0} ${PORT1}")] @@ -29,12 +34,15 @@ {:x-waiter-backend-proto "h2c" :x-waiter-cmd courier-command :x-waiter-cmd-type "shell" + :x-waiter-concurrency-level 32 :x-waiter-cpus 0.2 :x-waiter-debug true :x-waiter-grace-period-secs 120 :x-waiter-health-check-port-index 1 :x-waiter-health-check-proto "http" :x-waiter-idle-timeout-mins 10 + :x-waiter-max-instances 1 + :x-waiter-min-instances 1 :x-waiter-mem 512 :x-waiter-name (rand-name) :x-waiter-ports 2 @@ -45,65 +53,221 @@ [length] (apply str (take length (repeatedly #(char (+ (rand 26) 65)))))) -(deftest ^:parallel ^:integration-fast test-basic-grpc-server - (testing-using-waiter-url - (GrpcClient/setLogFunction (reify Function - (apply [_ message] - (log/info message)))) - (let [[host _] (str/split waiter-url #":") - h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url)) - request-headers (basic-grpc-service-parameters) - {:keys [cookies headers] :as response} (make-request waiter-url "/waiter-ping" :headers request-headers) - cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies)) - service-id (get headers "x-waiter-service-id")] +(defn ping-courier-service + [waiter-url request-headers] + (make-request waiter-url "/waiter-ping" :headers request-headers)) + +(defn start-courier-instance + [waiter-url] + (let [[host _] (str/split waiter-url #":") + h2c-port (Integer/parseInt (retrieve-h2c-port waiter-url)) + request-headers (basic-grpc-service-parameters) + {:keys [cookies headers] :as response} (ping-courier-service waiter-url request-headers) + cookie-header (str/join "; " (map #(str (:name %) "=" (:value %)) cookies)) + service-id (get headers "x-waiter-service-id") + request-headers (assoc request-headers + "cookie" cookie-header + "x-waiter-timeout" "60000")] + (assert-response-status response 200) + (is service-id) + (log/info "ping cid:" (get headers "x-cid")) + (log/info "service-id:" service-id) + (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)] + (is (= "received-response" (:result ping-response)) (str ping-response)) + (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check")) + (assert-response-status ping-response 200) + (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state) + (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state)) + (str service-state))) + (assert-service-on-all-routers waiter-url service-id cookies) - (assert-response-status response 200) - (is service-id) + {:cookies cookies + :h2c-port h2c-port + :host host + :request-headers request-headers + :service-id service-id})) +(deftest ^:parallel ^:integration-fast test-grpc-unary-call + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] (with-service-cleanup service-id - - (let [{:keys [ping-response service-state]} (some-> response :body try-parse-json walk/keywordize-keys)] - (is (= "received-response" (:result ping-response)) (str ping-response)) - (is (str/starts-with? (str (some-> ping-response :headers :server)) "courier-health-check")) - (assert-response-status ping-response 200) - (is (or (= {:exists? true :healthy? true :service-id service-id :status "Running"} service-state) - (= {:exists? true :healthy? false :service-id service-id :status "Starting"} service-state)) - (str service-state))) - (testing "small request and reply" (log/info "starting small request and reply test") (let [id (rand-name "m") from (rand-name "f") content (rand-str 1000) - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) + request-headers (assoc request-headers "x-cid" (rand-name)) reply (GrpcClient/sendPackage host h2c-port request-headers id from content)] (is reply) (when reply (is (= id (.getId reply))) (is (= content (.getMessage reply))) - (is (= "received" (.getResponse reply)))))) + (is (= "received" (.getResponse reply)))))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-unary-call-server-error + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (testing "small request and reply" + (log/info "starting small request and reply test") + (let [id (str (rand-name "m") ".SEND_ERROR") + from (rand-name "f") + content (rand-str 1000) + request-cid (rand-name) + _ (log/info "cid:" request-cid) + request-headers (assoc request-headers "x-cid" request-cid) + reply (GrpcClient/sendPackage host h2c-port request-headers id from content)] + (is reply) + (when reply + (is (= "CANCELLED" (.getId reply))) + (is (= "Cancelled by server" (.getMessage reply))) + (is (= "error" (.getResponse reply)))))))))) + +(deftest ^:parallel ^:integration-fast test-grpc-streaming-successful + (testing-using-waiter-url + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (doseq [max-message-length [1000 10000 100000]] + (let [num-messages 100 + messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length)))))] - (testing "streaming to and from server" - (doseq [max-message-length [100 1000 10000 100000]] - (let [messages (doall (repeatedly 200 #(rand-str (inc (rand-int max-message-length)))))] + (testing (str "independent mode " max-message-length " messages completion") + (log/info "starting streaming to and from server - independent mode test") + (let [cancel-threshold (inc num-messages) + from (rand-name "f") + collect-cid (str (rand-name) "-independent-complete") + request-headers (assoc request-headers "x-cid" collect-cid) + summaries (GrpcClient/collectPackages + host h2c-port request-headers "m-" from messages 10 false cancel-threshold)] + (log/info collect-cid "collecting independent packages...") + (is (= (count messages) (count summaries))) + (when (seq summaries) + (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) + (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))) - (testing (str "independent mode " max-message-length " messages") - (log/info "starting streaming to and from server - independent mode test") - (let [from (rand-name "f") - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) - summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 false)] - (is (= (count messages) (count summaries))) - (when (seq summaries) - (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) - (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))) + (testing (str "lock-step mode " max-message-length " messages completion") + (log/info "starting streaming to and from server - lock-step mode test") + (let [cancel-threshold (inc num-messages) + from (rand-name "f") + collect-cid (str (rand-name) "-lock-step-complete") + request-headers (assoc request-headers "x-cid" collect-cid) + summaries (GrpcClient/collectPackages + host h2c-port request-headers "m-" from messages 1 true cancel-threshold)] + (log/info collect-cid "collecting lock-step packages...") + (is (= (count messages) (count summaries))) + (when (seq summaries) + (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) + (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries)))))))))))) - (testing (str "lock-step mode " max-message-length " messages") - (log/info "starting streaming to and from server - lock-step mode test") - (let [from (rand-name "f") - request-headers (assoc request-headers "cookie" cookie-header "x-cid" (rand-name)) - summaries (GrpcClient/collectPackages host h2c-port request-headers "m-" from messages 1 true)] - (is (= (count messages) (count summaries))) - (when (seq summaries) - (is (= (range 1 (inc (count messages))) (map #(.getNumMessages %) summaries))) - (is (= (reductions + (map count messages)) (map #(.getTotalLength %) summaries))))))))))))) +(deftest ^:parallel ^:integration-slow test-grpc-streaming-server-exit + (testing-using-waiter-url + (let [num-messages 100 + num-iterations 5] + (dotimes [iteration num-iterations] + (doseq [max-message-length [1000 10000 100000]] + (doseq [mode ["EXIT_PRE_RESPONSE" "EXIT_POST_RESPONSE"]] + (testing (str "lock-step mode " max-message-length " messages exits pre-response") + (let [{:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [exit-index (* iteration (/ num-messages num-iterations)) + collect-cid (str (rand-name) "." mode "." exit-index "-" num-messages "." max-message-length) + _ (log/info "collect packages cid" collect-cid "for" + {:iteration iteration :max-message-length max-message-length}) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % exit-index) (str "." mode))) (range num-messages)) + messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length))))) + request-headers (assoc request-headers "x-cid" collect-cid) + summaries (GrpcClient/collectPackages + host h2c-port request-headers ids from messages 1 true (inc num-messages)) + assertion-message (str {:exit-index exit-index + :collect-cid collect-cid + :iteration iteration + :service-id service-id + :summaries (map (fn [s] + {:num-messages (.getNumMessages s) + :status-code (.getStatusCode s) + :total-length (.getTotalLength s)}) + summaries)}) + message-summaries (take (dec (count summaries)) summaries) + status-summary (last summaries) + expected-summary-count (cond-> exit-index + (= "EXIT_POST_RESPONSE" mode) inc)] + (log/info "result" assertion-message) ;; TODO shams convert to log + (is (= expected-summary-count (count message-summaries)) assertion-message) + (when (seq message-summaries) + (is (= (range 1 (inc expected-summary-count)) + (map #(.getNumMessages %) message-summaries)) + assertion-message) + (is (= (reductions + (map count (take expected-summary-count messages))) + (map #(.getTotalLength %) message-summaries)) + assertion-message)) + (is status-summary) + (when status-summary + (log/info "server exit summary" status-summary) + (is (contains? #{"UNAVAILABLE" "INTERNAL"} (.getStatusCode status-summary)) assertion-message) + (is (zero? (.getNumMessages status-summary)) assertion-message)))))))))))) + +(deftest ^:parallel ^:integration-slow test-grpc-streaming-server-cancellation + (testing-using-waiter-url + (let [num-messages 100 + num-iterations 5] + (dotimes [iteration num-iterations] + (doseq [max-message-length [1000 10000 100000]] + (testing (str "lock-step mode " max-message-length " messages server error") + (let [mode "SEND_ERROR" + {:keys [h2c-port host request-headers service-id]} (start-courier-instance waiter-url)] + (with-service-cleanup + service-id + (let [error-index (* iteration (/ num-messages num-iterations)) + collect-cid (str (rand-name) "." mode "." error-index "-" num-messages "." max-message-length) + from (rand-name "f") + ids (map #(str "id-" (cond-> % (= % error-index) (str "." mode))) (range num-messages)) + messages (doall (repeatedly num-messages #(rand-str (inc (rand-int max-message-length))))) + request-headers (assoc request-headers "x-cid" collect-cid) + _ (println "collect packages cid" collect-cid "for" + {:iteration iteration :max-message-length max-message-length}) + summaries (GrpcClient/collectPackages + host h2c-port request-headers ids from messages 1 true (inc num-messages)) + assertion-message (str {:collect-cid collect-cid + :error-index error-index + :iteration iteration + :service-id service-id + :summaries (map (fn [s] + {:num-messages (.getNumMessages s) + :status-code (.getStatusCode s) + :status-description (.getStatusDescription s) + :total-length (.getTotalLength s)}) + summaries)}) + expected-summary-count error-index + message-summaries (take (dec (count summaries)) summaries) + status-summary (last summaries)] + (log/info "result" assertion-message) + (is (= expected-summary-count (count message-summaries)) assertion-message) + (when (seq message-summaries) + (is (= (range 1 (inc expected-summary-count)) + (map #(.getNumMessages %) message-summaries)) + assertion-message) + (is (= (reductions + (map count (take expected-summary-count messages))) + (map #(.getTotalLength %) message-summaries)) + assertion-message)) + (is status-summary) + (when status-summary + (log/info "server cancellation summary" status-summary) + (if (zero? iteration) + (do + ;; TODO this if-block should not be needed, cancellation should be propagated correctly to the client + (is (= "INTERNAL" (.getStatusCode status-summary)) + assertion-message) + (is (str/includes? (.getStatusDescription status-summary) "Received Rst Stream") + assertion-message)) + (do + (is (= "CANCELLED" (.getStatusCode status-summary)) + assertion-message) + (is (= "Cancelled by server" (.getStatusDescription status-summary)) + assertion-message))) + (is (zero? (.getNumMessages status-summary)) + assertion-message))))))))))) diff --git a/waiter/project.clj b/waiter/project.clj index 79fd92a10..ec266cf9a 100644 --- a/waiter/project.clj +++ b/waiter/project.clj @@ -32,7 +32,7 @@ :dependencies [[bidi "2.1.5" :exclusions [prismatic/schema ring/ring-core]] - [twosigma/courier "1.2.1" + [twosigma/courier "1.3.0-SNAPSHOT" :exclusions [com.google.guava/guava io.grpc/grpc-core] :scope "test"] ;; avoids the following: @@ -41,7 +41,7 @@ [io.grpc/grpc-core "1.20.0" :exclusions [com.google.guava/guava] :scope "test"] - [twosigma/jet "0.7.10-20190701_144314-ga425faa" + [twosigma/jet "0.7.10-20190709_215542-ge7c9dbb" :exclusions [org.mortbay.jetty.alpn/alpn-boot]] [twosigma/clj-http "1.0.2-20180124_201819-gcdf23e5" :exclusions [commons-codec commons-io org.clojure/tools.reader potemkin slingshot]] diff --git a/waiter/src/waiter/process_request.clj b/waiter/src/waiter/process_request.clj index 8c736c57f..ef405e97a 100644 --- a/waiter/src/waiter/process_request.clj +++ b/waiter/src/waiter/process_request.clj @@ -506,6 +506,60 @@ (log/warn "unable to abort as request not found inside response!"))] (log/info "aborted backend request:" aborted)))) +(defn- forward-grpc-status-headers-in-trailers + "Adds logging for tracking response trailers for requests. + Since we always send some trailers, we need to repeat the grpc status headers in the trailers + to ensure client evaluates the request to the same grpc error. + Please see https://github.com/eclipse/jetty.project/issues/3829 for details." + [{:keys [headers trailers] :as response}] + (if trailers + (let [correlation-id (cid/get-correlation-id) + trailers-copy-ch (async/chan 5) + grpc-headers (select-keys headers ["grpc-message" "grpc-status"])] + (if (seq grpc-headers) + (do + (async/go + (cid/with-correlation-id + correlation-id + (try + (let [trailers-map (async/! trailers-copy-ch modified-trailers))) + (catch Throwable th + (log/error th "error in parsing response trailers"))) + (log/info "closing response trailers channel") + (async/close! trailers-copy-ch))) + (assoc response :trailers trailers-copy-ch)) + response)) + response)) + +(defn- handle-grpc-response + "Eagerly terminates grpc requests with error status headers. + We cannot rely on jetty to close the request for us in a timely manner, + please see https://github.com/eclipse/jetty.project/issues/3842 for details." + [{:keys [abort-ch body error-chan] :as response} request backend-proto reservation-status-promise] + (let [request-headers (:headers request) + {:strs [grpc-status]} (:headers response) + proto-version (hu/backend-protocol->http-version backend-proto)] + (when (and (hu/grpc? request-headers proto-version) + (not (str/blank? grpc-status)) + (not= "0" grpc-status) + (au/chan? body)) + (log/info "eagerly closing response body as grpc status is" grpc-status) + (when abort-ch + ;; disallow aborting the request + (async/close! abort-ch)) + ;; mark the request as successful, grpc failures are reported in the headers + (deliver reservation-status-promise :success) + ;; stop writing any content in the body + (async/close! body) + (async/close! error-chan)) + response)) + (defn process-http-response "Processes a response resulting from a http request. It includes book-keeping for async requests and asynchronously streaming the content." @@ -539,6 +593,8 @@ location (post-process-async-request-response-fn service-id metric-group backend-proto instance (handler/make-auth-user-map request) reason-map instance-request-properties location query-string)) + (forward-grpc-status-headers-in-trailers) + (handle-grpc-response request backend-proto reservation-status-promise) (assoc :body resp-chan) (update-in [:headers] (fn update-response-headers [headers] (utils/filterm #(not= "connection" (str/lower-case (str (key %)))) headers)))))))