diff --git a/examples/applications/langchain-chat/configuration.yaml b/examples/applications/langchain-chat/configuration.yaml index b87a824d2..c82f1a536 100644 --- a/examples/applications/langchain-chat/configuration.yaml +++ b/examples/applications/langchain-chat/configuration.yaml @@ -18,17 +18,20 @@ configuration: resources: - type: "open-ai-configuration" + id: "openai" name: "OpenAI Azure configuration" configuration: url: "${secrets.open-ai.url}" access-key: "${secrets.open-ai.access-key}" provider: "${secrets.open-ai.provider}" - type: "datasource" + id: "database" name: "AstraDatasource" configuration: service: "astra" clientId: "${secrets.astra-langchain.clientId}" secret: "${secrets.astra-langchain.secret}" database: "${secrets.astra-langchain.database}" + database-id: "${secrets.astra-langchain.database-id}" token: "${secrets.astra-langchain.token}" environment: "PROD" \ No newline at end of file diff --git a/examples/applications/langchain-chat/crawler.yaml b/examples/applications/langchain-chat/crawler.yaml index 2d7069984..d2685fbae 100644 --- a/examples/applications/langchain-chat/crawler.yaml +++ b/examples/applications/langchain-chat/crawler.yaml @@ -99,7 +99,7 @@ pipeline: resources: size: 2 configuration: - datasource: "AstraDatasource" + datasource: "database" table-name: "documents" keyspace: "documents" mapping: "row_id=value.row_id, body_blob=value.text, vector=value.embeddings_vector" \ No newline at end of file diff --git a/examples/applications/langchain-chat/python/langchain_chat.py b/examples/applications/langchain-chat/python/langchain_chat.py index 584c831a9..fbfa34324 100644 --- a/examples/applications/langchain-chat/python/langchain_chat.py +++ b/examples/applications/langchain-chat/python/langchain_chat.py @@ -201,12 +201,13 @@ def create_chain( class LangChainChat(Processor): def init(self, config): - self.openai_key = config.get("openai-key", "") - self.astra_db_token = config.get("astra-db-token", "") - self.astra_db_id = config.get("astra-db-id", "") - self.astra_keyspace = config.get("astra-db-keyspace", "") - self.astra_table_name = config.get("astra-db-table", "") + # the values are configured in the resources section in configuration.yaml + self.openai_key = config.resources.openai.get("access-key", ""); + self.astra_db_token = config.resources.database.get("token", "") + self.astra_db_id = config.resources.database.get("database-id", "") + self.astra_keyspace = config.resources.database.get("keyspace", "") + self.astra_table_name = config.resources.database.get("table", "") cassio.init(token=self.astra_db_token, database_id=self.astra_db_id) diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/datasource/CassandraDataSource.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/datasource/CassandraDataSource.java index 62801c836..cd343192a 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/datasource/CassandraDataSource.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/datasource/CassandraDataSource.java @@ -54,6 +54,7 @@ public class CassandraDataSource implements QueryStepDataSource { String astraEnvironment; String astraDatabase; + String astraDatabaseId; Map statements = new ConcurrentHashMap<>(); private static final DefaultCodecRegistry CODEC_REGISTRY = @@ -81,6 +82,7 @@ public void initialize(Map dataSourceConfig) { this.astraEnvironment = ConfigurationUtils.getString("environment", "PROD", dataSourceConfig); this.astraDatabase = ConfigurationUtils.getString("database", "", dataSourceConfig); + this.astraDatabaseId = ConfigurationUtils.getString("database-id", "", dataSourceConfig); this.session = buildCqlSession(dataSourceConfig); } @@ -227,10 +229,16 @@ private CqlSession buildCqlSession(Map dataSourceConfig) { secureBundleDecoded = Base64.getDecoder().decode(secureBundle); } else if (!astraDatabase.isEmpty() && !astraToken.isEmpty()) { log.info( - "Automatically downloading the secure bundle for database {} from AstraDB", + "Automatically downloading the secure bundle for database name {} from AstraDB", astraDatabase); DatabaseClient databaseClient = this.buildAstraClient(); secureBundleDecoded = downloadSecureBundle(databaseClient); + } else if (!astraDatabaseId.isEmpty() && !astraToken.isEmpty()) { + log.info( + "Automatically downloading the secure bundle for database id {} from AstraDB", + astraDatabaseId); + DatabaseClient databaseClient = this.buildAstraClient(); + secureBundleDecoded = downloadSecureBundle(databaseClient); } else { log.info("No secure bundle provided, using the default CQL driver for Cassandra"); } @@ -264,16 +272,24 @@ public CqlSession getSession() { } public DatabaseClient buildAstraClient() { - return buildAstraClient(astraToken, astraDatabase, astraEnvironment); + return buildAstraClient(astraToken, astraDatabase, astraDatabaseId, astraEnvironment); } public static DatabaseClient buildAstraClient( - String astraToken, String astraDatabase, String astraEnvironment) { - if (astraToken.isEmpty() || astraDatabase.isEmpty()) { - throw new IllegalArgumentException("You must configure both token and database"); + String astraToken, String astraDatabase, String astraDatabaseId, String astraEnvironment) { + if (astraToken.isEmpty()) { + throw new IllegalArgumentException("You must configure the AstraDB token"); + } + AstraDbClient astraDbClient = new AstraDbClient(astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment)); + if (!astraDatabase.isEmpty()) { + return astraDbClient + .databaseByName(astraDatabase); + } else if (!astraDatabaseId.isEmpty()) { + return astraDbClient + .database(astraDatabaseId); + } else { + throw new IllegalArgumentException("You must configure the database name or the database id"); } - return new AstraDbClient(astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment)) - .databaseByName(astraDatabase); } public static byte[] downloadSecureBundle(DatabaseClient databaseClient) { diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java index d3e41252c..33ca526f4 100644 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/cassandra/CassandraWriter.java @@ -106,23 +106,24 @@ public void initialise(Map agentConfiguration) { configuration.put( "cloud.secureConnectBundle", secureBundleString); } else { - // AstraDB, token/database + // AstraDB, token/database/databaseId String token = ConfigurationUtils.getString( "token", "", datasource); String database = ConfigurationUtils.getString( "database", "", datasource); + String databaseId = + ConfigurationUtils.getString( + "database-id", "", datasource); String environment = ConfigurationUtils.getString( "environment", "PROD", datasource); - if (!token.isEmpty() && !database.isEmpty()) { + if (!token.isEmpty() && (!database.isEmpty() || !databaseId.isEmpty())) { DatabaseClient databaseClient = CassandraDataSource.buildAstraClient( - token, database, environment); - log.info( - "Automatically downloading the secure bundle for database {} from AstraDB", - database); + token, database, databaseId, environment); + log.info("Automatically downloading the secure bundle from AstraDB"); byte[] secureBundle = CassandraDataSource.downloadSecureBundle( databaseClient); diff --git a/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java index b144e4190..d1331ae8d 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java +++ b/langstream-core/src/main/java/ai/langstream/impl/resources/datasource/AstraDatasourceConfig.java @@ -23,6 +23,8 @@ import ai.langstream.impl.uti.ClassConfigValidator; import java.util.Base64; import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; @Data @@ -93,6 +95,14 @@ Astra Token (AstraCS:xxx) for connecting to the database. If secureBundle is pro """) private String database; + @ConfigProperty( + description = + """ + Astra Database ID name to connect to. If secureBundle is provided, this field is ignored. + """) + @JsonProperty("database-id") + private String databaseId; + @ConfigProperty( description = """ diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PythonCodeAgentProvider.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PythonCodeAgentProvider.java index db9c0a9fd..54daf1493 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PythonCodeAgentProvider.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PythonCodeAgentProvider.java @@ -18,10 +18,19 @@ import ai.langstream.api.doc.AgentConfig; import ai.langstream.api.doc.ConfigProperty; import ai.langstream.api.model.AgentConfiguration; +import ai.langstream.api.model.Module; +import ai.langstream.api.model.Pipeline; +import ai.langstream.api.model.Resource; import ai.langstream.api.runtime.ComponentType; +import ai.langstream.api.runtime.ComputeClusterRuntime; +import ai.langstream.api.runtime.ExecutionPlan; +import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.impl.agents.AbstractComposableAgentProvider; import ai.langstream.runtime.impl.k8s.KubernetesClusterRuntime; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import lombok.extern.slf4j.Slf4j; @@ -40,6 +49,35 @@ public PythonCodeAgentProvider() { List.of(KubernetesClusterRuntime.CLUSTER_TYPE, "none")); } + @Override + protected Map computeAgentConfiguration(AgentConfiguration agentConfiguration, + Module module, + Pipeline pipeline, + ExecutionPlan executionPlan, + ComputeClusterRuntime clusterRuntime, + PluginsRegistry pluginsRegistry) { + Map copy = + super.computeAgentConfiguration(agentConfiguration, module, pipeline, executionPlan, clusterRuntime, pluginsRegistry); + + + Map resources = new HashMap<>(); + + // with this trick Python agents can access the resources configuration + Map resourcesDef = executionPlan.getApplication().getResources(); + if (resourcesDef != null) { + resourcesDef.forEach((key, r) -> { + String id = r.id(); + Map data = r.configuration(); + log.info("Passing resource configuration to Python agent: {} -> {}", id, data.keySet()); + resources.put(id, data); + }); + } + + copy.put("resources", resources); + + return copy; + } + @Override protected final ComponentType getComponentType(AgentConfiguration agentConfiguration) { return switch (agentConfiguration.getType()) {