From bd798986f6cf047a848943a96c09678b3f252554 Mon Sep 17 00:00:00 2001 From: Jordan Zimmerman Date: Wed, 26 Jun 2024 08:01:09 +0100 Subject: [PATCH] Improve exceptions/resumption in StreamingResponseHandler --- .../server/rest/StreamingResponseHandler.java | 36 ++++++++++++++++--- .../proxy/server/rest/TrinoS3ProxyClient.java | 7 ++-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/StreamingResponseHandler.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/StreamingResponseHandler.java index ba969688..c207fd8f 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/StreamingResponseHandler.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/StreamingResponseHandler.java @@ -13,16 +13,21 @@ */ package io.trino.s3.proxy.server.rest; +import com.google.common.base.Throwables; import io.airlift.http.client.HeaderName; +import io.airlift.http.client.HttpStatus; import io.airlift.http.client.Request; import io.airlift.http.client.Response; import io.airlift.http.client.ResponseHandler; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.container.AsyncResponse; import jakarta.ws.rs.core.StreamingOutput; import java.io.InputStream; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; -import static io.airlift.http.client.ResponseHandlerUtils.propagate; +import static jakarta.ws.rs.core.Response.Status.INTERNAL_SERVER_ERROR; import static java.util.Objects.requireNonNull; class StreamingResponseHandler @@ -30,6 +35,7 @@ class StreamingResponseHandler { private final AsyncResponse asyncResponse; private final RequestLoggingSession requestLoggingSession; + private final AtomicBoolean hasBeenResumed = new AtomicBoolean(false); StreamingResponseHandler(AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession) { @@ -44,7 +50,8 @@ public Void handleException(Request request, Exception exception) requestLoggingSession.logException(exception); requestLoggingSession.close(); - throw propagate(request, exception); + resume(exception); + return null; } @Override @@ -62,7 +69,10 @@ public Void handle(Request request, Response response) }; try (requestLoggingSession) { - jakarta.ws.rs.core.Response.ResponseBuilder responseBuilder = jakarta.ws.rs.core.Response.status(response.getStatusCode()).entity(streamingOutput); + jakarta.ws.rs.core.Response.ResponseBuilder responseBuilder = jakarta.ws.rs.core.Response.status(response.getStatusCode()); + if (HttpStatus.familyForStatusCode(response.getStatusCode()) == HttpStatus.Family.SUCCESSFUL) { + responseBuilder.entity(streamingOutput); + } response.getHeaders() .keySet() .stream() @@ -74,9 +84,27 @@ public Void handle(Request request, Response response) // this will block until StreamingOutput completes - asyncResponse.resume(responseBuilder.build()); + resume(responseBuilder.build()); } return null; } + + @SuppressWarnings("ThrowableNotThrown") + private void resume(Object result) + { + switch (result) { + case WebApplicationException exception -> resume(exception.getResponse()); + case Throwable exception when Throwables.getRootCause(exception) instanceof WebApplicationException webApplicationException -> resume(webApplicationException.getResponse()); + case Throwable exception -> resume(jakarta.ws.rs.core.Response.status(INTERNAL_SERVER_ERROR.getStatusCode(), Optional.ofNullable(exception.getMessage()).orElse("Unknown error")).build()); + default -> { + if (hasBeenResumed.compareAndSet(false, true)) { + asyncResponse.resume(result); + } + else { + throw new WebApplicationException("Could not resume with response: " + result, INTERNAL_SERVER_ERROR); + } + } + } + } } diff --git a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java index 4ba4ab94..4daff262 100644 --- a/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-s3-proxy/src/main/java/io/trino/s3/proxy/server/rest/TrinoS3ProxyClient.java @@ -154,13 +154,12 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques Request remoteRequest = remoteRequestBuilder.build(); executorService.submit(() -> { + StreamingResponseHandler responseHandler = new StreamingResponseHandler(asyncResponse, requestLoggingSession); try { - httpClient.execute(remoteRequest, new StreamingResponseHandler(asyncResponse, requestLoggingSession)); + httpClient.execute(remoteRequest, responseHandler); } catch (Throwable e) { - requestLoggingSession.logException(e); - requestLoggingSession.close(); - asyncResponse.resume(e); + responseHandler.handleException(remoteRequest, new RuntimeException(e)); } }); }