Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pritham Marupaka committed Jan 7, 2025
1 parent d34f615 commit c6dfe0c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.palantir.dialogue.Response;

// TODO(pm): use the new EndpointErrorDecoder
public final class ConjureErrorDecoder implements ErrorDecoder {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* items:
* - we don't want to use `String` for the error identifier. Let's create an `ErrorName` class.
* - re-consider using a map for the deserializersForEndpointBaseType field. is there a more direct way to get this info
*/

/** Package private internal API. */
Expand All @@ -65,7 +65,7 @@ final class ConjureBodySerDe implements BodySerDe {
private final Deserializer<Optional<InputStream>> optionalBinaryInputStreamDeserializer;
private final Deserializer<Void> emptyBodyDeserializer;
private final LoadingCache<Type, Serializer<?>> serializers;
private final LoadingCache<Type, EncodingDeserializerRegistry<?>> deserializers;
private final LoadingCache<Type, EncodingDeserializerForEndpointRegistry<?>> deserializers;
private final EmptyContainerDeserializer emptyContainerDeserializer;

/**
Expand All @@ -75,32 +75,49 @@ final class ConjureBodySerDe implements BodySerDe {
*/
ConjureBodySerDe(
List<WeightedEncoding> rawEncodings,
ErrorDecoder errorDecoder,
ErrorDecoder _errorDecoder,
EmptyContainerDeserializer emptyContainerDeserializer,
CaffeineSpec cacheSpec) {
List<WeightedEncoding> encodings = decorateEncodings(rawEncodings);
this.encodingsSortedByWeight = sortByWeight(encodings);
Preconditions.checkArgument(encodings.size() > 0, "At least one Encoding is required");
// note(pm): why do the weighted encoding thing? can we just pass in the default encoding?
this.defaultEncoding = encodings.get(0).encoding();
this.emptyContainerDeserializer = emptyContainerDeserializer;
this.binaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
this.binaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
emptyContainerDeserializer,
BinaryEncoding.MARKER);
this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerRegistry<>(
BinaryEncoding.MARKER,
DeserializerArgs.<InputStream>builder()
.withBaseType(BinaryEncoding.MARKER)
.withExpectedResult(BinaryEncoding.MARKER)
.build());
this.optionalBinaryInputStreamDeserializer = new EncodingDeserializerForEndpointRegistry<>(
ImmutableList.of(BinaryEncoding.INSTANCE),
errorDecoder,
emptyContainerDeserializer,
BinaryEncoding.OPTIONAL_MARKER);
this.emptyBodyDeserializer = new EmptyBodyDeserializer(errorDecoder);
BinaryEncoding.OPTIONAL_MARKER,
DeserializerArgs.<Optional<InputStream>>builder()
.withBaseType(BinaryEncoding.OPTIONAL_MARKER)
.withExpectedResult(BinaryEncoding.OPTIONAL_MARKER)
.build());
this.emptyBodyDeserializer =
new EmptyBodyDeserializer(new EndpointErrorDecoder<>(Map.of(), encodingsSortedByWeight));
// Class unloading: Not supported, Jackson keeps strong references to the types
// it sees: https://github.com/FasterXML/jackson-databind/issues/489
this.serializers = Caffeine.from(cacheSpec)
.build(type -> new EncodingSerializerRegistry<>(defaultEncoding, TypeMarker.of(type)));
this.deserializers = Caffeine.from(cacheSpec)
.build(type -> new EncodingDeserializerRegistry<>(
encodingsSortedByWeight, errorDecoder, emptyContainerDeserializer, TypeMarker.of(type)));
this.deserializers = Caffeine.from(cacheSpec).build(type -> buildCacheEntry(TypeMarker.of(type)));
}

private <T> EncodingDeserializerForEndpointRegistry<?> buildCacheEntry(TypeMarker<T> typeMarker) {
return new EncodingDeserializerForEndpointRegistry<>(
encodingsSortedByWeight,
emptyContainerDeserializer,
typeMarker,
DeserializerArgs.<T>builder()
.withBaseType(typeMarker)
.withExpectedResult(typeMarker)
.build());
}

private static List<WeightedEncoding> decorateEncodings(List<WeightedEncoding> input) {
Expand Down Expand Up @@ -235,108 +252,7 @@ private static final class EncodingSerializerContainer<T> {
}
}

private static final class EncodingDeserializerRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<T>> encodings;
private final ErrorDecoder errorDecoder;
private final Optional<String> acceptValue;
private final Supplier<Optional<T>> emptyInstance;
private final TypeMarker<T> token;

