From 539d599a06a729d62dffd0d9d3aaef4dddc07df1 Mon Sep 17 00:00:00 2001 From: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Date: Thu, 21 Nov 2024 18:33:36 -0800 Subject: [PATCH] Fix to better handle lambda responses when they are empty or null or not a valid json (#5211) * Fix to better handle lambda responses when they are empty or null or not a valid json Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * UTs for strict mode response comparison Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * Additional UTs for strict mode and aggregate mode Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * removed unused method and better method naming Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * doExecute method testing Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * better exception message Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * Add IT for aggregate mode cases Signed-off-by: Srikanth Govindarajan * Testing with presence of tags Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * Add IT to test behaviour for different lambda responses Signed-off-by: Srikanth Govindarajan * removed unused imports Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> * fix checkstyle Signed-off-by: Srikanth Govindarajan * fix checkstyle Signed-off-by: Srikanth Govindarajan * Address comments Signed-off-by: Srikanth Govindarajan --------- Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Srikanth Govindarajan Co-authored-by: Srikanth Govindarajan --- data-prepper-plugins/aws-lambda/README.md | 29 ++- .../lambda/processor/LambdaProcessorIT.java | 104 +++++++++- .../lambda/processor/LambdaProcessor.java | 38 ++-- .../StrictResponseEventHandlingStrategy.java | 10 +- ...rictResponseModeNotRespectedException.java | 7 + .../lambda/processor/LambdaProcessorTest.java | 184 +++++++++++------- ...rictResponseEventHandlingStrategyTest.java | 16 +- .../lambda/utils/LambdaTestSetupUtil.java | 10 + ...mbda-processor-aggregate-mode-config.yaml} | 0 9 files changed, 288 insertions(+), 110 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/exception/StrictResponseModeNotRespectedException.java rename data-prepper-plugins/aws-lambda/src/test/resources/{lambda-processor-unequal-success-config.yaml => lambda-processor-aggregate-mode-config.yaml} (100%) diff --git a/data-prepper-plugins/aws-lambda/README.md b/data-prepper-plugins/aws-lambda/README.md index 099b390702..25806f3d61 100644 --- a/data-prepper-plugins/aws-lambda/README.md +++ b/data-prepper-plugins/aws-lambda/README.md @@ -49,8 +49,35 @@ The following command runs the integration tests: -Dtests.lambda.processor.region="us-east-1" \ -Dtests.lambda.processor.functionName="test-lambda-processor" \ -Dtests.lambda.processor.sts_role_arn="arn:aws:iam::<>:role/lambda-role" +``` - +Lambda handler used to test: +``` +def lambda_handler(event, context): + input_arr = event.get('osi_key', []) + output = [] + if len(input_arr) == 1: + input = input_arr[0] + if "returnNone" in input: + return + if "returnString" in input: + return "RandomString" + if "returnObject" in input: + return input_arr[0] + if "returnEmptyArray" in input: + return output + if "returnNull" in input: + return "null" + if "returnEmptyMapinArray" in input: + return [{}] + for input in input_arr: + input["_out_"] = "transformed"; + for k,v in input.items(): + if type(v) is str: + input[k] = v.upper() + output.append(input) + + return output ``` diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java index 5ea7115bbf..8203819fcb 100644 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -6,12 +6,24 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import org.mockito.Mock; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -50,16 +62,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorIT { @@ -95,6 +97,7 @@ public void setup() { lambdaRegion = System.getProperty("tests.lambda.processor.region"); functionName = System.getProperty("tests.lambda.processor.functionName"); role = System.getProperty("tests.lambda.processor.sts_role_arn"); + pluginMetrics = mock(PluginMetrics.class); pluginSetting = mock(PluginSetting.class); when(pluginSetting.getPipelineName()).thenReturn("pipeline"); @@ -232,6 +235,87 @@ public void testWithFailureTags() throws Exception { } } + @ParameterizedTest + @ValueSource(strings = {"returnNull", "returnEmptyArray", "returnString", "returnEmptyMapinArray", "returnNone"}) + public void testAggregateMode_WithVariousResponses(String input) { + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); // Aggregate mode + when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure")); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + List> records = createRecord(input); + + Collection> results = lambdaProcessor.doExecute(records); + + switch (input) { + case "returnNull": + case "returnEmptyArray": + case "returnString": + case "returnNone": + assertTrue(results.isEmpty(), "Events should be dropped for null, empty array, or string response"); + break; + case "returnEmptyMapinArray": + assertEquals(1, results.size(), "Should have one event in result for empty map in array"); + assertTrue(results.stream().allMatch(record -> record.getData().toMap().isEmpty()), + "Result should be an empty map"); + break; + default: + fail("Unexpected input: " + input); + } + } + + @ParameterizedTest + @ValueSource(strings = {"returnNone", "returnString", "returnObject", "returnEmptyArray", "returnNull", "returnEmptyMapinArray"}) + public void testStrictMode_WithVariousResponses(String input) { + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); // Strict mode + when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure")); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + List> records = createRecord(input); + + Collection> results = lambdaProcessor.doExecute(records); + + switch (input) { + case "returnNone": + case "returnString": + case "returnEmptyArray": + case "returnNull": + assertEquals(1, results.size(), "Should return original record with failure tag"); + assertTrue(results.iterator().next().getData().getMetadata().getTags().contains("lambda_failure"), + "Result should contain lambda_failure tag"); + break; + case "returnObject": + assertEquals(1, results.size(), "Should return one record"); + assertEquals(records.get(0).getData().toMap(), results.iterator().next().getData().toMap(), + "Returned record should match input record"); + break; + case "returnEmptyMapinArray": + assertEquals(1, results.size(), "Should return one record"); + assertTrue(results.iterator().next().getData().toMap().isEmpty(), + "Returned record should be an empty map"); + break; + } + } + + private List> createRecord(String input) { + List> records = new ArrayList<>(); + Map map = new HashMap<>(); + map.put(input, 42); + EventMetadata metadata = DefaultEventMetadata.builder() + .withEventType("event") + .build(); + final Event event = JacksonEvent.builder() + .withData(map) + .withEventType("event") + .withEventMetadata(metadata) + .build(); + records.add(new Record<>(event)); + + return records; + } + + private void validateResultsForAggregateMode(Collection> results) { List> resultRecords = new ArrayList<>(results); for (int i = 0; i < resultRecords.size(); i++) { diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 719a09eee6..786939f5a1 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -10,6 +10,7 @@ import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; @@ -25,6 +26,7 @@ import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; +import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; @@ -48,9 +50,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; -import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; - @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { @@ -61,6 +60,8 @@ public class LambdaProcessor extends AbstractProcessor, Record, Record tagsOnFailure; private final LambdaAsyncClient lambdaAsyncClient; @@ -102,6 +104,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginSetting pl this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); + this.lambdaResponseRecordsCounter = pluginMetrics.counter(LAMBDA_RESPONSE_RECORDS_COUNTER); this.whenCondition = lambdaProcessorConfig.getWhenCondition(); this.tagsOnFailure = lambdaProcessorConfig.getTagsOnFailure(); @@ -163,6 +166,8 @@ public Collection> doExecute(Collection> records) { new OutputCodecContext()); } catch (Exception e) { LOG.error(NOISY, "Error while sending records to Lambda", e); + numberOfRecordsFailedCounter.increment(recordsToLambda.size()); + numberOfRequestsFailedCounter.increment(); resultRecords.addAll(addFailureTags(recordsToLambda)); } @@ -211,28 +216,19 @@ List> convertLambdaResponseToEvent(Buffer flushedBuffer, List parsedEvents = new ArrayList<>(); SdkBytes payload = lambdaResponse.payload(); - // Handle null or empty payload - if (payload == null || payload.asByteArray().length == 0) { - LOG.warn(NOISY, - "Lambda response payload is null or empty, dropping the original events"); - return responseStrategy.handleEvents(parsedEvents, originalRecords); + // Considering "null" payload as empty response from lambda and not parsing it. + if (!(NO_RETURN_RESPONSE.equals(payload.asUtf8String()))) { + //Convert using response codec + InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); + responseCodec.parse(inputStream, record -> { + Event event = record.getData(); + parsedEvents.add(event); + }); } - - //Convert using response codec - InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); - responseCodec.parse(inputStream, record -> { - Event event = record.getData(); - parsedEvents.add(event); - }); - - if (parsedEvents.isEmpty()) { - throw new RuntimeException( - "Lambda Response could not be parsed, returning original events"); - } - LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize()); + lambdaResponseRecordsCounter.increment(parsedEvents.size()); return responseStrategy.handleEvents(parsedEvents, originalRecords); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java index 2744534c8d..7f804ecd96 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java @@ -8,6 +8,7 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; +import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException; import java.util.ArrayList; import java.util.List; @@ -19,8 +20,13 @@ public class StrictResponseEventHandlingStrategy implements ResponseEventHandlin public List> handleEvents(List parsedEvents, List> originalRecords) { if (parsedEvents.size() != originalRecords.size()) { - throw new RuntimeException( - "Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch."); + throw new StrictResponseModeNotRespectedException( + "Event count mismatch. The aws_lambda processor is configured with response_events_match set to true. " + + "The Lambda function responded with a different number of events. " + + "Either set response_events_match to false or investigate your " + + "Lambda function to ensure that it returns the same number of " + + "events and provided as input. parsedEvents size = " + parsedEvents.size() + + ", Original events size = " + originalRecords.size()); } List> resultRecords = new ArrayList<>(); diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/exception/StrictResponseModeNotRespectedException.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/exception/StrictResponseModeNotRespectedException.java new file mode 100644 index 0000000000..9e4d8972ad --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/exception/StrictResponseModeNotRespectedException.java @@ -0,0 +1,7 @@ +package org.opensearch.dataprepper.plugins.lambda.processor.exception; + +public class StrictResponseModeNotRespectedException extends RuntimeException { + public StrictResponseModeNotRespectedException(final String message) { + super(message); + } +} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index caed598787..032bea1530 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -5,16 +5,34 @@ package org.opensearch.dataprepper.plugins.lambda.processor; +import com.fasterxml.jackson.core.JsonParseException; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import static org.junit.jupiter.params.provider.Arguments.arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import org.mockito.Mock; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -36,12 +54,17 @@ import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleRecord; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.util.Arrays; @@ -52,26 +75,6 @@ import java.util.function.Consumer; import java.util.stream.Stream; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyDouble; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.createLambdaConfigurationFromYaml; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleRecord; @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorTest { @@ -127,49 +130,21 @@ public class LambdaProcessorTest { private LambdaAsyncClient lambdaAsyncClient; - private static Stream getLambdaResponseConversionSamples() { - return Stream.of( - arguments("lambda-processor-success-config.yaml", null), - arguments("lambda-processor-success-config.yaml", SdkBytes.fromByteArray("{}".getBytes())), - arguments("lambda-processor-success-config.yaml", SdkBytes.fromByteArray("[]".getBytes())) - ); - } - @BeforeEach - public void setUp() throws Exception { + public void setUp() { MockitoAnnotations.openMocks(this); when(pluginSetting.getName()).thenReturn("testProcessor"); when(pluginSetting.getPipelineName()).thenReturn("testPipeline"); -/* - // Mock PluginMetrics counters and timers - when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS))).thenReturn( - numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED))).thenReturn( - numberOfRecordsFailedCounter); - when(pluginMetrics.counter(eq(NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA))).thenReturn( - numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(eq(NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA))).thenReturn( - numberOfRecordsFailedCounter); - when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); -*/ - // Mock AWS Authentication Options when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("testRole"); - // Mock BatchOptions and ThresholdOptions - // Mock PluginFactory to return the mocked responseCodec when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn( new JsonInputCodec(new JsonInputCodecConfig())); - // Instantiate the LambdaProcessor manually - - -// populatePrivateFields(); - //setPrivateField(lambdaProcessor, "pluginMetrics", pluginMetrics); // Mock InvokeResponse when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\":\"value\"}]")); when(invokeResponse.statusCode()).thenReturn(200); // Success status code @@ -179,9 +154,6 @@ public void setUp() throws Exception { invokeResponse); when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); - // Mock Response Codec parse method -// doNothing().when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); - } private void populatePrivateFields(LambdaProcessor lambdaProcessor) throws Exception { @@ -415,9 +387,6 @@ public void testDoExecute_SuccessfulProcessing(String configFileName) throws Exc @ValueSource(strings = {"lambda-processor-success-config.yaml"}) public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProcessing(String configFileName) throws Exception { - // Arrange - - // Mock LambdaResponse with a valid payload containing two events String payloadString = "[{\"key\":\"value1\"}, {\"key\":\"value2\"}]"; SdkBytes sdkBytes = SdkBytes.fromByteArray(payloadString.getBytes()); @@ -466,13 +435,9 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc } @ParameterizedTest - @ValueSource(strings = {"lambda-processor-unequal-success-config.yaml"}) + @ValueSource(strings = {"lambda-processor-aggregate-mode-config.yaml"}) public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulProcessing(String configFileName) throws Exception { - // Arrange - // Set responseEventsMatch to false - - // Mock LambdaResponse with a valid payload containing three events String payloadString = "[{\"key\":\"value1\"}, {\"key\":\"value2\"}, {\"key\":\"value3\"}]"; SdkBytes sdkBytes = SdkBytes.fromByteArray(payloadString.getBytes()); @@ -528,9 +493,35 @@ public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulPr assertEquals(3, resultRecords.size(), "ResultRecords should contain three records."); } + + private static Stream getLambdaResponseConversionSamplesForStrictAndAggregateMode() { + return Stream.of( + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), null, RuntimeException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), "null", StrictResponseModeNotRespectedException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), "random string", JsonParseException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("{}".getBytes()), StrictResponseModeNotRespectedException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[]".getBytes()), StrictResponseModeNotRespectedException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[{\"key\":\"val\"}]".getBytes()), null, 1), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[{\"key\":\"val\"}, {\"key\":\"val\"}]".getBytes()), StrictResponseModeNotRespectedException.class, 0), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(2), SdkBytes.fromByteArray("[{\"key\":\"val\"}, {\"key\":\"val\"}]".getBytes()), null, 2), + //Aggregate mode + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), null, RuntimeException.class, 0), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), "null", null, 0), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), "random string", JsonParseException.class, 0), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("{}".getBytes()), null, 0), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[]".getBytes()), null, 0), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(2), SdkBytes.fromByteArray("[{\"key\":\"val\"}]".getBytes()), null, 1) + + ); + } + @ParameterizedTest - @MethodSource("getLambdaResponseConversionSamples") - public void testConvertLambdaResponseToEvent_ExpectException_when_request_response_do_not_match(String configFile, SdkBytes lambdaReponse) { + @MethodSource("getLambdaResponseConversionSamplesForStrictAndAggregateMode") + public void testConvertLambdaResponseToEvent_for_strict_and_aggregate_mode(String configFile, + List> originalRecords, + SdkBytes lambdaResponse, + Class expectedException, + int expectedFinalRecordCount) throws IOException { // Arrange // Set responseEventsMatch to false LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFile); @@ -538,19 +529,78 @@ public void testConvertLambdaResponseToEvent_ExpectException_when_request_respon lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); InvokeResponse invokeResponse = mock(InvokeResponse.class); // Mock LambdaResponse with a valid payload containing three events - when(invokeResponse.payload()).thenReturn(lambdaReponse); + when(invokeResponse.payload()).thenReturn(lambdaResponse); when(invokeResponse.statusCode()).thenReturn(200); // Success status code - int randomCount = (int) (Math.random() * 10); - List> originalRecords = getSampleEventRecords(randomCount); Buffer buffer = new InMemoryBuffer(lambdaProcessorConfig.getBatchOptions().getKeyName()); for (Record originalRecord : originalRecords) { buffer.addRecord(originalRecord); } // Act - assertThrows(RuntimeException.class, () -> localLambdaProcessor.convertLambdaResponseToEvent(buffer, invokeResponse), - "For Strict mode request and response size from lambda should match"); + if (null != expectedException) { + assertThrows(expectedException, () -> localLambdaProcessor.convertLambdaResponseToEvent(buffer, invokeResponse)); + } else { + List> records = localLambdaProcessor.convertLambdaResponseToEvent(buffer, invokeResponse); + assertEquals(expectedFinalRecordCount, records.size(), String.format("Expected %s size of records", expectedFinalRecordCount)); + } + } + + private static Stream getDoExecuteSamplesForStrictAndAggregateMode() { + List> firstSample = getSampleEventRecords(1); + + List> secondSample = getSampleEventRecords(1); + List> thirdSample = getSampleEventRecords(1); + List> fourthSample = getSampleEventRecords(1); + List> fifthSample = getSampleEventRecords(1); + String fifthSampleJsonString = fifthSample.get(0).getData().toJsonString(); + fifthSampleJsonString = "[" + fifthSampleJsonString + "]"; + + return Stream.of( + arguments("lambda-processor-success-config.yaml", firstSample, null, firstSample, true), + arguments("lambda-processor-success-config.yaml", secondSample, "null", secondSample, true), + arguments("lambda-processor-success-config.yaml", thirdSample, "random string",thirdSample, true), + arguments("lambda-processor-success-config.yaml", fourthSample, SdkBytes.fromByteArray("[]".getBytes()), fourthSample, true), + arguments("lambda-processor-success-config.yaml", fifthSample, SdkBytes.fromByteArray(fifthSampleJsonString.getBytes()), fifthSample, false)/*, + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[{\"key\":\"val\"}, {\"key\":\"val\"}]".getBytes()),Collections.emptyList()), + arguments("lambda-processor-success-config.yaml", getSampleEventRecords(2), SdkBytes.fromByteArray("[{\"key\":\"val\"}, {\"key\":\"val\"}]".getBytes()), Collections.emptyList()), + //Aggregate mode + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), null, Collections.emptyList()), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), "null", Collections.emptyList()), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), "random string", Collections.emptyList()), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("{}".getBytes()), Collections.emptyList()), + arguments("lambda-processor-aggregate-mode-config.yaml", getSampleEventRecords(1), SdkBytes.fromByteArray("[]".getBytes()), Collections.emptyList()) +*/ + ); } + @ParameterizedTest + @MethodSource("getDoExecuteSamplesForStrictAndAggregateMode") + public void testDoExecute_for_strict_and_aggregate_mode(String configFile, + List> originalRecords, + SdkBytes lambdaResponse, + Collection> expectedRecords, + boolean validateTags) throws Exception { + // Arrange + // Set responseEventsMatch to false + LambdaProcessorConfig lambdaProcessorConfig = createLambdaConfigurationFromYaml(configFile); + LambdaProcessor localLambdaProcessor = new LambdaProcessor(pluginFactory, pluginSetting, + lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); + + InvokeResponse invokeResponse = mock(InvokeResponse.class); + LambdaAsyncClient lambdaAsyncClient = mock(LambdaAsyncClient.class); + setPrivateField(localLambdaProcessor, "lambdaAsyncClient", lambdaAsyncClient); + CompletableFuture invokeFuture = CompletableFuture.completedFuture(invokeResponse); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); + // Mock LambdaResponse with a valid payload containing three events + when(invokeResponse.payload()).thenReturn(lambdaResponse); + when(invokeResponse.statusCode()).thenReturn(200); // Success status code + + Collection> records = localLambdaProcessor.doExecute(originalRecords); + assertEquals(expectedRecords, records); + if (validateTags) { + Record record = records.iterator().next(); + assertEquals("[lambda_failure]", record.getData().getMetadata().getTags().toString()); + } + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java index 3962f264b9..6ebc8c71ba 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java @@ -5,19 +5,19 @@ package org.opensearch.dataprepper.plugins.lambda.processor; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; +import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleParsedEvents; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleEventRecords; -import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleParsedEvents; - public class StrictResponseEventHandlingStrategyTest { @@ -53,11 +53,9 @@ public void testHandleEvents_WithMismatchingEventCount_ShouldThrowException() { List> originalRecords = getSampleEventRecords(firstRandomCount + 10); // Act & Assert - RuntimeException exception = assertThrows(RuntimeException.class, () -> + RuntimeException exception = assertThrows(StrictResponseModeNotRespectedException.class, () -> strictResponseEventHandlingStrategy.handleEvents(parsedEvents, originalRecords) ); - - assertEquals("Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch.", exception.getMessage()); } @Test diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java index 1453c30e46..0bf9a2fd25 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/utils/LambdaTestSetupUtil.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventType; import org.opensearch.dataprepper.model.event.JacksonEvent; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.record.RecordMetadata; @@ -14,8 +15,10 @@ import java.io.IOException; import java.io.InputStream; +import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.UUID; public class LambdaTestSetupUtil { @@ -55,6 +58,13 @@ public static Event getSampleEvent() { return JacksonEvent.fromMessage(UUID.randomUUID().toString()); } + public static Event getSampleEventWithAttributes(String key, String value) { + return JacksonEvent.fromEvent(JacksonEvent.builder().withEventType(EventType.DOCUMENT.toString()) + .withTimeReceived(Instant.now()) + .withEventMetadataAttributes(Map.of(key, value)) + .build()); + } + public static List> getSampleEventRecords(int count) { List> originalRecords = new ArrayList<>(); for (int i = 0; i < count; i++) { diff --git a/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-unequal-success-config.yaml b/data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-aggregate-mode-config.yaml similarity index 100% rename from data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-unequal-success-config.yaml rename to data-prepper-plugins/aws-lambda/src/test/resources/lambda-processor-aggregate-mode-config.yaml