Skip to content

Commit

Permalink
Putting a hold on this -- going to wait for the "init" command PR to …
Browse files Browse the repository at this point in the history
…be pushed.

TODO:
[] - tests for different sentence transformers models
[] - tests for embedding model use from OpenAI endpoint
  • Loading branch information
glennga committed Jan 30, 2025
1 parent f535ce3 commit c1eaac0
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 72 deletions.
41 changes: 31 additions & 10 deletions docs/source/env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Mandatory Environment Variables
The username of the account/database access key used to access your Couchbase cluster.

``AGENT_CATALOG_PASSWORD``
The password of the account/database access key used to access your Couchbase cluster
The password of the account/database access key used to access your Couchbase cluster.

``AGENT_CATALOG_BUCKET``
The name of the bucket where your catalog and all audit logs are/will be stored.
Expand All @@ -34,9 +34,8 @@ Optional Environment Variables

``AGENT_CATALOG_CONN_ROOT_CERTIFICATE``
Path to the `TLS <https://en.wikipedia.org/wiki/Transport_Layer_Security>`_ Root Certificate associated with your
Couchbase cluster for secure connection establishment.

Instructions for Couchbase Server certificates can be found `here <https://docs.couchbase.com/server/current/learn/security/certificates.html>`_.
Couchbase cluster.
More information about Couchbase Server certificates can be found `here <https://docs.couchbase.com/server/current/learn/security/certificates.html>`_.

``AGENT_CATALOG_ACTIVITY``
The location on your filesystem that denotes where the local audit logs are stored.
Expand Down Expand Up @@ -68,17 +67,39 @@ Optional Environment Variables
The location + filename of the audit logs that the :python:`agentc.Auditor` will write to.
By default, the :python:`agentc.Auditor` class will write and rotate logs in the :file:`./agent-activity` directory.

``AGENT_CATALOG_EMBEDDING_MODEL``
``AGENT_CATALOG_EMBEDDING_MODEL_NAME``
The embedding model that Agent Catalog will use when indexing and querying tools and prompts.
This *must* be a valid embedding model that is supported by the :python:`sentence_transformers.SentenceTransformer`
class.
class *or* the name of a model that can be used from the endpoint specified in the environment variable
``AGENT_CATALOG_EMBEDDING_MODEL_URL``.
By default, the ``sentence-transformers/all-MiniLM-L12-v2`` model is used.

``AGENT_CATALOG_EMBEDDING_MODEL_URL``
An OpenAI-standard client base URL whose ``/embeddings`` endpoint will be used to generate embeddings for Agent
Catalog tools and prompts.
The specified endpoint *must* host the embedding model given in ``AGENT_CATALOG_EMBEDDING_MODEL_NAME``.
If this variable is specified, Agent Catalog will assume the model given in ``AGENT_CATALOG_EMBEDDING_MODEL_NAME``
should be accessed through an OpenAI-standard interface.
This variable *must* be specified with ``AGENT_CATALOG_EMBEDDING_MODEL_AUTH``.
By default, this variable is not set (thus, a locally hosted SentenceTransformers is used).

``AGENT_CATALOG_EMBEDDING_MODEL_AUTH``
The field used in the authorization header of all OpenAI-standard client embedding requests.
For embedding models hosted by OpenAI, this field refers to the API key.
For embedding models hosted by Capella, this field refers to the Base64-encoded value of
``MY_USERNAME.MY_PASSWORD``.
If this variable is specified, Agent Catalog will assume the model given in ``AGENT_CATALOG_EMBEDDING_MODEL_NAME``
should be accessed through an OpenAI-standard interface.
This variable *must* be specified with ``AGENT_CATALOG_EMBEDDING_MODEL_URL``.
By default, this variable is not set (thus, a locally hosted SentenceTransformers is used).

``AGENT_CATALOG_INDEX_PARTITION``
Required for advanced vector index definition. This is an integer that defines the number of index partitions on your node.
If not set, this value is ``2 * number of nodes with 'search' service`` on your cluster.
The number of index partitions associated with your cluster.
This variable is used during the creation of vector indexes for semantic catalog search.
By default, this value is set to ``2 * number of nodes with 'search' service on your cluster``.
More information on index partitioning can be found `here <https://docs.couchbase.com/server/current/n1ql/n1ql-language-reference/index-partitioning.html>`_.

