Skip to content
This repository has been archived by the owner on Mar 22, 2023. It is now read-only.

adds support for propagating grpc server-side cancellations #844

Merged
merged 11 commits into from
Jul 11, 2019
Merged
2 changes: 1 addition & 1 deletion containers/test-apps/courier/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>twosigma</groupId>
<artifactId>courier</artifactId>
<version>1.2.1</version>
<version>1.4.3</version>

<name>courier</name>
<url>https://github.com/twosigma/waiter/tree/master/test-apps/courier</url>
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;

import java.io.IOException;
Expand All @@ -34,6 +37,14 @@ public class GrpcServer {

private final static Logger LOGGER = Logger.getLogger(GrpcServer.class.getName());

private static void sleep(final int durationMillis) {
try {
Thread.sleep(durationMillis);
} catch (final Exception ex) {
ex.printStackTrace();
}
}

private Server server;

void start(final int port) throws IOException {
Expand Down Expand Up @@ -76,44 +87,89 @@ public void sendPackage(final CourierRequest request, final StreamObserver<Couri
"id=" + request.getId() + ", " +
"from=" + request.getFrom() + ", " +
"message.length=" + request.getMessage().length() + "}");
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:sendPackage CourierRequest{" + "id=" + request.getId() + "} was cancelled");
});
}
if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else {
final CourierReply reply = CourierReply
.newBuilder()
.setId(request.getId())
.setMessage(request.getMessage())
.setResponse("received")
.build();
LOGGER.info("Sending CourierReply for id=" + reply.getId());
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
}

@Override
public StreamObserver<CourierRequest> collectPackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest courierRequest) {
LOGGER.info("Received CourierRequest id=" + courierRequest.getId());
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += courierRequest.getMessage().length();
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + courierRequest.getId());
responseObserver.onNext(courierSummary);
if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
} else {
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending CourierSummary for id=" + request.getId());
responseObserver.onNext(courierSummary);
}

if (Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
}
}

@Override
public void onError(final Throwable throwable) {
LOGGER.severe("Error in collecting packages" + throwable.getMessage());
responseObserver.onError(throwable);
public void onError(final Throwable th) {
LOGGER.severe("Error in collecting packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
Expand All @@ -123,6 +179,63 @@ public void onCompleted() {
}
};
}

@Override
public StreamObserver<CourierRequest> aggregatePackages(final StreamObserver<CourierSummary> responseObserver) {

if (responseObserver instanceof ServerCallStreamObserver) {
((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> {
LOGGER.info("CancelHandler:collectPackages() was cancelled");
});
}
return new StreamObserver<CourierRequest>() {

private long numMessages = 0;
private long totalLength = 0;

@Override
public void onNext(final CourierRequest request) {
LOGGER.info("Received CourierRequest id=" + request.getId());

numMessages += 1;
totalLength += request.getMessage().length();
LOGGER.info("Summary of collected packages: numMessages=" + numMessages +
" with totalLength=" + totalLength);

if (Variant.EXIT_PRE_RESPONSE.equals(request.getVariant()) || Variant.EXIT_POST_RESPONSE.equals(request.getVariant())) {
sleep(1000);
LOGGER.info("Exiting server abruptly");
System.exit(1);
} else if (Variant.SEND_ERROR.equals(request.getVariant())) {
final StatusRuntimeException error = Status.CANCELLED
.withCause(new RuntimeException(request.getId()))
.withDescription("Cancelled by server")
.asRuntimeException();
LOGGER.info("Sending cancelled by server error");
responseObserver.onError(error);
}
}

@Override
public void onError(final Throwable th) {
LOGGER.severe("Error in aggregating packages: " + th.getMessage());
responseObserver.onError(th);
}

@Override
public void onCompleted() {
LOGGER.severe("Completed aggregating packages");
final CourierSummary courierSummary = CourierSummary
.newBuilder()
.setNumMessages(numMessages)
.setTotalLength(totalLength)
.build();
LOGGER.info("Sending aggregated CourierSummary");
responseObserver.onNext(courierSummary);
responseObserver.onCompleted();
}
};
}
}

private static class GrpcServerInterceptor implements ServerInterceptor {
Expand All @@ -142,20 +255,59 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {
@Override
public void sendHeaders(final Metadata responseHeaders) {
LOGGER.info("GrpcServerInterceptor.sendHeaders[cid=" + correlationId + "]");
logMetadata(requestMetadata, "response");
if (correlationId != null) {
LOGGER.info("response linked to cid: " + correlationId);
responseHeaders.put(xCidKey, correlationId);
}
super.sendHeaders(responseHeaders);
}

@Override
public void sendMessage(final RespT response) {
LOGGER.info("GrpcServerInterceptor.sendMessage[cid=" + correlationId + "]");
super.sendMessage(response);
}

@Override
public void close(final Status status, final Metadata trailers) {
LOGGER.info("GrpcServerInterceptor.close[cid=" + correlationId + "] " + status + ", " + trailers);
super.close(status, trailers);
}
};
return serverCallHandler.startCall(wrapperCall, requestMetadata);
final ServerCall.Listener<ReqT> listener = serverCallHandler.startCall(wrapperCall, requestMetadata);
return new ServerCall.Listener<ReqT>() {
public void onMessage(final ReqT message) {
LOGGER.info("GrpcServerInterceptor.onMessage[cid=" + correlationId + "]");
listener.onMessage(message);
}

public void onHalfClose() {
LOGGER.info("GrpcServerInterceptor.onHalfClose[cid=" + correlationId + "]");
listener.onHalfClose();
}

public void onCancel() {
LOGGER.info("GrpcServerInterceptor.onCancel[cid=" + correlationId + "]");
listener.onCancel();
}

public void onComplete() {
LOGGER.info("GrpcServerInterceptor.onComplete[cid=" + correlationId + "]");
listener.onComplete();
}

public void onReady() {
LOGGER.info("GrpcServerInterceptor.onReady[cid=" + correlationId + "]");
listener.onReady();
}
};
}

private void logMetadata(final Metadata metadata, final String label) {
final Set<String> metadataKeys = metadata.keys();
LOGGER.info(label + " metadata keys = " + metadataKeys);
LOGGER.info(label + "@" + metadata.hashCode() + " metadata keys = " + metadataKeys);
for (final String key : metadataKeys) {
final String value = metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER));
LOGGER.info(label + " metadata " + key + " = " + value);
Expand Down
16 changes: 11 additions & 5 deletions containers/test-apps/courier/src/main/proto/courier.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,35 @@ option objc_class_prefix = "TSCP";

package courier;

// The greeting service definition.
service Courier {
// Sends a package.
rpc SendPackage (CourierRequest) returns (CourierReply);
// Processes a stream of messages
// Processes a stream of messages, returns a stream of messages
rpc CollectPackages (stream CourierRequest) returns (stream CourierSummary);
// Processes a stream of messages, returns a unary message
rpc AggregatePackages (stream CourierRequest) returns (CourierSummary);
}

enum Variant {
NORMAL = 0;
SEND_ERROR = 1;
EXIT_PRE_RESPONSE = 2;
EXIT_POST_RESPONSE = 3;
}

// The request message containing the package's name.
message CourierRequest {
string id = 1;
string from = 2;
string message = 3;
Variant variant = 4;
}

// The response message containing the package response.
message CourierReply {
string id = 1;
string message = 2;
string response = 3;
}

// The response message containing the package response.
message CourierSummary {
int64 num_messages = 1;
int64 total_length = 2;
Expand Down
Loading