EncodingDeserializerRegistry(
List<Encoding> encodings,
ErrorDecoder errorDecoder,
EmptyContainerDeserializer empty,
TypeMarker<T> token) {
this.encodings = encodings.stream()
.map(encoding -> new EncodingDeserializerContainer<>(encoding, token))
.collect(ImmutableList.toImmutableList());
this.errorDecoder = errorDecoder;
this.token = token;
this.emptyInstance = Suppliers.memoize(() -> empty.tryGetEmptyInstance(token));
// Encodings are applied to the accept header in the order of preference based on the provided list.
this.acceptValue =
Optional.of(encodings.stream().map(Encoding::getContentType).collect(Collectors.joining(", ")));
}

@Override
public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (errorDecoder.isError(response)) {
throw errorDecoder.decode(response);
} else if (response.code() == 204) {
// TODO(dfox): what if we get a 204 for a non-optional type???
// TODO(dfox): support http200 & body=null
// TODO(dfox): what if we were expecting an empty list but got {}?
Optional<T> maybeEmptyInstance = emptyInstance.get();
if (maybeEmptyInstance.isPresent()) {
return maybeEmptyInstance.get();
}
throw new SafeRuntimeException(
"Unable to deserialize non-optional response type from 204", SafeArg.of("type", token));
}

Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
if (!contentType.isPresent()) {
throw new SafeIllegalArgumentException(
"Response is missing Content-Type header",
SafeArg.of("received", response.headers().keySet()));
}
Encoding.Deserializer<T> deserializer = getResponseDeserializer(contentType.get());
T deserialized = deserializer.deserialize(response.body());
// deserializer has taken on responsibility for closing the response body
closeResponse = false;
return deserialized;
} catch (IOException e) {
throw new SafeRuntimeException(
"Failed to deserialize response stream",
e,
SafeArg.of("contentType", response.getFirstHeader(HttpHeaders.CONTENT_TYPE)),
SafeArg.of("type", token));
} finally {
if (closeResponse) {
response.close();
}
}
}

@Override
public Optional<String> accepts() {
return acceptValue;
}

/** Returns the {@link EncodingDeserializerContainer} to use to deserialize the request body. */
@SuppressWarnings("ForLoopReplaceableByForEach")
// performance sensitive code avoids iterator allocation
Encoding.Deserializer<T> getResponseDeserializer(String contentType) {
for (int i = 0; i < encodings.size(); i++) {
EncodingDeserializerContainer<T> container = encodings.get(i);
if (container.encoding.supportsContentType(contentType)) {
return container.deserializer;
}
}
return throwingDeserializer(contentType);
}

private Encoding.Deserializer<T> throwingDeserializer(String contentType) {
return input -> {
try {
input.close();
} catch (RuntimeException | IOException e) {
log.warn("Failed to close InputStream", e);
}
throw new SafeRuntimeException(
"Unsupported Content-Type",
SafeArg.of("received", contentType),
SafeArg.of("supportedEncodings", encodings));
};
}
}