``AGENT_CATALOG_MAX_SOURCE_PARTITION``
Required for advanced vector index definition. This is an integer that defines the maximum number of source partitions.
If not set, this value is 1024.
The maximum number of source partitions associated with your cluster.
This variable is used during the creation of vector indexes for semantic catalog search.
By default, this value is set to 1024.
4 changes: 3 additions & 1 deletion libs/agentc_cli/agentc_cli/cmds/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def cmd_env(ctx: Context = None):
"AGENT_CATALOG_SNAPSHOT": os.getenv("AGENT_CATALOG_SNAPSHOT", LATEST_SNAPSHOT_VERSION),
"AGENT_CATALOG_PROVIDER_OUTPUT": os.getenv("AGENT_CATALOG_PROVIDER_OUTPUT", None),
"AGENT_CATALOG_AUDITOR_OUTPUT": os.getenv("AGENT_CATALOG_AUDITOR_OUTPUT", None),
"AGENT_CATALOG_EMBEDDING_MODEL": os.getenv("AGENT_CATALOG_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL),
"AGENT_CATALOG_EMBEDDING_MODEL_NAME": os.getenv("AGENT_CATALOG_EMBEDDING_MODEL_NAME", DEFAULT_EMBEDDING_MODEL),
"AGENT_CATALOG_EMBEDDING_MODEL_URL": os.getenv("AGENT_CATALOG_EMBEDDING_MODEL_URL", None),
"AGENT_CATALOG_EMBEDDING_MODEL_AUTH": os.getenv("AGENT_CATALOG_EMBEDDING_MODEL_AUTH", None),
}
for line in json.dumps(environment_dict, indent=4).split("\n"):
if re.match(r'\s*"AGENT_CATALOG_.*": (?!null)', line):
Expand Down
2 changes: 2 additions & 0 deletions libs/agentc_cli/agentc_cli/cmds/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def cmd_index(
source_dirs: list[str | os.PathLike],
kinds: list[typing.Literal["tool", "prompt"]],
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
embedding_model_url: str = None,
dry_run: bool = False,
ctx: Context = None,
**_,
Expand All @@ -54,6 +55,7 @@ def cmd_index(
repo, get_path_version = load_repository(pathlib.Path(os.getcwd()))
embedding_model = EmbeddingModel(
embedding_model_name=embedding_model_name,
embedding_model_url=embedding_model_url,
catalog_path=pathlib.Path(ctx.catalog),
)

Expand Down
18 changes: 13 additions & 5 deletions libs/agentc_cli/agentc_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def resolve_command(self, ctx, args):

@click.group(
cls=AliasedGroup,
epilog="See: https://docs.couchbase.com or https://couchbaselabs.github.io/agent-catalog/index.html# for more information.",
epilog="See: https://docs.couchbase.com or "
"https://couchbaselabs.github.io/agent-catalog/index.html# for more information.",
context_settings=dict(max_content_width=800),
)
@click.option(
Expand Down Expand Up @@ -450,9 +451,15 @@ def find(
)
@click.option(
"-em",
"--embedding-model",
"--embedding-model-name",
default=DEFAULT_EMBEDDING_MODEL,
help="Embedding model used when indexing source files into the local catalog.",
help="Name of the embedding model used when indexing source files into the local catalog.",
show_default=True,
)
@click.option(
"--embedding-model-url",
default=None,
help="Base URL of an OpenAI-standard endpoint that exposes an embedding model.",
show_default=True,
)
@click.option(
Expand All @@ -463,7 +470,7 @@ def find(
show_default=True,
)
@click.pass_context
def index(ctx, source_dirs, tools, prompts, embedding_model, dry_run):
def index(ctx, source_dirs, tools, prompts, embedding_model_name, embedding_model_url, dry_run):
"""Walk the source directory trees (SOURCE_DIRS) to index source files into the local catalog.
Source files that will be scanned include *.py, *.sqlpp, *.yaml, etc."""

Expand All @@ -488,7 +495,8 @@ def index(ctx, source_dirs, tools, prompts, embedding_model, dry_run):
ctx=ctx.obj,
source_dirs=source_dirs,
kinds=kinds,
embedding_model_name=embedding_model,
embedding_model_name=embedding_model_name,
embedding_model_url=embedding_model_url,
dry_run=dry_run,
)

Expand Down
7 changes: 3 additions & 4 deletions libs/agentc_core/agentc_core/catalog/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..tool.descriptor.models import SemanticSearchToolDescriptor
from ..tool.descriptor.models import SQLPPQueryToolDescriptor
from ..version import VersionDescriptor
from agentc_core.learned.model import EmbeddingModel


class CatalogKind(enum.StrEnum):
Expand Down Expand Up @@ -47,10 +48,8 @@ class CatalogDescriptor(pydantic.BaseModel):

kind: CatalogKind = pydantic.Field(description="The type of items within the catalog.")

embedding_model: str = pydantic.Field(
description="The sentence-transformers embedding model used to generate the vector representations "
"of each catalog entry.",
examples=["sentence-transformers/all-MiniLM-L12-v2"],
embedding_model: EmbeddingModel = pydantic.Field(
description="Embedding model used for all descriptions in the catalog.",
)

version: VersionDescriptor = pydantic.Field(
Expand Down
112 changes: 67 additions & 45 deletions libs/agentc_core/agentc_core/learned/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import couchbase.cluster
import couchbase.exceptions
import logging
import os
import pathlib
import pydantic
import typing
Expand All @@ -12,6 +13,7 @@
from agentc_core.defaults import DEFAULT_MODEL_CACHE_FOLDER
from agentc_core.defaults import DEFAULT_PROMPT_CATALOG_NAME
from agentc_core.defaults import DEFAULT_TOOL_CATALOG_NAME
from agentc_core.learned.model import EmbeddingModel as PydanticEmbeddingModel

logger = logging.getLogger(__name__)

Expand All @@ -21,6 +23,7 @@ class EmbeddingModel(pydantic.BaseModel):

# Embedding models are defined in three distinct ways: explicitly (by name)...
embedding_model_name: typing.Optional[str] = DEFAULT_EMBEDDING_MODEL
embedding_model_url: typing.Optional[str] = None

# ...or implicitly (by path)...
catalog_path: typing.Optional[pathlib.Path] = None
Expand All @@ -29,8 +32,8 @@ class EmbeddingModel(pydantic.BaseModel):
cb_bucket: typing.Optional[str] = None
cb_cluster: typing.Optional[couchbase.cluster.Cluster] = None

# The actual embedding model object (we won't type this to avoid the sentence transformers import).
_embedding_model: None
# The actual embedding model object (we won't type this to avoid the potential sentence transformers import).
_embedding_model: None = None

@pydantic.model_validator(mode="after")
def validate_bucket_cluster(self) -> "EmbeddingModel":
Expand All @@ -46,34 +49,32 @@ def validate_embedding_model(self) -> "EmbeddingModel":
if self.embedding_model_name is None and self.catalog_path is None and self.cb_cluster is None:
raise ValueError("embedding_model_name, catalog_path, or cb_cluster must be specified.")

from_catalog_embedding_model_name = None
from_catalog_embedding_model = None
if self.catalog_path is not None:
collected_embedding_model_names = set()
collected_embedding_models = set()

# Grab our local tool embedding model...
local_tool_catalog_path = self.catalog_path / DEFAULT_TOOL_CATALOG_NAME
if local_tool_catalog_path.exists():
with local_tool_catalog_path.open("r") as fp:
local_tool_catalog = CatalogDescriptor.model_validate_json(fp.read())
collected_embedding_model_names.add(local_tool_catalog.embedding_model)
collected_embedding_models.add(local_tool_catalog.embedding_model)

# ...and now our local prompt embedding model.
local_prompt_catalog_path = self.catalog_path / DEFAULT_PROMPT_CATALOG_NAME
if local_prompt_catalog_path.exists():
with local_prompt_catalog_path.open("r") as fp:
local_prompt_catalog = CatalogDescriptor.model_validate_json(fp.read())
collected_embedding_model_names.add(local_prompt_catalog.embedding_model)
collected_embedding_models.add(local_prompt_catalog.embedding_model)

if len(collected_embedding_model_names) > 1:
raise ValueError(
f"Multiple embedding models found in local catalogs: " f"{collected_embedding_model_names}"
)
elif len(collected_embedding_model_names) == 1:
from_catalog_embedding_model_name = collected_embedding_model_names.pop()
logger.debug("Found embedding model %s in local catalogs.", from_catalog_embedding_model_name)
if len(collected_embedding_models) > 1:
raise ValueError(f"Multiple embedding models found in local catalogs: " f"{collected_embedding_models}")
elif len(collected_embedding_models) == 1:
from_catalog_embedding_model = collected_embedding_models.pop()
logger.debug("Found embedding model %s in local catalogs.", from_catalog_embedding_model)

if self.cb_cluster is not None:
collected_embedding_model_names = set()
collected_embedding_models = set()

# TODO (GLENN): There is probably a cleaner way to do this (but this is Pythonic, so...).
union_subqueries = []
Expand All @@ -83,7 +84,8 @@ def validate_embedding_model(self) -> "EmbeddingModel":
f"`{self.cb_bucket}`.`{DEFAULT_CATALOG_SCOPE}`.`{kind}{DEFAULT_META_COLLECTION_NAME}`"
)
self.cb_cluster.query(f"""
FROM {qualified_collection_name} AS mc
FROM
{qualified_collection_name} AS mc
SELECT *
LIMIT 1
""").execute()
Expand Down Expand Up @@ -115,37 +117,32 @@ def validate_embedding_model(self) -> "EmbeddingModel":

if metadata_query is not None:
for row in metadata_query:
collected_embedding_model_names.add(row)
collected_embedding_models.add(PydanticEmbeddingModel(**row))

if len(collected_embedding_model_names) > 1:
if len(collected_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models found in remote catalogs: " f"{collected_embedding_model_names}"
f"Multiple embedding models found in remote catalogs: " f"{collected_embedding_models}"
)
elif len(collected_embedding_model_names) == 1:
remote_embedding_model_name = collected_embedding_model_names.pop()
logger.debug("Found embedding model %s in remote catalogs.", remote_embedding_model_name)
if (
from_catalog_embedding_model_name is not None
and from_catalog_embedding_model_name != remote_embedding_model_name
):
elif len(collected_embedding_models) == 1:
remote_embedding_model = collected_embedding_models.pop()
logger.debug("Found embedding model %s in remote catalogs.", remote_embedding_model)
if from_catalog_embedding_model is not None and from_catalog_embedding_model != remote_embedding_model:
raise ValueError(
f"Local embedding model {from_catalog_embedding_model_name} does not match "
f"remote embedding model {remote_embedding_model_name}!"
f"Local embedding model {from_catalog_embedding_model} does not match "
f"remote embedding model {remote_embedding_model}!"
)
elif from_catalog_embedding_model_name is None:
from_catalog_embedding_model_name = remote_embedding_model_name
elif from_catalog_embedding_model is None:
from_catalog_embedding_model = remote_embedding_model

if self.embedding_model_name is None:
self.embedding_model_name = from_catalog_embedding_model_name
elif (
from_catalog_embedding_model_name is not None
and self.embedding_model_name != from_catalog_embedding_model_name
):
self.embedding_model_name = from_catalog_embedding_model.name
self.embedding_model_url = from_catalog_embedding_model.base_url
elif from_catalog_embedding_model is not None and self.embedding_model_name != from_catalog_embedding_model:
raise ValueError(
f"Local embedding model {from_catalog_embedding_model_name} does not match "
f"Local embedding model {from_catalog_embedding_model.name} does not match "
f"specified embedding model {self.embedding_model_name}!"
)
elif self.embedding_model_name is None and from_catalog_embedding_model_name is None:
elif self.embedding_model_name is None and from_catalog_embedding_model is None:
raise ValueError("No embedding model found (run 'agentc index' to download one).")

# Note: we won't validate the embedding model name because sentence_transformers takes a while to import.
Expand All @@ -157,15 +154,40 @@ def name(self) -> str:
return self.embedding_model_name

def encode(self, text: str) -> list[float]:
# Lazily-load the embedding model.
if self._embedding_model is None:
import sentence_transformers
if self.embedding_model_name.startswith("https://") or self.embedding_model_name.startswith("http://"):
import openai

self._embedding_model = sentence_transformers.SentenceTransformer(
self.embedding_model_name,
tokenizer_kwargs={"clean_up_tokenization_spaces": True},
cache_folder=DEFAULT_MODEL_CACHE_FOLDER,
local_files_only=False,
)
open_ai_client = openai.OpenAI(
base_url=self.embedding_model_url, api_key=os.getenv("AGENT_CATALOG_EMBEDDING_MODEL_AUTH")
)

def _encode(_text: str) -> list[float]:
return (
open_ai_client.embeddings.create(
model=self.embedding_model_name, input=text, encoding_format="float"
)
.data[0]
.embedding
)

self._embedding_model = _encode

else:
import sentence_transformers

sentence_transformers_model = sentence_transformers.SentenceTransformer(
self.embedding_model_name,
tokenizer_kwargs={"clean_up_tokenization_spaces": True},
cache_folder=DEFAULT_MODEL_CACHE_FOLDER,
local_files_only=False,
)

def _encode(_text: str) -> list[float]:
return sentence_transformers_model.encode(_text, convert_to_tensor=False).tolist()

self._embedding_model = _encode

# Normalize embeddings to unit length (only dot-product is computed with Couchbase, so...).
return self._embedding_model.encode(text, normalize_embeddings=True).tolist()
# Invoke our model.
return self._embedding_model(text)
17 changes: 17 additions & 0 deletions libs/agentc_core/agentc_core/learned/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pydantic
import typing


class EmbeddingModel(pydantic.BaseModel):
kind: typing.Literal["sentence-transformers", "openai"] = pydantic.Field(
description="The type of embedding model being used."
)
name: str = pydantic.Field(
description="The name of the embedding model being used.",
examples=["all-MiniLM-L12-v2", "https://12fs345d.apps.cloud.couchbase.com"],
)
base_url: typing.Optional[str] = pydantic.Field(
description="The base URL of the embedding model."
"This field must be specified is using a non-SentenceTransformers-based model.",
examples=["https://12fs345d.apps.cloud.couchbase.com"],
)
Loading

0 comments on commit c1eaac0

Please sign in to comment.