Skip to content

Commit

Permalink
[python agents] Allow to access resources configured in configuration…
Browse files Browse the repository at this point in the history
….yaml
  • Loading branch information
eolivelli committed Nov 6, 2023
1 parent 6ca69e9 commit 69b6f5a
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 19 deletions.
3 changes: 3 additions & 0 deletions examples/applications/langchain-chat/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "openai"
name: "OpenAI Azure configuration"
configuration:
url: "${secrets.open-ai.url}"
access-key: "${secrets.open-ai.access-key}"
provider: "${secrets.open-ai.provider}"
- type: "datasource"
id: "database"
name: "AstraDatasource"
configuration:
service: "astra"
clientId: "${secrets.astra-langchain.clientId}"
secret: "${secrets.astra-langchain.secret}"
database: "${secrets.astra-langchain.database}"
database-id: "${secrets.astra-langchain.database-id}"
token: "${secrets.astra-langchain.token}"
environment: "PROD"
2 changes: 1 addition & 1 deletion examples/applications/langchain-chat/crawler.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pipeline:
resources:
size: 2
configuration:
datasource: "AstraDatasource"
datasource: "database"
table-name: "documents"
keyspace: "documents"
mapping: "row_id=value.row_id, body_blob=value.text, vector=value.embeddings_vector"
11 changes: 6 additions & 5 deletions examples/applications/langchain-chat/python/langchain_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,13 @@ def create_chain(
class LangChainChat(Processor):

def init(self, config):
self.openai_key = config.get("openai-key", "")
self.astra_db_token = config.get("astra-db-token", "")
self.astra_db_id = config.get("astra-db-id", "")
self.astra_keyspace = config.get("astra-db-keyspace", "")
self.astra_table_name = config.get("astra-db-table", "")

# the values are configured in the resources section in configuration.yaml
self.openai_key = config.resources.openai.get("access-key", "");
self.astra_db_token = config.resources.database.get("token", "")
self.astra_db_id = config.resources.database.get("database-id", "")
self.astra_keyspace = config.resources.database.get("keyspace", "")
self.astra_table_name = config.resources.database.get("table", "")

cassio.init(token=self.astra_db_token, database_id=self.astra_db_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class CassandraDataSource implements QueryStepDataSource {
String astraEnvironment;

String astraDatabase;
String astraDatabaseId;
Map<String, PreparedStatement> statements = new ConcurrentHashMap<>();

private static final DefaultCodecRegistry CODEC_REGISTRY =
Expand Down Expand Up @@ -81,6 +82,7 @@ public void initialize(Map<String, Object> dataSourceConfig) {
this.astraEnvironment =
ConfigurationUtils.getString("environment", "PROD", dataSourceConfig);
this.astraDatabase = ConfigurationUtils.getString("database", "", dataSourceConfig);
this.astraDatabaseId = ConfigurationUtils.getString("database-id", "", dataSourceConfig);
this.session = buildCqlSession(dataSourceConfig);
}

Expand Down Expand Up @@ -227,10 +229,16 @@ private CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {
secureBundleDecoded = Base64.getDecoder().decode(secureBundle);
} else if (!astraDatabase.isEmpty() && !astraToken.isEmpty()) {
log.info(
"Automatically downloading the secure bundle for database {} from AstraDB",
"Automatically downloading the secure bundle for database name {} from AstraDB",
astraDatabase);
DatabaseClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else if (!astraDatabaseId.isEmpty() && !astraToken.isEmpty()) {
log.info(
"Automatically downloading the secure bundle for database id {} from AstraDB",
astraDatabaseId);
DatabaseClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else {
log.info("No secure bundle provided, using the default CQL driver for Cassandra");
}
Expand Down Expand Up @@ -264,16 +272,24 @@ public CqlSession getSession() {
}

public DatabaseClient buildAstraClient() {
return buildAstraClient(astraToken, astraDatabase, astraEnvironment);
return buildAstraClient(astraToken, astraDatabase, astraDatabaseId, astraEnvironment);
}

public static DatabaseClient buildAstraClient(
String astraToken, String astraDatabase, String astraEnvironment) {
if (astraToken.isEmpty() || astraDatabase.isEmpty()) {
throw new IllegalArgumentException("You must configure both token and database");
String astraToken, String astraDatabase, String astraDatabaseId, String astraEnvironment) {
if (astraToken.isEmpty()) {
throw new IllegalArgumentException("You must configure the AstraDB token");
}
AstraDbClient astraDbClient = new AstraDbClient(astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment));
if (!astraDatabase.isEmpty()) {
return astraDbClient
.databaseByName(astraDatabase);
} else if (!astraDatabaseId.isEmpty()) {
return astraDbClient
.database(astraDatabaseId);
} else {
throw new IllegalArgumentException("You must configure the database name or the database id");
}
return new AstraDbClient(astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment))
.databaseByName(astraDatabase);
}

public static byte[] downloadSecureBundle(DatabaseClient databaseClient) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,24 @@ public void initialise(Map<String, Object> agentConfiguration) {
configuration.put(
"cloud.secureConnectBundle", secureBundleString);
} else {
// AstraDB, token/database
// AstraDB, token/database/databaseId
String token =
ConfigurationUtils.getString(
"token", "", datasource);
String database =
ConfigurationUtils.getString(
"database", "", datasource);
String databaseId =
ConfigurationUtils.getString(
"database-id", "", datasource);
String environment =
ConfigurationUtils.getString(
"environment", "PROD", datasource);
if (!token.isEmpty() && !database.isEmpty()) {
if (!token.isEmpty() && (!database.isEmpty() || !databaseId.isEmpty())) {
DatabaseClient databaseClient =
CassandraDataSource.buildAstraClient(
token, database, environment);
log.info(
"Automatically downloading the secure bundle for database {} from AstraDB",
database);
token, database, databaseId, environment);
log.info("Automatically downloading the secure bundle from AstraDB");
byte[] secureBundle =
CassandraDataSource.downloadSecureBundle(
databaseClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import ai.langstream.impl.uti.ClassConfigValidator;
import java.util.Base64;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;

@Data
Expand Down Expand Up @@ -93,6 +95,14 @@ Astra Token (AstraCS:xxx) for connecting to the database. If secureBundle is pro
""")
private String database;

@ConfigProperty(
description =
"""
Astra Database ID name to connect to. If secureBundle is provided, this field is ignored.
""")
@JsonProperty("database-id")
private String databaseId;

@ConfigProperty(
description =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
import ai.langstream.api.doc.AgentConfig;
import ai.langstream.api.doc.ConfigProperty;
import ai.langstream.api.model.AgentConfiguration;
import ai.langstream.api.model.Module;
import ai.langstream.api.model.Pipeline;
import ai.langstream.api.model.Resource;
import ai.langstream.api.runtime.ComponentType;
import ai.langstream.api.runtime.ComputeClusterRuntime;
import ai.langstream.api.runtime.ExecutionPlan;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.impl.agents.AbstractComposableAgentProvider;
import ai.langstream.runtime.impl.k8s.KubernetesClusterRuntime;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -40,6 +49,35 @@ public PythonCodeAgentProvider() {
List.of(KubernetesClusterRuntime.CLUSTER_TYPE, "none"));
}

@Override
protected Map<String, Object> computeAgentConfiguration(AgentConfiguration agentConfiguration,
Module module,
Pipeline pipeline,
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
Map<String, Object> copy =
super.computeAgentConfiguration(agentConfiguration, module, pipeline, executionPlan, clusterRuntime, pluginsRegistry);


Map<String, Object> resources = new HashMap<>();

// with this trick Python agents can access the resources configuration
Map<String, Resource> resourcesDef = executionPlan.getApplication().getResources();
if (resourcesDef != null) {
resourcesDef.forEach((key, r) -> {
String id = r.id();
Map<String, Object> data = r.configuration();
log.info("Passing resource configuration to Python agent: {} -> {}", id, data.keySet());
resources.put(id, data);
});
}

copy.put("resources", resources);

return copy;
}

@Override
protected final ComponentType getComponentType(AgentConfiguration agentConfiguration) {
return switch (agentConfiguration.getType()) {
Expand Down

0 comments on commit 69b6f5a

Please sign in to comment.