Skip to content

Commit

Permalink
Improve exceptions/resumption in StreamingResponseHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
Randgalt committed Jun 26, 2024
1 parent 5f27621 commit 3efbe80
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,20 @@ public void setLevelDebug()
loggerProc = debugLogger;
}

public RequestLoggingSession newOrCurrentRequestSession(Request request, SigningServiceType serviceType)
public RequestLoggingSession newRequestSession(Request request, SigningServiceType serviceType)
{
return sessions.computeIfAbsent(new Key(request, serviceType), _ -> internalRequestSession(request, serviceType));
return sessions.compute(new Key(request, serviceType), (key, current) -> {
checkState(current == null, "There is already a logging session for the request: " + key);
return internalNewRequestSession(request, serviceType);
});
}

public Optional<RequestLoggingSession> currentRequestSession(Request request, SigningServiceType serviceType)
{
return Optional.ofNullable(sessions.get(new Key(request, serviceType)));
}

private RequestLoggingSession internalRequestSession(Request request, SigningServiceType serviceType)
private RequestLoggingSession internalNewRequestSession(Request request, SigningServiceType serviceType)
{
if (!loggerProc.isEnabled()) {
return () -> {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,29 @@
*/
package io.trino.s3.proxy.server.rest;

import com.google.common.base.Throwables;
import io.airlift.http.client.HeaderName;
import io.airlift.http.client.HttpStatus;
import io.airlift.http.client.Request;
import io.airlift.http.client.Response;
import io.airlift.http.client.ResponseHandler;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.AsyncResponse;
import jakarta.ws.rs.core.StreamingOutput;

import java.io.InputStream;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.airlift.http.client.ResponseHandlerUtils.propagate;
import static jakarta.ws.rs.core.Response.Status.INTERNAL_SERVER_ERROR;
import static java.util.Objects.requireNonNull;

class StreamingResponseHandler
implements ResponseHandler<Void, RuntimeException>
{
private final AsyncResponse asyncResponse;
private final RequestLoggingSession requestLoggingSession;
private final AtomicBoolean hasBeenResumed = new AtomicBoolean(false);

StreamingResponseHandler(AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession)
{
Expand All @@ -44,7 +50,8 @@ public Void handleException(Request request, Exception exception)
requestLoggingSession.logException(exception);
requestLoggingSession.close();

throw propagate(request, exception);
resume(exception);
return null;
}

@Override
Expand All @@ -62,7 +69,10 @@ public Void handle(Request request, Response response)
};

try (requestLoggingSession) {
jakarta.ws.rs.core.Response.ResponseBuilder responseBuilder = jakarta.ws.rs.core.Response.status(response.getStatusCode()).entity(streamingOutput);
jakarta.ws.rs.core.Response.ResponseBuilder responseBuilder = jakarta.ws.rs.core.Response.status(response.getStatusCode());
if (HttpStatus.familyForStatusCode(response.getStatusCode()) == HttpStatus.Family.SUCCESSFUL) {
responseBuilder.entity(streamingOutput);
}
response.getHeaders()
.keySet()
.stream()
Expand All @@ -74,9 +84,27 @@ public Void handle(Request request, Response response)

// this will block until StreamingOutput completes

asyncResponse.resume(responseBuilder.build());
resume(responseBuilder.build());
}

return null;
}

@SuppressWarnings("ThrowableNotThrown")
private void resume(Object result)
{
switch (result) {
case WebApplicationException exception -> resume(exception.getResponse());
case Throwable exception when Throwables.getRootCause(exception) instanceof WebApplicationException webApplicationException -> resume(webApplicationException.getResponse());
case Throwable exception -> resume(jakarta.ws.rs.core.Response.status(INTERNAL_SERVER_ERROR.getStatusCode(), Optional.ofNullable(exception.getMessage()).orElse("Unknown error")).build());
default -> {
if (hasBeenResumed.compareAndSet(false, true)) {
asyncResponse.resume(result);
}
else {
throw new WebApplicationException("Could not resume with response: " + result, INTERNAL_SERVER_ERROR);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques
Request remoteRequest = remoteRequestBuilder.build();

executorService.submit(() -> {
StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, requestLoggingSession);
try {
httpClient.execute(remoteRequest, new StreamingResponseHandler(asyncResponse, requestLoggingSession));
httpClient.execute(remoteRequest, responseHandler);
}
catch (Throwable e) {
requestLoggingSession.logException(e);
requestLoggingSession.close();
asyncResponse.resume(e);
responseHandler.handleException(remoteRequest, new RuntimeException(e));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void s3DeleteWithPath(@Context ContainerRequest containerRequest, @Suspen
private void handler(ContainerRequest containerRequest, AsyncResponse asyncResponse)
{
Request request = fromRequest(containerRequest);
RequestLoggingSession requestLoggingSession = requestLoggerController.newOrCurrentRequestSession(request, SigningServiceType.S3);
RequestLoggingSession requestLoggingSession = requestLoggerController.newRequestSession(request, SigningServiceType.S3);
try {
ParsedS3Request parsedS3Request = parseRequest(request);

Expand Down

0 comments on commit 3efbe80

Please sign in to comment.