Skip to content
This repository has been archived by the owner on Mar 22, 2023. It is now read-only.

Commit

Permalink
adds support for propagating grpc server-side cancellations (#844)
Browse files Browse the repository at this point in the history
* adds support for propagating grpc server-side cancellations

* adds messages to test assertions

* closes trailers eagerly to allow pure header responses to be published asap

* removes unnecessary when clauses

* avoids overloading the id field in CourierRequest
introduces separate variant field to determine cancellation operation
adds unary server exit case
cleans up return values from the grpc client
adds assertion on grpc health check response body

* documents reasons for closing the channels

* adds tests for client-side streaming and server-side cancellations with unary responses

* initializes grpc client result Status to UNKNOWN instead of OK
asserts on individual fields on the ping response

* improves assertion messages
avoids warnings due to presence of overloaded methods in grpc client

* improves assertion messages
avoids warnings due to presence of overloaded methods in grpc client
adds correlation-id to logs from grpc client

* avoids race in reading status on error path in grpc-client
  • Loading branch information
shamsimam authored and sradack committed Jul 11, 2019
1 parent c6b2af1 commit dd6bd41
Show file tree
Hide file tree
Showing 7 changed files with 1,038 additions and 156 deletions.
2 changes: 1 addition & 1 deletion containers/test-apps/courier/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>twosigma</groupId>
<artifactId>courier</artifactId>
<version>1.2.1</version>
<version>1.4.6</version>

<name>courier</name>
<url>https://github.com/twosigma/waiter/tree/master/test-apps/courier</url>
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;

import java.io.IOException;
Expand All @@ -34,6 +37,14 @@ public class GrpcServer {

private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());

private static void sleep(final int durationMillis) {
try {
Thread.sleep(durationMillis);
} catch (final Exception ex) {
ex.printStackTrace();
}
}

private Server server;

void start(final int port) throws IOException {
Expand Down Expand Up @@ -76,44 +87,89 @@ public void sendPackage(final CourierRequest request, final StreamObserver<Couri
"id=" + request.getId() + ", " +
"from=" + request.getFrom() + ", " +
"message.length=" + request.getMessage().length() + "}");
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:sendPackage CourierRequest{" + "id=" + request.getId() + "} was cancelled");
});
}
if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else {
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
}

@Override
public StreamObserver<CourierRequest> collectPackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest courierRequest) {
LOGGER.info("Received CourierRequest id=" + courierRequest.getId());
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += courierRequest.getMessage().length();
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
responseObserver.onNext(courierSummary);
if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else {
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + request.getId());
responseObserver.onNext(courierSummary);
}

if (Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
}
}

@Override
public void onError(final Throwable throwable) {
LOGGER.severe("Error in collecting packages" + throwable.getMessage());
responseObserver.onError(throwable);
public void onError(final Throwable th) {
LOGGER.severe("Error in collecting packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
Expand All @@ -123,6 +179,63 @@ public void onCompleted() {
}
};
}

@Override
public StreamObserver<CourierRequest> aggregatePackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
}
}

@Override
public void onError(final Throwable th) {
LOGGER.severe("Error in aggregating packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
public void onCompleted() {
LOGGER.severe("Completed aggregating packages");
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending aggregated CourierSummary");
responseObserver.onNext(courierSummary);
responseObserver.onCompleted();
}
};
}
}

private static class GrpcServerInterceptor implements ServerInterceptor {
Expand All @@ -142,20 +255,59 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {
@Override
public void sendHeaders(final Metadata responseHeaders) {
LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
logMetadata(requestMetadata, "response");
if (correlationId != null) {
LOGGER.info("response linked to cid: " + correlationId);
responseHeaders.put(xCidKey, correlationId);
}
super.sendHeaders(responseHeaders);
}

@Override
public void sendMessage(final RespT response) {
LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
super.sendMessage(response);
}

@Override
public void close(final Status status, final Metadata trailers) {
LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
super.close(status, trailers);
}
};
return serverCallHandler.startCall(wrapperCall, requestMetadata);
final ServerCall.Listener<ReqT> listener = serverCallHandler.startCall(wrapperCall, requestMetadata);
return new ServerCall.Listener<ReqT>() {
public void onMessage(final ReqT message) {
LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
listener.onMessage(message);
}

public void onHalfClose() {
LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
listener.onHalfClose();
}

public void onCancel() {
LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
listener.onCancel();
}

public void onComplete() {
LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
listener.onComplete();
}

public void onReady() {
LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
listener.onReady();
}
};
}

private void logMetadata(final Metadata metadata, final String label) {
final Set<String> metadataKeys = metadata.keys();
LOGGER.info(label + " metadata keys = " + metadataKeys);
LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
for (final String key : metadataKeys) {
final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER));
LOGGER.info(label + " metadata " + key + " = " + value);
Expand Down
16 changes: 11 additions & 5 deletions containers/test-apps/courier/src/main/proto/courier.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,35 @@ option objc_class_prefix = "TSCP";

package courier;

// The greeting service definition.
service Courier {
// Sends a package.
rpc SendPackage (CourierRequest) returns (CourierReply);
// Processes a stream of messages
// Processes a stream of messages, returns a stream of messages
rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary);
// Processes a stream of messages, returns a unary message
rpc AggregatePackages (stream CourierRequest) returns (CourierSummary);
}

enum Variant {
NORMAL = 0;
SEND_ERROR = 1;
EXIT_PRE_RESPONSE = 2;
EXIT_POST_RESPONSE = 3;
}

// The request message containing the package's name.
message CourierRequest {
string id = 1;
string from = 2;
string message = 3;
Variant variant = 4;
}

// The response message containing the package response.
message CourierReply {
string id = 1;
string message = 2;
string response = 3;
}

// The response message containing the package response.
message CourierSummary {
int64 num_messages = 1;
int64 total_length = 2;
Expand Down
Loading

0 comments on commit dd6bd41

Please sign in to comment.