private static final class EncodingDeserializerForEndpointRegistry<T> implements Deserializer<T> {

private static final SafeLogger log = SafeLoggerFactory.get(EncodingDeserializerForEndpointRegistry.class);
private final ImmutableList<EncodingDeserializerContainer<? extends T>> encodings;
private final EndpointErrorDecoder<T> endpointErrorDecoder;
Expand Down Expand Up @@ -367,7 +283,6 @@ public T deserialize(Response response) {
boolean closeResponse = true;
try {
if (endpointErrorDecoder.isError(response)) {
// TODO(pm): This needs to return T for the new deserializer API, but throw an exception for the old
return endpointErrorDecoder.decode(response);
} else if (response.code() == 204) {
Optional<T> maybeEmptyInstance = emptyInstance.get();
Expand Down Expand Up @@ -457,9 +372,9 @@ public String toString() {
}

private static final class EmptyBodyDeserializer implements Deserializer<Void> {
private final ErrorDecoder errorDecoder;
private final EndpointErrorDecoder<?> errorDecoder;

EmptyBodyDeserializer(ErrorDecoder errorDecoder) {
EmptyBodyDeserializer(EndpointErrorDecoder<?> errorDecoder) {
this.errorDecoder = errorDecoder;
}

Expand All @@ -469,7 +384,7 @@ public Void deserialize(Response response) {
// We should not fail if a server that previously returned nothing starts returning a response
try (Response unused = response) {
if (errorDecoder.isError(response)) {
throw errorDecoder.decode(response);
errorDecoder.decode(response);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,32 @@ private T decodeInternal(Response response) {
}

Optional<String> contentType = response.getFirstHeader(HttpHeaders.CONTENT_TYPE);
// Use a factory: given contentType, create the deserailizer.
// Use a factory: given contentType, create the deserializer.
// We need Encoding.Deserializer here. That depends on the encoding.
if (contentType.isPresent() && Encodings.matchesContentType("application/json", contentType.get())) {
String jsonContentType = "application/json";
if (contentType.isPresent() && Encodings.matchesContentType(jsonContentType, contentType.get())) {
try {
JsonNode node = MAPPER.readTree(body);
if (node.get("errorName") != null) {
// TODO(pm): Update this to use some struct instead of errorName.
TypeMarker<? extends T> container = Optional.ofNullable(
errorNameToTypeMap.get(node.get("errorName").asText()))
.orElseThrow();
for (int i = 0; i < encodings.size(); i++) {
Encoding encoding = encodings.get(i);
if (encoding.supportsContentType(contentType.get())) {
return encoding.deserializer(container)
.deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)));
}
if (node.get("errorName") == null) {
throwSerializableError(body, code);
}
// TODO(pm): Update this to use some struct instead of errorName.
Optional<TypeMarker<? extends T>> maybeContainer = Optional.ofNullable(
errorNameToTypeMap.get(node.get("errorName").asText()));
if (maybeContainer.isEmpty()) {
// This thrown exception will be caught below. Refactor.
throwSerializableError(body, code);
}
for (int i = 0; i < encodings.size(); i++) {
Encoding encoding = encodings.get(i);
if (encoding.supportsContentType(jsonContentType)) {
return encoding.deserializer(maybeContainer.get())
.deserialize(new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)));
}
} else {
SerializableError serializableError = MAPPER.readValue(body, SerializableError.class);
throw new RemoteException(serializableError, code);
}
} catch (RemoteException remoteException) {
// rethrow the created remote exception
throw remoteException;
} catch (Exception e) {
throw new UnknownRemoteException(code, body);
}
Expand All @@ -156,6 +161,11 @@ private T decodeInternal(Response response) {
throw new UnknownRemoteException(code, body);
}

private static void throwSerializableError(String body, int code) throws IOException {
SerializableError serializableError = MAPPER.readValue(body, SerializableError.class);
throw new RemoteException(serializableError, code);
}

private static String toString(InputStream body) throws IOException {
try (Reader reader = new InputStreamReader(body, StandardCharsets.UTF_8)) {
return CharStreams.toString(reader);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.palantir.conjure.java.api.errors.ErrorType;
import com.palantir.conjure.java.api.errors.RemoteException;
import com.palantir.conjure.java.api.errors.SerializableError;
import com.palantir.conjure.java.api.errors.ServiceException;
import com.palantir.conjure.java.serialization.ObjectMappers;
import com.palantir.dialogue.BinaryRequestBody;
import com.palantir.dialogue.BodySerDe;
import com.palantir.dialogue.RequestBody;
Expand All @@ -47,6 +50,7 @@
@ExtendWith(MockitoExtension.class)
public class ConjureBodySerDeTest {

private static final ObjectMapper SERVER_MAPPER = ObjectMappers.newServerObjectMapper();
private static final TypeMarker<String> TYPE = new TypeMarker<String>() {};
private static final TypeMarker<Optional<String>> OPTIONAL_TYPE = new TypeMarker<Optional<String>>() {};

Expand Down Expand Up @@ -137,14 +141,12 @@ public void testRequestUnknownContentType() throws IOException {
}

@Test
public void testErrorsDecoded() {
TestResponse response = new TestResponse().code(400);

public void testErrorsDecoded() throws JsonProcessingException {
ServiceException serviceException = new ServiceException(ErrorType.INVALID_ARGUMENT);
SerializableError serialized = SerializableError.forException(serviceException);
errorDecoder = mock(ErrorDecoder.class);
when(errorDecoder.isError(response)).thenReturn(true);
when(errorDecoder.decode(response)).thenReturn(new RemoteException(serialized, 400));
TestResponse response = TestResponse.withBody(
SERVER_MAPPER.writeValueAsString(SerializableError.forException(serviceException)))
.code(400)
.contentType("application/json");

BodySerDe serializers = conjureBodySerDe("text/plain");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ public void testDeserializeCustomErrors() throws IOException {
EndpointErrorsConjureBodySerDeTest.EndpointReturnBaseType value =
serializers.deserializer(deserializerArgs).deserialize(response);

assertThat(value).isInstanceOf(ErrorForEndpoint.class);
assertThat(value)
.extracting("errorCode", "errorName", "errorInstanceId", "args")
.containsExactly(
Expand Down

0 comments on commit c6dfe0c

Please sign in to comment.