Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gateway: fix Pulsar topics with non default tenant #761

Merged
merged 3 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicOffsetPosition;
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
Expand All @@ -42,44 +45,14 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsumeGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

@Getter
public static class ProduceException extends Exception {

private final ProduceResponse.Status status;

public ProduceException(String message, ProduceResponse.Status status) {
super(message);
this.status = status;
}
}

public static class ProduceGatewayRequestValidator
implements GatewayRequestHandler.GatewayRequestValidator {
@Override
public List<String> getAllRequiredParameters(Gateway gateway) {
return gateway.getParameters();
}

@Override
public void validateOptions(Map<String, String> options) {
for (Map.Entry<String, String> option : options.entrySet()) {
switch (option.getKey()) {
default -> throw new IllegalArgumentException(
"Unknown option " + option.getKey());
}
}
}
}

private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;

private volatile TopicReader reader;
private volatile boolean interrupted;
Expand All @@ -88,8 +61,11 @@ public void validateOptions(Map<String, String> options) {
private AuthenticatedGatewayRequestContext requestContext;
private List<Function<Record, Boolean>> filters;

public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
public ConsumeGateway(
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
}

public void setup(
Expand Down Expand Up @@ -124,9 +100,16 @@ public void setup(
default -> TopicOffsetPosition.absolute(
Base64.getDecoder().decode(positionParameter));
};
TopicDefinition topicDefinition = requestContext.application().resolveTopic(topic);
StreamingClusterRuntime streamingClusterRuntime =
clusterRuntimeRegistry.getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();
reader =
topicConnectionsRuntime.createReader(
streamingCluster, Map.of("topic", topic), position);
streamingCluster, Map.of("topic", resolvedTopicName), position);
reader.start();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.SimpleRecord;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
Expand Down Expand Up @@ -79,15 +83,18 @@ public void validateOptions(Map<String, String> options) {
}

private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final TopicProducerCache topicProducerCache;
private TopicProducer producer;
private List<Header> commonHeaders;
private String logRef;

public ProduceGateway(
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
this.topicProducerCache = topicProducerCache;
}

Expand All @@ -113,17 +120,27 @@ public void start(
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

TopicDefinition topicDefinition = requestContext.application().resolveTopic(topic);
StreamingClusterRuntime streamingClusterRuntime =
clusterRuntimeRegistry.getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();

// we need to cache the producer per topic and per config, since an application update could
// change the configuration
final TopicProducerCache.Key key =
new TopicProducerCache.Key(
requestContext.tenant(),
requestContext.applicationId(),
requestContext.gateway().getId(),
topic,
resolvedTopicName,
configString);
producer =
topicProducerCache.getOrCreate(key, () -> setupProducer(topic, streamingCluster));
topicProducerCache.getOrCreate(
key, () -> setupProducer(resolvedTopicName, streamingCluster));
}

protected TopicProducer setupProducer(String topic, StreamingCluster streamingCluster) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
Expand Down Expand Up @@ -78,6 +79,7 @@ public class GatewayResource {
protected static final ObjectMapper MAPPER = new ObjectMapper();
protected static final String SERVICE_REQUEST_ID_HEADER = "langstream-service-request-id";
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final TopicProducerCache topicProducerCache;
private final ApplicationStore applicationStore;
private final GatewayRequestHandler gatewayRequestHandler;
Expand Down Expand Up @@ -121,6 +123,7 @@ ProduceResponse produce(
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry,
topicProducerCache)) {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
Expand Down Expand Up @@ -259,12 +262,14 @@ private CompletableFuture<ResponseEntity> handleServiceWithTopics(
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry,
topicProducerCache); ) {

final ConsumeGateway consumeGateway =
new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry);
completableFuture.thenRunAsync(
() -> {
if (consumeGateway != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.runner;

import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
@Slf4j
public class ClusterRuntimeRegistryBean {
@Bean
public ClusterRuntimeRegistry registry() {
return new ClusterRuntimeRegistry();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.apigateway.websocket;

import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.TopicProducerCache;
Expand Down Expand Up @@ -49,6 +50,7 @@ public class WebSocketConfig implements WebSocketConfigurer {

private final ApplicationStore applicationStore;
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final GatewayRequestHandler gatewayRequestHandler;
private final TopicProducerCache topicProducerCache;
private final ExecutorService consumeThreadPool =
Expand All @@ -64,19 +66,22 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
CONSUME_PATH)
.addHandler(
new ProduceHandler(
applicationStore,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
PRODUCE_PATH)
.addHandler(
new ChatHandler(
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
CHAT_PATH)
.setAllowedOrigins("*")
Expand All @@ -93,5 +98,6 @@ public ServletServerContainerFactoryBean createWebSocketContainer() {
@PreDestroy
public void onDestroy() {
consumeThreadPool.shutdown();
clusterRuntimeRegistry.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import ai.langstream.api.events.GatewayEventData;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.code.SimpleRecord;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
Expand All @@ -52,14 +56,17 @@ public abstract class AbstractHandler extends TextWebSocketHandler {
protected static final String ATTRIBUTE_PRODUCE_GATEWAY = "__produce_gateway";
protected static final String ATTRIBUTE_CONSUME_GATEWAY = "__consume_gateway";
protected final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
protected final ClusterRuntimeRegistry clusterRuntimeRegistry;
protected final ApplicationStore applicationStore;
private final TopicProducerCache topicProducerCache;

public AbstractHandler(
ApplicationStore applicationStore,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
this.applicationStore = applicationStore;
this.topicProducerCache = topicProducerCache;
}
Expand Down Expand Up @@ -187,11 +194,20 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont

topicConnectionsRuntime.init(streamingCluster);

TopicDefinition topicDefinition =
context.application().resolveTopic(gateway.getEventsTopic());
StreamingClusterRuntime streamingClusterRuntime =
new ClusterRuntimeRegistry().getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();

try (final TopicProducer producer =
topicConnectionsRuntime.createProducer(
"langstream-events",
streamingCluster,
Map.of("topic", gateway.getEventsTopic()))) {
Map.of("topic", resolvedTopicName))) {
producer.start();

final EventSources.GatewaySource source =
Expand Down Expand Up @@ -246,7 +262,8 @@ protected void setupReader(
List<Function<Record, Boolean>> filters,
AuthenticatedGatewayRequestContext context)
throws Exception {
final ConsumeGateway consumeGateway = new ConsumeGateway(topicConnectionsRuntimeRegistry);
final ConsumeGateway consumeGateway =
new ConsumeGateway(topicConnectionsRuntimeRegistry, clusterRuntimeRegistry);
try {
consumeGateway.setup(topic, filters, context);
} catch (Exception ex) {
Expand All @@ -261,7 +278,10 @@ protected void setupProducer(
String topic, List<Header> commonHeaders, AuthenticatedGatewayRequestContext context)
throws Exception {
final ProduceGateway produceGateway =
new ProduceGateway(topicConnectionsRuntimeRegistry, topicProducerCache);
new ProduceGateway(
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache);

try {
produceGateway.start(topic, commonHeaders, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
Expand All @@ -47,8 +48,13 @@ public ChatHandler(
ApplicationStore applicationStore,
ExecutorService executor,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
super(applicationStore, topicConnectionsRuntimeRegistry, topicProducerCache);
super(
applicationStore,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache);
this.executor = executor;
}

Expand Down
Loading
Loading