diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index b8f1948c3..be5b6b689 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -54,6 +54,8 @@ public class ConsumeGateway implements AutoCloseable { private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; private final ClusterRuntimeRegistry clusterRuntimeRegistry; + private volatile TopicConnectionsRuntime topicConnectionsRuntime; + private volatile TopicReader reader; private volatile boolean interrupted; private volatile String logRef; @@ -84,7 +86,7 @@ public void setup( final StreamingCluster streamingCluster = requestContext.application().getInstance().streamingCluster(); - final TopicConnectionsRuntime topicConnectionsRuntime = + topicConnectionsRuntime = topicConnectionsRuntimeRegistry .getTopicConnectionsRuntime(streamingCluster) .asTopicConnectionsRuntime(); @@ -203,6 +205,13 @@ private void closeReader() { log.warn("error closing reader", e); } } + if (topicConnectionsRuntime != null) { + try { + topicConnectionsRuntime.close(); + } catch (Exception e) { + log.warn("error closing runtime", e); + } + } } @Override diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index 4fba7d0ad..70ded0363 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -19,6 +19,7 @@ import ai.langstream.api.model.StreamingCluster; import ai.langstream.api.model.TopicDefinition; import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; @@ -37,7 +38,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; @@ -143,6 +146,43 @@ public void start( key, () -> setupProducer(resolvedTopicName, streamingCluster)); } + @AllArgsConstructor + static class TopicProducerAndRuntime implements TopicProducer { + private TopicProducer producer; + private TopicConnectionsRuntime runtime; + + @Override + public void start() { + producer.start(); + } + + @Override + public void close() { + producer.close(); + runtime.close(); + } + + @Override + public CompletableFuture write(Record record) { + return producer.write(record); + } + + @Override + public Object getNativeProducer() { + return producer.getNativeProducer(); + } + + @Override + public Object getInfo() { + return producer.getInfo(); + } + + @Override + public long getTotalIn() { + return producer.getTotalIn(); + } + } + protected TopicProducer setupProducer(String topic, StreamingCluster streamingCluster) { final TopicConnectionsRuntime topicConnectionsRuntime = @@ -157,7 +197,7 @@ protected TopicProducer setupProducer(String topic, StreamingCluster streamingCl null, streamingCluster, Map.of("topic", topic)); topicProducer.start(); log.debug("[{}] Started producer on topic {}", logRef, topic); - return topicProducer; + return new TopicProducerAndRuntime(topicProducer, topicConnectionsRuntime); } public void produceMessage(String payload) throws ProduceException { diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index 8e7db231b..7d1e83344 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -304,7 +304,6 @@ void testSimpleProduce() throws Exception { final String url = "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" .formatted(port); - produceJsonAndExpectOk(url, "{\"key\": \"my-key\", \"value\": \"my-value\"}"); produceJsonAndExpectOk(url, "{\"key\": \"my-key\"}"); produceJsonAndExpectOk(url, "{\"key\": \"my-key\", \"headers\": {\"h1\": \"v1\"}}"); @@ -569,6 +568,27 @@ void testService() throws Exception { produceJsonAndGetBody( url, "{\"key\": \"my-key2\", \"value\": \"my-value\", \"headers\": {\"header1\":\"value1\"}}")); + + // sorry but kafka can't keep up + final int numParallel = getStreamingCluster().type().equals("kafka") ? 5 : 30; + + List> futures1 = new ArrayList<>(); + for (int i = 0; i < numParallel; i++) { + CompletableFuture future = + CompletableFuture.runAsync( + () -> { + for (int j = 0; j < 10; j++) { + assertMessageContent( + new MsgRecord("my-key", "my-value", Map.of()), + produceJsonAndGetBody( + url, + "{\"key\": \"my-key\", \"value\": \"my-value\"}")); + } + }); + futures1.add(future); + } + CompletableFuture.allOf(futures1.toArray(new CompletableFuture[] {})) + .get(2, TimeUnit.MINUTES); } private void startTopicExchange(String logicalFromTopic, String logicalToTopic)