From 3f8538d3f4c85fa62dd1356325cf83408be0bc3e Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Thu, 14 Dec 2023 17:51:56 +0100 Subject: [PATCH] Add more tests and add back validation for some agents (#747) --- .../agents/grpc/AbstractGrpcAgent.java | 5 ++- .../agents/grpc/PythonGrpcAgentProcessor.java | 2 +- .../agents/grpc/PythonGrpcAgentSink.java | 2 +- .../agents/grpc/PythonGrpcAgentSource.java | 2 +- .../agents/grpc/PythonGrpcServer.java | 12 +++++- .../ai/agents/commons/jstl/JstlEvaluator.java | 3 ++ .../ai/agents/commons/jstl/JstlFunctions.java | 4 ++ .../oss/streaming/ai/ComputeStep.java | 5 +++ .../oss/streaming/ai/model/ComputeField.java | 2 +- .../streaming/ai/jstl/JstlEvaluatorTest.java | 12 ++++++ .../ai/GenAIToolKitFunctionAgentProvider.java | 10 +++++ .../steps/AIChatCompletionsConfiguration.java | 2 +- .../agents/ai/steps/ComputeConfiguration.java | 2 +- .../impl/uti/ClassConfigValidator.java | 18 +++++++- .../ai/langstream/tests/PythonAgentsIT.java | 6 ++- .../tests/util/BaseEndToEndTest.java | 3 +- .../apps/python-source/pipeline.yaml | 13 ++++-- .../apps/python-source/python/example.py | 2 +- .../runtime/impl/k8s/GenAIAgentsTest.java | 42 +++++++++++++++++++ ...GenAIToolKitFunctionAgentProviderTest.java | 4 +- 20 files changed, 133 insertions(+), 18 deletions(-) diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java index c21d5dcff..6b728c2fd 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java @@ -99,7 +99,10 @@ public void start() throws Exception { throw new IllegalStateException("Channel not initialized"); } blockingStub = - AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); + AgentServiceGrpc.newBlockingStub(channel) + .withMaxInboundMessageSize(Integer.MAX_VALUE) + .withMaxOutboundMessageSize(Integer.MAX_VALUE) + .withDeadlineAfter(30, TimeUnit.SECONDS); asyncStub = AgentServiceGrpc.newStub(channel) .withWaitForReady() diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java index da3d9e367..b207aabfe 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java @@ -49,7 +49,7 @@ public void start() throws Exception { public synchronized void close() throws Exception { super.close(); if (server != null) { - server.close(false); + server.close(true); } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java index 78c9bdb0d..6aebe6a1d 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java @@ -47,7 +47,7 @@ public void start() throws Exception { public synchronized void close() throws Exception { super.close(); if (server != null) { - server.close(false); + server.close(true); } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java index 54b03abb5..bfda10b5f 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java @@ -47,7 +47,7 @@ public void start() throws Exception { public synchronized void close() throws Exception { super.close(); if (server != null) { - server.close(false); + server.close(true); } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java index 3c86a9ed7..25ae0e0eb 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java @@ -90,7 +90,10 @@ public ManagedChannel start() throws Exception { ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build(); AgentServiceGrpc.AgentServiceBlockingStub stub = - AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); + AgentServiceGrpc.newBlockingStub(channel) + .withMaxInboundMessageSize(Integer.MAX_VALUE) + .withMaxOutboundMessageSize(Integer.MAX_VALUE) + .withDeadlineAfter(30, TimeUnit.SECONDS); for (int i = 0; ; i++) { try { stub.agentInfo(Empty.getDefaultInstance()); @@ -101,7 +104,12 @@ public ManagedChannel start() throws Exception { throw e; } log.info("Waiting for python agent to start"); - Thread.sleep(1000); + try { + Thread.sleep(1000); + } catch (InterruptedException interruptedException) { + log.info("Sleep interrupted"); + break; + } } } return channel; diff --git a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java index c8360288c..53dbf26e6 100644 --- a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java +++ b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlEvaluator.java @@ -45,6 +45,9 @@ public JstlEvaluator(String expression, Class type) { @SneakyThrows private void registerFunctions() { + this.expressionContext + .getFunctionMapper() + .mapFunction("fn", "length", JstlFunctions.class.getMethod("length", Object.class)); this.expressionContext .getFunctionMapper() .mapFunction("fn", "toJson", JstlFunctions.class.getMethod("toJson", Object.class)); diff --git a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java index 9d4055937..bddc4b780 100644 --- a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java +++ b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java @@ -187,6 +187,10 @@ public static Map emptyMap() { return Map.of(); } + public static long length(Object o) { + return o == null ? 0 : toString(o).length(); + } + public static Map mapOf(Object... field) { Map result = new HashMap<>(); for (int i = 0; i < field.length; i += 2) { diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java index f97c4d404..f6e6b3d2c 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeStep.java @@ -39,6 +39,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import lombok.Builder; +import lombok.extern.slf4j.Slf4j; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; @@ -46,6 +47,7 @@ /** Computes a field dynamically based on JSTL expressions and adds it to the key or the value . */ @Builder +@Slf4j public class ComputeStep implements TransformStep { public static final long MILLIS_PER_DAY = TimeUnit.DAYS.toMillis(1); @Builder.Default private final List fields = new ArrayList<>(); @@ -85,6 +87,9 @@ public void process(MutableRecord mutableRecord) { .filter(f -> "header.properties".equals(f.getScope())) .collect(Collectors.toList()), mutableRecord); + } catch (RuntimeException error) { + log.error("Error while computing fields on record {}", mutableRecord, error); + throw error; } } diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeField.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeField.java index ecfe60bf5..68773539c 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeField.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/ComputeField.java @@ -89,7 +89,7 @@ public ComputeField build() { this.evaluator = new JstlEvaluator<>(String.format("${%s}", this.expression), getJavaType()); } catch (ELException ex) { - throw new IllegalArgumentException("invalid expression: " + "expression", ex); + throw new IllegalArgumentException("invalid expression: " + expression, ex); } return new ComputeField(name, evaluator, type, scope, optional); } diff --git a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/jstl/JstlEvaluatorTest.java b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/jstl/JstlEvaluatorTest.java index 15f0f8c51..810566f57 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/jstl/JstlEvaluatorTest.java +++ b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/jstl/JstlEvaluatorTest.java @@ -65,6 +65,18 @@ void testPrimitiveValue() { assertEquals("test-message", value); } + @Test + void testLength() { + MutableRecord primitiveStringContext = + Utils.createContextWithPrimitiveRecord(Schema.STRING, "test-message", ""); + + String value = + new JstlEvaluator<>("${fn:length(value)}", String.class) + .evaluate(primitiveStringContext); + + assertEquals("12", value); + } + @Test void testNowFunction() { MutableRecord primitiveStringContext = diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java index 71f0fce2c..720509ed2 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java @@ -90,6 +90,16 @@ protected final ComponentType getComponentType(AgentConfiguration agentConfigura return ComponentType.PROCESSOR; } + @Override + protected Class getAgentConfigModelClass(String type) { + StepConfigurationInitializer stepConfigurationInitializer = STEP_TYPES.get(type); + log.info( + "Validating agent configuration model for type {} with {}", + type, + stepConfigurationInitializer.getAgentConfigurationModelClass()); + return stepConfigurationInitializer.getAgentConfigurationModelClass(); + } + public interface TopicConfigurationGenerator { void generateTopicConfiguration(String topicName); } diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/AIChatCompletionsConfiguration.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/AIChatCompletionsConfiguration.java index 050287593..ac4f51199 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/AIChatCompletionsConfiguration.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/AIChatCompletionsConfiguration.java @@ -76,7 +76,7 @@ public static class ChatMessage { """ Role of the message. The role is used to identify the speaker in the chat. """, - required = true) + required = false) private String role; @ConfigProperty( diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeConfiguration.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeConfiguration.java index 1c6d2d69a..c891bb831 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeConfiguration.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeConfiguration.java @@ -68,7 +68,7 @@ public static class Field { The type field is not required for the message headers [destinationTopic, messageKey, properties.] and STRING will be used. For the value and key, if it is not provided, then the type will be inferred from the result of the expression evaluation. """, - required = true) + required = false) private String type; @ConfigProperty( diff --git a/langstream-core/src/main/java/ai/langstream/impl/uti/ClassConfigValidator.java b/langstream-core/src/main/java/ai/langstream/impl/uti/ClassConfigValidator.java index 6f5f4dcfb..d3fb2052a 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/uti/ClassConfigValidator.java +++ b/langstream-core/src/main/java/ai/langstream/impl/uti/ClassConfigValidator.java @@ -314,6 +314,22 @@ private static void validateProperty( if (propertyValue.getExtendedValidationType() != null) { validateExtendedValidationType(propertyValue.getExtendedValidationType(), actualValue); } + + if (propertyValue.getItems() != null && actualValue != null) { + if (actualValue instanceof Collection collection) { + for (Object o : collection) { + validateProperty( + entityRef, fullPropertyKey, o, propertyValue.getItems(), propertyKey); + } + } else { + validateProperty( + entityRef, + fullPropertyKey, + actualValue, + propertyValue.getItems(), + propertyKey); + } + } } @Data @@ -531,7 +547,7 @@ private static void validateExtendedValidationType( case EL_EXPRESSION -> { if (actualValue instanceof String expression) { log.info("Validating EL expression: {}", expression); - new JstlEvaluator(actualValue.toString(), Object.class); + new JstlEvaluator("${" + actualValue + "}", Object.class); } else if (actualValue instanceof Collection collection) { log.info("Validating EL expressions {}", collection); for (Object o : collection) { diff --git a/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java index bfb324b48..422906ddb 100644 --- a/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java +++ b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java @@ -115,8 +115,12 @@ public void testSource() { .formatted(applicationId) .split(" ")); log.info("Output: {}", output); + String bigPayload = "test".repeat(10000); + String value = "the length is " + bigPayload.length(); Assertions.assertTrue( - output.contains("{\"record\":{\"key\":null,\"value\":\"test\",\"headers\":{}}")); + output.contains( + "{\"record\":{\"key\":null,\"value\":\"" + value + "\",\"headers\":{}}"), + "Output doesn't contain the expected payload: " + output); deleteAppAndAwaitCleanup(tenant, applicationId); } diff --git a/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java b/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java index 7a65c1d5b..1b299b1bd 100644 --- a/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java +++ b/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java @@ -1235,7 +1235,8 @@ protected static String deployLocalApplication( final String command = "bin/langstream apps %s %s -app /tmp/app -i /tmp/instance.yaml -s /tmp/secrets.yaml" .formatted(isUpdate ? "update" : "deploy", applicationId); - executeCommandOnClient((beforeCmd + command).split(" ")); + String logs = executeCommandOnClient((beforeCmd + command).split(" ")); + log.info("Logs after deploy: {}", logs); return podUids; } diff --git a/langstream-e2e-tests/src/test/resources/apps/python-source/pipeline.yaml b/langstream-e2e-tests/src/test/resources/apps/python-source/pipeline.yaml index 6e8b69411..21f5a6200 100644 --- a/langstream-e2e-tests/src/test/resources/apps/python-source/pipeline.yaml +++ b/langstream-e2e-tests/src/test/resources/apps/python-source/pipeline.yaml @@ -22,12 +22,19 @@ topics: creation-mode: create-if-not-exists schema: type: string +resources: + size: 2 pipeline: - name: "Source using Python" - resources: - size: 2 id: "test-python-source" type: "python-source" + configuration: + className: example.TestSource + - name: "Compute length (because we cannot write a big message to Kafka)" + id: "compute-length" + type: "compute" output: ls-test-output configuration: - className: example.TestSource \ No newline at end of file + fields: + - name: "value" + expression: "fn:concat('the length is ', fn:length(value))" diff --git a/langstream-e2e-tests/src/test/resources/apps/python-source/python/example.py b/langstream-e2e-tests/src/test/resources/apps/python-source/python/example.py index c6f20a98b..ef4e63292 100644 --- a/langstream-e2e-tests/src/test/resources/apps/python-source/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/python-source/python/example.py @@ -26,7 +26,7 @@ def read(self): if not self.sent: logging.info("Sending the record") self.sent = True - return [SimpleRecord("test")] + return [SimpleRecord("test" * 10000)] return [] def commit(self, records): diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/GenAIAgentsTest.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/GenAIAgentsTest.java index 7bc72f5b2..8cd2b2949 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/GenAIAgentsTest.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/GenAIAgentsTest.java @@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import ai.langstream.api.model.Application; @@ -841,4 +842,45 @@ public void testForceAiService() throws Exception { assertNull(configuration.get("service")); } } + + @Test + public void testValidateBadComputeStep() throws Exception { + Application applicationInstance = + ModelBuilder.buildApplicationInstance( + Map.of( + "module.yaml", + """ + module: "module-1" + id: "pipeline-1" + topics: + - name: "input-topic" + creation-mode: create-if-not-exists + pipeline: + - name: "compute" + id: "step1" + type: "compute" + input: "input-topic" + configuration: + fields: + - name: value + expression: "fn:concat('something', fn:len(value))" + """), + buildInstanceYaml(), + null) + .getApplication(); + + try (ApplicationDeployer deployer = + ApplicationDeployer.builder() + .registry(new ClusterRuntimeRegistry()) + .pluginsRegistry(new PluginsRegistry()) + .build()) { + Exception e = + assertThrows( + Exception.class, + () -> { + deployer.createImplementation("app", applicationInstance); + }); + assertEquals("Function [fn:len] not found", e.getMessage()); + } + } } diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java index d12b9f8d9..39655a829 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java @@ -187,7 +187,7 @@ public void testDocumentation() { }, "role" : { "description" : "Role of the message. The role is used to identify the speaker in the chat.", - "required" : true, + "required" : false, "type" : "string" } } @@ -456,7 +456,7 @@ public void testDocumentation() { }, "type" : { "description" : "The type of the computed field. This\\n will translate to the schema type of the new field in the transformed message.\\n The following types are currently supported :STRING, INT8, INT16, INT32, INT64, FLOAT, DOUBLE, BOOLEAN, DATE, TIME, TIMESTAMP, LOCAL_DATE_TIME, LOCAL_TIME, LOCAL_DATE, INSTANT.\\n The type field is not required for the message headers [destinationTopic, messageKey, properties.] and STRING will be used.\\n For the value and key, if it is not provided, then the type will be inferred from the result of the expression evaluation.", - "required" : true, + "required" : false, "type" : "string" } }