Skip to content

Commit

Permalink
refactor(ApiGateway): Make HttpClient overridable in RequestInfo #NP-…
Browse files Browse the repository at this point in the history
…47966
  • Loading branch information
torbjokv committed Oct 29, 2024
1 parent 083fc45 commit 003b070
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.http.HttpClient;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -51,23 +52,16 @@ public abstract class ApiGatewayHandler<I, O> extends RestRequestHandler<I, O> {
public static final String ORIGIN_DELIMITER = ",";
public static final String FALLBACK_ORIGIN = "https://nva.sikt.no";

private final ObjectMapper objectMapper;

private Supplier<Map<String, String>> additionalSuccessHeadersSupplier;
private boolean isBase64Encoded;

public ApiGatewayHandler(Class<I> iclass) {
this(iclass, new Environment());
this(iclass, new Environment(), defaultRestObjectMapper, HttpClient.newBuilder().build());
}

public ApiGatewayHandler(Class<I> iclass, Environment environment) {
this(iclass, environment, defaultRestObjectMapper);
this.additionalSuccessHeadersSupplier = Collections::emptyMap;
}

public ApiGatewayHandler(Class<I> iclass, Environment environment, ObjectMapper objectMapper) {
super(iclass, environment);
this.objectMapper = objectMapper;
public ApiGatewayHandler(Class<I> iclass, Environment environment, ObjectMapper objectMapper, HttpClient httpClient) {
super(iclass, environment, objectMapper, httpClient);
this.additionalSuccessHeadersSupplier = Collections::emptyMap;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package nva.commons.apigateway;

import static nva.commons.apigateway.RestConfig.defaultRestObjectMapper;
import com.amazonaws.services.lambda.runtime.Context;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.net.http.HttpClient;
import nva.commons.apigateway.exceptions.ApiGatewayException;
import nva.commons.core.Environment;
import nva.commons.core.JacocoGenerated;
Expand All @@ -18,17 +20,12 @@ public abstract class ApiGatewayProxyHandler<I, O> extends ApiGatewayHandler<I,

@JacocoGenerated
protected ApiGatewayProxyHandler(Class<I> iclass) {
this(iclass, new Environment());
this(iclass, new Environment(), defaultRestObjectMapper, HttpClient.newBuilder().build());
}

@JacocoGenerated
protected ApiGatewayProxyHandler(Class<I> iclass, Environment environment) {
super(iclass, environment);
}

@JacocoGenerated
protected ApiGatewayProxyHandler(Class<I> iclass, Environment environment, ObjectMapper objectMapper) {
super(iclass, environment, objectMapper);
protected ApiGatewayProxyHandler(Class<I> iclass, Environment environment, ObjectMapper objectMapper, HttpClient httpClient) {
super(iclass, environment, objectMapper, httpClient);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.net.http.HttpClient;
import java.time.Duration;
import nva.commons.apigateway.exceptions.BadRequestException;
import nva.commons.core.Environment;
Expand Down Expand Up @@ -29,8 +30,9 @@ public ApiS3GatewayHandler(Class<I> iclass,
S3Client s3client,
S3Presigner s3Presigner,
Environment environment,
ObjectMapper objectMapper) {
super(iclass, s3Presigner, environment, objectMapper);
ObjectMapper objectMapper,
HttpClient httpClient) {
super(iclass, s3Presigner, environment, objectMapper, httpClient);
this.s3client = s3client;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.amazonaws.services.lambda.runtime.Context;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.net.HttpURLConnection;
import java.net.http.HttpClient;
import java.time.Duration;
import java.util.Map;
import nva.commons.apigateway.exceptions.ApiGatewayException;
Expand Down Expand Up @@ -31,8 +32,9 @@ public ApiS3PresignerGatewayHandler(Class<I> iclass, S3Presigner s3Presigner) {
public ApiS3PresignerGatewayHandler(Class<I> iclass,
S3Presigner s3Presigner,
Environment environment,
ObjectMapper objectMapper) {
super(iclass, environment, objectMapper);
ObjectMapper objectMapper,
HttpClient httpClient) {
super(iclass, environment, objectMapper, httpClient);
this.s3presigner = s3Presigner;
}

Expand Down
40 changes: 27 additions & 13 deletions apigateway/src/main/java/nva/commons/apigateway/RequestInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonPointer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.net.HttpHeaders;
import java.io.InputStream;
import java.net.URI;
Expand All @@ -55,14 +56,15 @@
import java.util.stream.Stream;
import no.unit.nva.auth.CognitoUserInfo;
import no.unit.nva.auth.FetchUserInfo;
import no.unit.nva.commons.json.JsonUtils;
import nva.commons.apigateway.exceptions.ApiIoException;
import nva.commons.apigateway.exceptions.BadRequestException;
import nva.commons.apigateway.exceptions.UnauthorizedException;
import nva.commons.core.JacocoGenerated;
import nva.commons.core.SingletonCollector;
import nva.commons.core.StringUtils;
import nva.commons.core.attempt.Failure;
import nva.commons.core.exceptions.ExceptionUtils;
import nva.commons.core.ioutils.IoUtils;
import nva.commons.core.paths.UriWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -71,9 +73,9 @@
public class RequestInfo {

public static final String ERROR_FETCHING_COGNITO_INFO = "Could not fetch user information from Cognito:{}";
private static final HttpClient DEFAULT_HTTP_CLIENT = HttpClient.newBuilder().build();
private static final Logger logger = LoggerFactory.getLogger(RequestInfo.class);
private final HttpClient httpClient;
private static final ObjectMapper mapper = defaultRestObjectMapper;
private HttpClient httpClient;
private final Supplier<URI> cognitoUri;
private final Supplier<URI> e2eTestsUserInfoUri;
@JsonProperty(HEADERS_FIELD)
Expand All @@ -100,29 +102,36 @@ public RequestInfo(HttpClient httpClient, Supplier<URI> cognitoUri, Supplier<URI
this.e2eTestsUserInfoUri = e2eTestsUserInfoUri;
}

public RequestInfo() {
private RequestInfo() {
this.headers = new HashMap<>();
this.pathParameters = new HashMap<>();
this.queryParameters = new HashMap<>();
this.multiValueQueryStringParameters = new HashMap<>();
this.otherProperties = new LinkedHashMap<>(); // ordinary HashMap and ConcurrentHashMap fail.
this.requestContext = defaultRestObjectMapper.createObjectNode();
this.httpClient = DEFAULT_HTTP_CLIENT;
this.httpClient = HttpClient.newBuilder().build();
this.cognitoUri = DEFAULT_COGNITO_URI;
this.e2eTestsUserInfoUri = RequestInfoConstants.E2E_TESTING_USER_INFO_ENDPOINT;
}

public static RequestInfo fromRequest(InputStream requestStream) {
return attempt(() -> JsonUtils.dtoObjectMapper.readValue(requestStream, RequestInfo.class)).orElseThrow();
public static RequestInfo fromRequest(InputStream requestStream, HttpClient httpClient) throws ApiIoException {
String inputString = IoUtils.streamToString(requestStream);
return fromString(inputString, httpClient);
}

public static RequestInfo fromString(String inputString, HttpClient httpClient) throws ApiIoException {
var requestInfo = new ApiMessageParser<>(mapper).getRequestInfo(inputString);
requestInfo.setHttpClient(httpClient);
return requestInfo;
}

@JsonIgnore
public String getHeader(String header) {
return getHeaders().entrySet().stream()
.filter(entry -> entry.getKey().equalsIgnoreCase(header))
.findFirst()
.map(Map.Entry::getValue)
.orElseThrow(() -> new IllegalArgumentException(MISSING_FROM_HEADERS + header));
.filter(entry -> entry.getKey().equalsIgnoreCase(header))
.findFirst()
.map(Map.Entry::getValue)
.orElseThrow(() -> new IllegalArgumentException(MISSING_FROM_HEADERS + header));
}

@JsonIgnore
Expand Down Expand Up @@ -199,6 +208,11 @@ public void setOtherProperties(Map<String, Object> otherProperties) {
this.otherProperties = otherProperties;
}

@JacocoGenerated
public void setHttpClient(HttpClient httpClient) {
this.httpClient = httpClient;
}

public Map<String, String> getHeaders() {
return headers;
}
Expand Down Expand Up @@ -340,7 +354,7 @@ public URI getPersonCristinId() throws UnauthorizedException {
@JsonIgnore
public URI getPersonAffiliation() throws UnauthorizedException {
return extractPersonAffiliationForTests().or(this::fetchPersonAffiliation)
.orElseThrow(UnauthorizedException::new);
.orElseThrow(UnauthorizedException::new);
}

@JsonIgnore
Expand Down Expand Up @@ -455,7 +469,7 @@ private List<AccessRightEntry> fetchAvailableRights() {
return userInfo
.map(CognitoUserInfo::getAccessRights)
.map(accessRightEntryStr -> AccessRightEntry.fromCsvForCustomer(accessRightEntryStr, userInfo.get()
.getCurrentCustomer()))
.getCurrentCustomer()))
.map(stream -> stream.collect(Collectors.toList()))
.orElseGet(Collections::emptyList);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import static nva.commons.core.exceptions.ExceptionUtils.stackTraceInSingleLine;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.net.HttpHeaders;
import com.google.common.net.MediaType;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.http.HttpClient;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -42,7 +44,9 @@ public abstract class RestRequestHandler<I, O> implements RequestStreamHandler {
protected final Environment environment;
private static final Logger logger = LoggerFactory.getLogger(RestRequestHandler.class);
private final transient Class<I> iclass;
private final transient ApiMessageParser<I> inputParser = new ApiMessageParser<>();
private final transient ApiMessageParser<I> inputParser;
protected final ObjectMapper objectMapper;
private final HttpClient httpClient;

protected transient OutputStream outputStream;
protected transient String allowedOrigin;
Expand Down Expand Up @@ -130,9 +134,13 @@ private MediaType defaultResponseContentTypeWhenNotSpecifiedByClientRequest() {
* @param iclass The class object of the input class.
* @param environment the Environment from where the handler will read ENV variables.
*/
public RestRequestHandler(Class<I> iclass, Environment environment) {
public RestRequestHandler(Class<I> iclass, Environment environment, ObjectMapper objectMapper,
HttpClient httpClient) {
this.iclass = iclass;
this.environment = environment;
this.inputParser = new ApiMessageParser<>(objectMapper);
this.objectMapper = objectMapper;
this.httpClient = httpClient;
}

@Override
Expand All @@ -145,7 +153,8 @@ public void handleRequest(InputStream inputStream, OutputStream outputStream, Co
inputObject = attempt(() -> parseInput(inputString))
.orElseThrow(this::parsingExceptionToBadRequestException);

RequestInfo requestInfo = inputParser.getRequestInfo(inputString);
RequestInfo requestInfo = RequestInfo.fromString(inputString, httpClient);
requestInfo.setHttpClient(httpClient);
setAllowedOrigin(requestInfo);

validateRequest(inputObject, requestInfo, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -89,6 +90,8 @@ class ApiGatewayHandlerTest {
private static final String PATH = "path1/path2/path3";
private Context context;
private Handler handler;
private HttpClient httpClient;
private Environment environment;

public static Stream<String> mediaTypeProvider() {
return Stream.of(MediaTypes.APPLICATION_JSON_LD.toString(), MediaTypes.APPLICATION_DATACITE_XML.toString(),
Expand All @@ -98,7 +101,10 @@ public static Stream<String> mediaTypeProvider() {
@BeforeEach
public void setup() {
context = new FakeContext();
handler = new Handler();
httpClient = mock(HttpClient.class);
environment = mock(Environment.class);
when(environment.readEnv("ALLOWED_ORIGIN")).thenReturn("*");
handler = new Handler(defaultRestObjectMapper, environment, httpClient);
}

@Test
Expand Down Expand Up @@ -394,7 +400,7 @@ void shouldReturnContentTypeMatchingSupportedMediaTypeWhenSupportedMediaTypeIsRe
@Test
void handlerSerializesBodyWithNonDefaultSerializationWhenDefaultSerializerIsOverridden() throws IOException {
ObjectMapper spiedMapper = spy(defaultRestObjectMapper);
var handler = new Handler(spiedMapper);
var handler = new Handler(spiedMapper, environment, httpClient);
var inputStream = requestWithHeaders();
var outputStream = outputStream();
handler.handleRequest(inputStream, outputStream, context);
Expand Down Expand Up @@ -425,7 +431,7 @@ void handlerSendsRedirectionWhenItReceivesARedirectException() throws IOExceptio

@Test
void shouldReturnJsonObjectWhenClientAsksForJson() throws Exception {
var handler = new RawStringResponseHandler(dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithHeaders();
var expected = objectMapper.convertValue(createBody(), RequestBody.class);
var outputStream = outputStream();
Expand All @@ -439,7 +445,7 @@ void shouldReturnJsonObjectWhenClientAsksForJson() throws Exception {

@Test
void shouldReturnXmlObjectWhenClientAsksForXml() throws Exception {
var handler = new RawStringResponseHandler(dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithAcceptXmlHeader();
var expected = objectMapper.convertValue(createBody(), RequestBody.class);
var outputStream = outputStream();
Expand All @@ -457,7 +463,7 @@ void shouldReturnAllOriginsWhenEnvironmentAllowsAllOrigins() throws IOException
var environment = mock(Environment.class);
when(environment.readEnv(ALLOWED_ORIGIN_ENV)).thenReturn("*");

var handler = new RawStringResponseHandler(environment, dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithHeaders();

var outputStream = outputStream();
Expand All @@ -474,7 +480,7 @@ void shouldReturnNvaFrontendProdWhenEnvironmentIsEmpty() throws IOException {
var environment = mock(Environment.class);
when(environment.readEnv(ALLOWED_ORIGIN_ENV)).thenReturn("");

var handler = new RawStringResponseHandler(environment, dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithHeaders();

var outputStream = outputStream();
Expand All @@ -492,7 +498,7 @@ void shouldEchoAllowedOriginWhenEnvironmentContainsOrigin() throws IOException {
var environment = mock(Environment.class);
when(environment.readEnv(ALLOWED_ORIGIN_ENV)).thenReturn("localhost, " + originInHeader + ", some-place-else");

var handler = new RawStringResponseHandler(environment, dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithHeaders();

var outputStream = outputStream();
Expand All @@ -511,7 +517,7 @@ void shouldReturnOneOfTheAllowedOriginsInEnvironmentWhenRequestOriginIsNotWhiteL
var environment = mock(Environment.class);
when(environment.readEnv(ALLOWED_ORIGIN_ENV)).thenReturn("localhost, some-place-else");

var handler = new RawStringResponseHandler(environment, dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithHeaders();

var outputStream = outputStream();
Expand All @@ -529,7 +535,7 @@ void shouldReturnFirstElementInAllowedOriginsListWhenOriginIsMissing() throws IO
var header2 = "https://example2.com";
var environment = mock(Environment.class);
when(environment.readEnv(ALLOWED_ORIGIN_ENV)).thenReturn(header1 + ", " + header2);
var handler = new RawStringResponseHandler(environment, dtoObjectMapper);
var handler = new RawStringResponseHandler(dtoObjectMapper, environment, httpClient);
var inputStream = requestWithMissingOriginHeader();
var outputStream = outputStream();
handler.handleRequest(inputStream, outputStream, context);
Expand Down
Loading

0 comments on commit 003b070

Please sign in to comment.