Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1084: decode body if base64 is enable #1085

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
package com.amazonaws.serverless.proxy.spring;

import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.nio.charset.UnsupportedCharsetException;
import java.util.Base64;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.commons.io.Charsets;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest;
import org.springframework.cloud.function.serverless.web.ServerlessMVC;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMapAdapter;
import org.springframework.util.StringUtils;

import com.amazonaws.serverless.proxy.AsyncInitializationWrapper;
import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter;
import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter;
import com.amazonaws.serverless.proxy.RequestReader;
Expand Down Expand Up @@ -120,10 +122,17 @@ private static HttpServletRequest generateRequest1(String request, Context lambd
MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders());
httpRequest.setHeaders(headers);
}
if (StringUtils.hasText(v1Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8));
}
if (StringUtils.hasText(v1Request.getBody())) {
deki marked this conversation as resolved.
Show resolved Hide resolved
if (v1Request.getHeaders().get(HttpHeaders.CONTENT_TYPE)==null) {
httpRequest.setContentType("application/json");
}
if (v1Request.isBase64Encoded()) {
httpRequest.setContent(Base64.getMimeDecoder().decode(v1Request.getBody()));
} else {
Charset charseEncoding = parseCharacterEncoding(v1Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
deki marked this conversation as resolved.
Show resolved Hide resolved
httpRequest.setContent(v1Request.getBody().getBytes(charseEncoding));
}
}
if (v1Request.getRequestContext() != null) {
httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext());
httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb());
Expand All @@ -149,11 +158,19 @@ private static HttpServletRequest generateRequest2(String request, Context lambd
populateQueryStringparameters(v2Request.getQueryStringParameters(), httpRequest);

v2Request.getHeaders().forEach(httpRequest::setHeader);

if (StringUtils.hasText(v2Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8));
}


if (StringUtils.hasText(v2Request.getBody())) {
deki marked this conversation as resolved.
Show resolved Hide resolved
if (v2Request.getHeaders().get(HttpHeaders.CONTENT_TYPE)==null) {
httpRequest.setContentType("application/json");
}
if (v2Request.isBase64Encoded()) {
httpRequest.setContent(Base64.getMimeDecoder().decode(v2Request.getBody()));
} else {
Charset charseEncoding = parseCharacterEncoding(v2Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
httpRequest.setContent(v2Request.getBody().getBytes(charseEncoding));
}
}
httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext());
httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables());
httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request);
Expand All @@ -180,4 +197,36 @@ private static <T> T readValue(String json, Class<T> clazz, ObjectMapper mapper)
}
}

static final String HEADER_KEY_VALUE_SEPARATOR = "=";
static final String HEADER_VALUE_SEPARATOR = ";";
static final String ENCODING_VALUE_KEY = "charset";
static protected Charset parseCharacterEncoding(String contentTypeHeader) {
deki marked this conversation as resolved.
Show resolved Hide resolved
// we only look at content-type because content-encoding should only be used for
// "binary" requests such as gzip/deflate.
Charset defaultCharset = StandardCharsets.UTF_8;
if (contentTypeHeader == null) {
return defaultCharset;
}

String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR);
if (contentTypeValues.length <= 1) {
return defaultCharset;
}

for (String contentTypeValue : contentTypeValues) {
if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) {
String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR);
if (encodingValues.length <= 1) {
return defaultCharset;
}
try {
return Charsets.toCharset(encodingValues[1]);
} catch (UnsupportedCharsetException ex) {
return defaultCharset;
}
}
}
return defaultCharset;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.cloud.function.serverless.web.ServerlessServletContext;
import org.springframework.util.CollectionUtils;

import com.amazonaws.serverless.proxy.spring.servletapp.MessageData;
Expand Down Expand Up @@ -214,7 +210,7 @@ public static Collection<String> data() {
public void validateComplesrequest(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST",
"/foo/male/list/24", "{\"name\":\"bob\"}", null));
"/foo/male/list/24", "{\"name\":\"bob\"}", false,null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -229,7 +225,7 @@ public void validateComplesrequest(String jsonEvent) throws Exception {
@ParameterizedTest
public void testAsyncPost(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}", null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}",false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -242,7 +238,7 @@ public void testAsyncPost(String jsonEvent) throws Exception {
public void testValidate400(String jsonEvent) throws Exception {
initServletAppTest();
UserData ud = new UserData();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -258,27 +254,48 @@ public void testValidate200(String jsonEvent) throws Exception {
ud.setFirstName("bob");
ud.setLastName("smith");
ud.setEmail("[email protected]");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("VALID", result.get("body"));
}

@MethodSource("data")
@ParameterizedTest
public void testValidate200Base64(String jsonEvent) throws Exception {
initServletAppTest();
UserData ud = new UserData();
ud.setFirstName("bob");
ud.setLastName("smith");
ud.setEmail("[email protected]");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate",
Base64.getMimeEncoder().encodeToString(mapper.writeValueAsString(ud).getBytes()),true, null));

ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("VALID", result.get("body"));
}


@MethodSource("data")
@ParameterizedTest
public void messageObject_parsesObject_returnsCorrectMessage(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
mapper.writeValueAsString(new MessageData("test message")), null));
mapper.writeValueAsString(new MessageData("test message")),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("test message", result.get("body"));
}



@SuppressWarnings({"unchecked" })
@MethodSource("data")
@ParameterizedTest
Expand All @@ -289,43 +306,46 @@ void messageObject_propertiesInContentType_returnsCorrectMessage(String jsonEven
headers.put(HttpHeaders.CONTENT_TYPE, "application/json;v=1");
headers.put(HttpHeaders.ACCEPT, "application/json;v=1");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
mapper.writeValueAsString(new MessageData("test message")), headers));
mapper.writeValueAsString(new MessageData("test message")),false, headers));

ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals("test message", result.get("body"));
}

private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
Map requestMap = mapper.readValue(jsonEvent, Map.class);
if (requestMap.get("version").equals("2.0")) {
return generateHttpRequest2(requestMap, method, path, body, headers);
return generateHttpRequest2(requestMap, method, path, body, isBase64Encoded,headers);
}
return generateHttpRequest(requestMap, method, path, body, headers);
return generateHttpRequest(requestMap, method, path, body,isBase64Encoded, headers);
}

@SuppressWarnings({ "unchecked"})
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
requestMap.put("path", path);
requestMap.put("httpMethod", method);
requestMap.put("body", body);
requestMap.put("isBase64Encoded", isBase64Encoded);
if (!CollectionUtils.isEmpty(headers)) {
requestMap.put("headers", headers);
}
return mapper.writeValueAsBytes(requestMap);
}

@SuppressWarnings({ "unchecked"})
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
Map map = mapper.readValue(API_GATEWAY_EVENT_V2, Map.class);
Map http = (Map) ((Map) map.get("requestContext")).get("http");
http.put("path", path);
http.put("method", method);
map.put("body", body);
map.put("isBase64Encoded", isBase64Encoded);
if (!CollectionUtils.isEmpty(headers)) {
map.put("headers", headers);
}
System.out.println(map);
return mapper.writeValueAsBytes(map);
}
}
Loading