From 2c816ede963ca7933413bf341e02910a0670b8d2 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Fri, 27 Dec 2024 00:02:09 -0500 Subject: [PATCH 01/14] Create VectorSearchRetrieverTool class for OpenAI --- .../src/databricks_langchain/utils.py | 57 ------------ .../vector_search_retriever_tool.py | 69 ++------------ integrations/openai/README.md | 0 integrations/openai/__init__.py | 0 integrations/openai/pyproject.toml | 71 ++++++++++++++ integrations/openai/src/__init__.py | 0 .../openai/src/databricks_openai/__init__.py | 6 ++ .../vector_search_retriever_tool.py | 93 +++++++++++++++++++ .../utils/vector_search.py | 55 +++++++++++ .../vector_search_retriever_tool.py | 68 ++++++++++++++ 10 files changed, 303 insertions(+), 116 deletions(-) create mode 100644 integrations/openai/README.md create mode 100644 integrations/openai/__init__.py create mode 100644 integrations/openai/pyproject.toml create mode 100644 integrations/openai/src/__init__.py create mode 100644 integrations/openai/src/databricks_openai/__init__.py create mode 100644 integrations/openai/src/databricks_openai/vector_search_retriever_tool.py create mode 100644 src/databricks_ai_bridge/utils/vector_search.py create mode 100644 src/databricks_ai_bridge/vector_search_retriever_tool.py diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index 8218ab9..b35359f 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -97,60 +97,3 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity - - -class IndexType(str, Enum): - DIRECT_ACCESS = "DIRECT_ACCESS" - DELTA_SYNC = "DELTA_SYNC" - - -class IndexDetails: - """An utility class to store the configuration details of an index.""" - - def __init__(self, index: Any): - self._index_details = index.describe() - - @property - def name(self) -> str: - return self._index_details["name"] - - @property - def schema(self) -> Optional[Dict]: - if self.is_direct_access_index(): - schema_json = self.index_spec.get("schema_json") - if schema_json is not None: - return json.loads(schema_json) - return None - - @property - def primary_key(self) -> str: - return self._index_details["primary_key"] - - @property - def index_spec(self) -> Dict: - return ( - self._index_details.get("delta_sync_index_spec", {}) - if self.is_delta_sync_index() - else self._index_details.get("direct_access_index_spec", {}) - ) - - @property - def embedding_vector_column(self) -> Dict: - if vector_columns := self.index_spec.get("embedding_vector_columns"): - return vector_columns[0] - return {} - - @property - def embedding_source_column(self) -> Dict: - if source_columns := self.index_spec.get("embedding_source_columns"): - return source_columns[0] - return {} - - def is_delta_sync_index(self) -> bool: - return self._index_details["index_type"] == IndexType.DELTA_SYNC.value - - def is_direct_access_index(self) -> bool: - return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value - - def is_databricks_managed_embeddings(self) -> bool: - return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index e21d5de..62c21ee 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -4,56 +4,31 @@ from langchain_core.tools import BaseTool from pydantic import BaseModel, Field, PrivateAttr, model_validator -from databricks_langchain.utils import IndexDetails from databricks_langchain.vectorstores import DatabricksVectorSearch - -class VectorSearchRetrieverToolInput(BaseModel): - query: str = Field( - description="The string used to query the index with and identify the most similar " - "vectors and return the associated documents." - ) +from databricks_ai_bridge.vector_search_retriever_tool import BaseVectorSearchRetrieverTool, VectorSearchRetrieverToolInput -class VectorSearchRetrieverTool(BaseTool): +class VectorSearchRetrieverTool(BaseTool, BaseVectorSearchRetrieverTool): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. - This class integrates with a Databricks Vector Search and provides a convenient interface + This class integrates with Databricks Vector Search and provides a convenient interface for building a retriever tool for agents. """ - index_name: str = Field( - ..., description="The name of the index to use, format: 'catalog.schema.index'." - ) - num_results: int = Field(10, description="The number of results to return.") - columns: Optional[List[str]] = Field( - None, description="Columns to return when doing the search." - ) - filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") - query_type: str = Field( - "ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'." - ) - tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") - tool_description: Optional[str] = Field(None, description="A description of the tool.") - text_column: Optional[str] = Field( - None, - description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", - ) - embedding: Optional[Embeddings] = Field( - None, description="Embedding model for self-managed embeddings." - ) - # The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() name: str = Field(default="", description="The name of the tool") description: str = Field(default="", description="The description of the tool") args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput + embedding: Optional[Embeddings] = Field( + None, description="Embedding model for self-managed embeddings." + ) + _vector_store: DatabricksVectorSearch = PrivateAttr() @model_validator(mode="after") - def validate_tool_inputs(self): + def _validate_tool_inputs(self): kwargs = { "index_name": self.index_name, "embedding": self.embedding, @@ -63,36 +38,12 @@ def validate_tool_inputs(self): dbvs = DatabricksVectorSearch(**kwargs) self._vector_store = dbvs - def get_tool_description(): - default_tool_description = ( - "A vector search-based retrieval tool for querying indexed embeddings." - ) - index_details = IndexDetails(dbvs.index) - if index_details.is_delta_sync_index(): - from databricks.sdk import WorkspaceClient - - source_table = index_details.index_spec.get("source_table", "") - w = WorkspaceClient() - source_table_comment = w.tables.get(full_name=source_table).comment - if source_table_comment: - return ( - default_tool_description - + f" The queried index uses the source table {source_table} with the description: " - + source_table_comment - ) - else: - return ( - default_tool_description - + f" The queried index uses the source table {source_table}" - ) - return default_tool_description - self.name = self.tool_name or self.index_name - self.description = self.tool_description or get_tool_description() + self.description = self.tool_description or self._get_default_tool_description(dbvs) return self def _run(self, query: str) -> str: return self._vector_store.similarity_search( query, k=self.num_results, filter=self.filters, query_type=self.query_type - ) + ) \ No newline at end of file diff --git a/integrations/openai/README.md b/integrations/openai/README.md new file mode 100644 index 0000000..e69de29 diff --git a/integrations/openai/__init__.py b/integrations/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml new file mode 100644 index 0000000..b70f1d9 --- /dev/null +++ b/integrations/openai/pyproject.toml @@ -0,0 +1,71 @@ +[project] +name = "databricks-openai" +version = "0.1.0" +description = "Support for Databricks AI support with OpenAI" +authors = [ + { name="Leon Bi", email="leon.bi@databricks.com" }, +] +readme = "README.md" +license = { text="Apache-2.0" } +requires-python = ">=3.9" +dependencies = [ + "databricks-vectorsearch>=0.40", + "databricks-ai-bridge>=0.1.0", + "openai>=1.46.1", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "typing_extensions", + "databricks-sdk>=0.34.0", + "ruff==0.6.4", +] + +integration = [ + "pytest-timeout>=2.3.1", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "src/databricks_openai/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks_openai"] + +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +select = [ + # isort + "I", + # bugbear rules + "B", + # remove unused imports + "F401", + # bare except statements + "E722", + # print statements + "T201", + "T203", + # misuse of typing.TYPE_CHECKING + "TCH004", + # import rules + "TID251", + # undefined-local-with-import-star + "F403", +] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 88 + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/integrations/openai/src/__init__.py b/integrations/openai/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integrations/openai/src/databricks_openai/__init__.py b/integrations/openai/src/databricks_openai/__init__.py new file mode 100644 index 0000000..015814c --- /dev/null +++ b/integrations/openai/src/databricks_openai/__init__.py @@ -0,0 +1,6 @@ +from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool + +# Expose all integrations to users under databricks-langchain +__all__ = [ + "VectorSearchRetrieverTool", +] diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py new file mode 100644 index 0000000..7c7763e --- /dev/null +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -0,0 +1,93 @@ +from typing import Any, Dict, List, Optional, Type + +from openai import pydantic_function_tool +from openai.types import Embeddings +from openai.types.chat import ChatCompletionToolParam +from openai.types.chat import ChatCompletion + +from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model + +from databricks_ai_bridge.vector_search_retriever_tool import BaseVectorSearchRetrieverTool, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION +from databricks_ai_bridge.vectorstores import IndexDetails + +class VectorSearchRetrieverTool(BaseVectorSearchRetrieverTool): + """ + A utility class to create a vector search-based retrieval tool for querying indexed embeddings. + This class integrates with Databricks Vector Search and provides a convenient interface + for tool calling using the OpenAI SDK. + + Example: + dbvs_tool = VectorSearchRetrieverTool("index_name") + tools = [dbvs_tool.tool, ...] + response = openai.chat.completions.create( + model="gpt-4o", + messages=initial_messages, + tools=tools, + ) + new_messages = dbvs_tool.execute_retriever_call(response) + final_response = openai.chat.completions.create( + model="gpt-4o", + messages=initial_messages + new_messages, + tools=tools, + ) + final_response.choices[0].message.content + """ + + embedding: Optional[Embeddings] = Field( + None, description="Embedding model for self-managed embeddings." + ) + tool: ChatCompletionToolParam = Field( + ..., description="The tool input used in the OpenAI chat completion SDK" + ) + _vector_store: DatabricksVectorSearch = PrivateAttr() + + @model_validator(mode="after") + def _validate_tool_inputs(self): + kwargs = { + "index_name": self.index_name, + "embedding": self.embedding, + "text_column": self.text_column, + "columns": self.columns, + } + dbvs = DatabricksVectorSearch(**kwargs) + self._vector_store = dbvs + self.tool = pydantic_function_tool( + VectorSearchRetrieverToolInput, + name=self.tool_name or self.index_name, + description=self.tool_description or self._get_default_tool_description(), + ) + return self + + def execute_retriever_call(self, + response: ChatCompletion, + choice_index: int = 0) -> List[Dict[str, Any]]: + """ + Generate tool call messages from the response. + + Args: + response: The chat completion response object returned by the OpenAI API. + choice_index: The index of the choice to process. Defaults to 0. Note that multiple + choices are not supported yet. + + Returns: + A list of messages containing the assistant message and the function call results. + """ + pass + + # client = validate_or_set_default_client(client) + # message = response.choices[choice_index].message + # tool_calls = message.tool_calls + # function_calls = [] + # if tool_calls: + # for tool_call in tool_calls: + # arguments = json.loads(tool_call.function.arguments) + # func_name = construct_original_function_name(tool_call.function.name) + # result = client.execute_function(func_name, arguments) + # function_call_result_message = { + # "role": "tool", + # "content": json.dumps({"content": result.value}), + # "tool_call_id": tool_call.id, + # } + # function_calls.append(function_call_result_message) + # assistant_message = message.to_dict() + # return [assistant_message, *function_calls] diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py new file mode 100644 index 0000000..38a6946 --- /dev/null +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -0,0 +1,55 @@ +class IndexType(str, Enum): + DIRECT_ACCESS = "DIRECT_ACCESS" + DELTA_SYNC = "DELTA_SYNC" + + +class IndexDetails: + """An utility class to store the configuration details of an index.""" + + def __init__(self, index: Any): + self._index_details = index.describe() + + @property + def name(self) -> str: + return self._index_details["name"] + + @property + def schema(self) -> Optional[Dict]: + if self.is_direct_access_index(): + schema_json = self.index_spec.get("schema_json") + if schema_json is not None: + return json.loads(schema_json) + return None + + @property + def primary_key(self) -> str: + return self._index_details["primary_key"] + + @property + def index_spec(self) -> Dict: + return ( + self._index_details.get("delta_sync_index_spec", {}) + if self.is_delta_sync_index() + else self._index_details.get("direct_access_index_spec", {}) + ) + + @property + def embedding_vector_column(self) -> Dict: + if vector_columns := self.index_spec.get("embedding_vector_columns"): + return vector_columns[0] + return {} + + @property + def embedding_source_column(self) -> Dict: + if source_columns := self.index_spec.get("embedding_source_columns"): + return source_columns[0] + return {} + + def is_delta_sync_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DELTA_SYNC.value + + def is_direct_access_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value + + def is_databricks_managed_embeddings(self) -> bool: + return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py new file mode 100644 index 0000000..23aa1d1 --- /dev/null +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, List, Optional +from abc import abstractmethod +from pydantic import BaseModel, Field +from databricks_ai_bridge.utils.vector_search import IndexDetails + +DEFAULT_TOOL_DESCRIPTION = "A vector search-based retrieval tool for querying indexed embeddings." + +class VectorSearchRetrieverToolInput(BaseModel): + query: str = Field( + description="The string used to query the index with and identify the most similar " + "vectors and return the associated documents." + ) + +class BaseVectorSearchRetrieverTool(BaseModel): + """ + Abstract base class for Databricks Vector Search retrieval tools. + This class provides the common structure and interface that framework-specific + implementations should follow. + """ + + index_name: str = Field( + ..., description="The name of the index to use, format: 'catalog.schema.index'." + ) + num_results: int = Field(10, description="The number of results to return.") + columns: Optional[List[str]] = Field( + None, description="Columns to return when doing the search." + ) + filters: Optional[Dict[str, Any]] = Field( + None, description="Filters to apply to the search." + ) + query_type: str = Field( + "ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'." + ) + tool_name: Optional[str] = Field( + None, description="The name of the retrieval tool." + ) + tool_description: Optional[str] = Field( + None, description="A description of the tool." + ) + text_column: Optional[str] = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", + ) + # TODO see if we can make an abstract Field here for embeddings + + # TODO: figure dbvs type (prob base class of DatabricksVectorSearch) + def _get_default_tool_description(self, dbvs: Any ) -> str: + index_details = IndexDetails(dbvs.index) + if index_details.is_delta_sync_index(): + from databricks.sdk import WorkspaceClient + + source_table = index_details.index_spec.get("source_table", "") + w = WorkspaceClient() + source_table_comment = w.tables.get(full_name=source_table).comment + if source_table_comment: + return ( + DEFAULT_TOOL_DESCRIPTION + + f" The queried index uses the source table {source_table} with the description: " + + source_table_comment + ) + else: + return ( + DEFAULT_TOOL_DESCRIPTION + + f" The queried index uses the source table {source_table}" + ) + return DEFAULT_TOOL_DESCRIPTION From c27a66521e65e4c05d8236c25ce36d03244e8c65 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Sun, 29 Dec 2024 04:03:21 -0800 Subject: [PATCH 02/14] Intermediate commit --- .../src/databricks_langchain/utils.py | 2 - .../vector_search_retriever_tool.py | 13 +- .../src/databricks_langchain/vectorstores.py | 44 +----- .../vector_search_retriever_tool.py | 134 ++++++++++++------ .../utils/vector_search.py | 48 ++++++- .../vector_search_retriever_tool.py | 16 +-- 6 files changed, 160 insertions(+), 97 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index b35359f..750d26c 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -1,5 +1,3 @@ -import json -from enum import Enum from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 62c21ee..020c079 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -6,10 +6,11 @@ from databricks_langchain.vectorstores import DatabricksVectorSearch -from databricks_ai_bridge.vector_search_retriever_tool import BaseVectorSearchRetrieverTool, VectorSearchRetrieverToolInput +from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput +from databricks_ai_bridge.utils.vector_search import IndexDetails -class VectorSearchRetrieverTool(BaseTool, BaseVectorSearchRetrieverTool): +class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. This class integrates with Databricks Vector Search and provides a convenient interface @@ -24,6 +25,12 @@ class VectorSearchRetrieverTool(BaseTool, BaseVectorSearchRetrieverTool): embedding: Optional[Embeddings] = Field( None, description="Embedding model for self-managed embeddings." ) + text_column: Optional[str] = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", + ) _vector_store: DatabricksVectorSearch = PrivateAttr() @@ -39,7 +46,7 @@ def _validate_tool_inputs(self): self._vector_store = dbvs self.name = self.tool_name or self.index_name - self.description = self.tool_description or self._get_default_tool_description(dbvs) + self.description = self.tool_description or self._get_default_tool_description(IndexDetails(dbvs.index)) return self diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index db7f04c..e94f540 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -21,7 +21,8 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore -from databricks_langchain.utils import IndexDetails, maximal_marginal_relevance +from databricks_langchain.utils import maximal_marginal_relevance +from databricks_ai_bridge.utils.vector_search import IndexDetails, DatabricksVectorSearchMixin, _validate_and_get_text_column logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ _INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") -class DatabricksVectorSearch(VectorStore): +class DatabricksVectorSearch(VectorStore, DatabricksVectorSearchMixin): """Databricks vector store integration. Setup: @@ -690,43 +691,8 @@ async def amax_marginal_relevance_search_by_vector( def _parse_search_response( self, search_resp: Dict, ignore_cols: Optional[List[str]] = None ) -> List[Tuple[Document, float]]: - """Parse the search response into a list of Documents with score.""" - if ignore_cols is None: - ignore_cols = [] - - columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] - docs_with_score = [] - for result in search_resp.get("result", dict()).get("data_array", []): - doc_id = result[columns.index(self._primary_key)] - text_content = result[columns.index(self._text_column)] - ignore_cols = [self._primary_key, self._text_column] + ignore_cols - metadata = { - col: value - for col, value in zip(columns[:-1], result[:-1]) - if col not in ignore_cols - } - metadata[self._primary_key] = doc_id - score = result[-1] - doc = Document(page_content=text_content, metadata=metadata) - docs_with_score.append((doc, score)) - return docs_with_score - - -def _validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: - if index_details.is_databricks_managed_embeddings(): - index_source_column: str = index_details.embedding_source_column["name"] - # check if input text column matches the source column of the index - if text_column is not None: - raise ValueError( - f"The index '{index_details.name}' has the source column configured as " - f"'{index_source_column}'. Do not pass the `text_column` parameter." - ) - return index_source_column - else: - if text_column is None: - raise ValueError("The `text_column` parameter is required for this index.") - return text_column - + """Parse the search response into a list of Documents with score using VectorSearchRetrieverToolMixin function.""" + return self.parse_vector_search_response(search_resp, ignore_cols=ignore_cols, document_class=Document) def _validate_and_get_return_columns( columns: List[str], text_column: str, index_details: IndexDetails diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 7c7763e..62651ee 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,16 +1,17 @@ from typing import Any, Dict, List, Optional, Type from openai import pydantic_function_tool -from openai.types import Embeddings from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletion from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model -from databricks_ai_bridge.vector_search_retriever_tool import BaseVectorSearchRetrieverTool, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION -from databricks_ai_bridge.vectorstores import IndexDetails +from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION +from databricks_ai_bridge.utils.vector_search import IndexDetails, DatabricksVectorSearchMixin, _validate_and_get_text_column +from databricks.vector_search.client import VectorSearchClient, VectorSearchIndex +import json -class VectorSearchRetrieverTool(BaseVectorSearchRetrieverTool): +class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin, DatabricksVectorSearchMixin): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. This class integrates with Databricks Vector Search and provides a convenient interface @@ -24,45 +25,56 @@ class VectorSearchRetrieverTool(BaseVectorSearchRetrieverTool): messages=initial_messages, tools=tools, ) - new_messages = dbvs_tool.execute_retriever_call(response) + retriever_call_message = dbvs_tool.execute_retriever_calls(response) + + ### If needed, execute potential remaining tool calls here ### + remaining_tool_call_messages = execute_remaining_tool_calls(response) + final_response = openai.chat.completions.create( model="gpt-4o", - messages=initial_messages + new_messages, + messages=initial_messages + retriever_call_message + remaining_tool_call_messages tools=tools, ) final_response.choices[0].message.content """ embedding: Optional[Embeddings] = Field( - None, description="Embedding model for self-managed embeddings." + None, description="Embedding model for self-managed embeddings. Used for direct " + "access indexes or delta-sync indexes with self-managed embeddings" + ) + text_column: Optional[str] = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings. Used for direct access indexes or " + "delta-sync indexes with self-managed embeddings", ) + tool: ChatCompletionToolParam = Field( ..., description="The tool input used in the OpenAI chat completion SDK" ) - _vector_store: DatabricksVectorSearch = PrivateAttr() + _index: VectorSearchIndex = PrivateAttr() + _index_details: IndexDetails = PrivateAttr() @model_validator(mode="after") def _validate_tool_inputs(self): - kwargs = { - "index_name": self.index_name, - "embedding": self.embedding, - "text_column": self.text_column, - "columns": self.columns, - } - dbvs = DatabricksVectorSearch(**kwargs) - self._vector_store = dbvs - self.tool = pydantic_function_tool( + self._index = VectorSearchClient().get_index(index_name=self.index_name) + self._index_details = IndexDetails(self._index) + _validate_and_get_text_column(self.text_column, self._index_details) + + self.tool: ChatCompletionToolParam = pydantic_function_tool( VectorSearchRetrieverToolInput, name=self.tool_name or self.index_name, - description=self.tool_description or self._get_default_tool_description(), + description=self.tool_description or self._get_default_tool_description(self._index_details) ) return self - def execute_retriever_call(self, + def execute_retriever_calls(self, response: ChatCompletion, choice_index: int = 0) -> List[Dict[str, Any]]: """ - Generate tool call messages from the response. + Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. Args: response: The chat completion response object returned by the OpenAI API. @@ -70,24 +82,66 @@ def execute_retriever_call(self, choices are not supported yet. Returns: - A list of messages containing the assistant message and the function call results. + A list of messages containing the assistant message and the retriever call results + that correspond to the self.tool VectorSearchRetrieverToolInput. """ - pass - - # client = validate_or_set_default_client(client) - # message = response.choices[choice_index].message - # tool_calls = message.tool_calls - # function_calls = [] - # if tool_calls: - # for tool_call in tool_calls: - # arguments = json.loads(tool_call.function.arguments) - # func_name = construct_original_function_name(tool_call.function.name) - # result = client.execute_function(func_name, arguments) - # function_call_result_message = { - # "role": "tool", - # "content": json.dumps({"content": result.value}), - # "tool_call_id": tool_call.id, - # } - # function_calls.append(function_call_result_message) - # assistant_message = message.to_dict() - # return [assistant_message, *function_calls] + + class Document: + def __init__(self, page_content, metadata): + self.page_content = page_content + self.metadata = metadata + + def get_query_text_vector(query): + if self._index_details.is_databricks_managed_embeddings(): + return query, None + + # For non-Databricks-managed embeddings + text = query if query_type and query_type.upper() == "HYBRID" else None + vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + return text, vector + + def do_arguments_match(llm_call_arguments: Dict[str, Any]): + # TODO: Remove print statement + print("llm_call_arguments_json: ", llm_call_arguments) + print("retriever_tool_params: ", self.tool.function.parameters) + retriever_tool_params: Dict[str, Any] = self.tool.function.parameters + return set(llm_call_arguments.keys()) == set(retriever_tool_params.keys()) + + message = response.choices[choice_index].message + llm_tool_calls = message.tool_calls + function_calls = [] + if llm_tool_calls: + for llm_tool_call in llm_tool_calls: + # Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput + llm_call_arguments_json: Dict[str, Any] = json.loads(llm_tool_call.function.arguments) + if llm_tool_call.function.name != self.tool.function.name \ + or not do_arguments_match(llm_call_arguments_json): + continue + + query_text, query_vector = get_query_text_vector(llm_call_arguments_json["query"]) + search_resp = self._index.similarity_search( + columns=self._columns, + query_text=query_text, + query_vector=query_vector, + filters=filter, + num_results=k, + query_type=query_type, + ) + embedding_column = self._index_details.embedding_vector_column["name"] + # Don't show the embedding column in the response + ignore_cols: List = [embedding_column] if embedding_column not in self.columns else [] + docs_with_score: List[Tuple[Document, float]] = \ + self.parse_vector_search_response( # from DatabricksVectorSearchMixin + search_resp, + ignore_cols=ignore_cols, + document_class=Document + ) + + function_call_result_message = { + "role": "tool", + "content": json.dumps({"content": docs_with_score}), + "tool_call_id": llm_tool_call.id, + } + function_calls.append(function_call_result_message) + assistant_message = message.to_dict() + return [assistant_message, *function_calls] diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py index 38a6946..ddd2867 100644 --- a/src/databricks_ai_bridge/utils/vector_search.py +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -1,8 +1,11 @@ +import json +from enum import Enum +from typing import Any, Dict, Optional + class IndexType(str, Enum): DIRECT_ACCESS = "DIRECT_ACCESS" DELTA_SYNC = "DELTA_SYNC" - class IndexDetails: """An utility class to store the configuration details of an index.""" @@ -53,3 +56,46 @@ def is_direct_access_index(self) -> bool: def is_databricks_managed_embeddings(self) -> bool: return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None + +class DatabricksVectorSearchMixin: + def parse_vector_search_response( + self, search_resp: Dict, ignore_cols: Optional[List[str]] = None, document_class: Any = dict + ) -> List[Tuple[Dict, float]]: + """ + Parse the search response into a list of Documents with score. + The document_class parameter is used to specify the class of the document to be created. + """ + if ignore_cols is None: + ignore_cols = [] + + columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] + docs_with_score = [] + for result in search_resp.get("result", dict()).get("data_array", []): + doc_id = result[columns.index(self._index_details.primary_key)] + text_content = result[columns.index(self._text_column)] + ignore_cols = [self._primary_key, self._text_column] + ignore_cols + metadata = { + col: value + for col, value in zip(columns[:-1], result[:-1]) + if col not in ignore_cols + } + metadata[self._primary_key] = doc_id + score = result[-1] + doc = document_class(page_content=text_content, metadata=metadata) + docs_with_score.append((doc, score)) + return docs_with_score + +def _validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: + if index_details.is_databricks_managed_embeddings(): + index_source_column: str = index_details.embedding_source_column["name"] + # check if input text column matches the source column of the index + if text_column is not None: + raise ValueError( + f"The index '{index_details.name}' has the source column configured as " + f"'{index_source_column}'. Do not pass the `text_column` parameter." + ) + return index_source_column + else: + if text_column is None: + raise ValueError("The `text_column` parameter is required for this index.") + return text_column \ No newline at end of file diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 23aa1d1..4057df9 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -2,6 +2,7 @@ from abc import abstractmethod from pydantic import BaseModel, Field from databricks_ai_bridge.utils.vector_search import IndexDetails +from databricks.vector_search.client import VectorSearchIndex DEFAULT_TOOL_DESCRIPTION = "A vector search-based retrieval tool for querying indexed embeddings." @@ -11,9 +12,9 @@ class VectorSearchRetrieverToolInput(BaseModel): "vectors and return the associated documents." ) -class BaseVectorSearchRetrieverTool(BaseModel): +class VectorSearchRetrieverToolMixin(BaseModel): """ - Abstract base class for Databricks Vector Search retrieval tools. + Mixin class for Databricks Vector Search retrieval tools. This class provides the common structure and interface that framework-specific implementations should follow. """ @@ -37,17 +38,8 @@ class BaseVectorSearchRetrieverTool(BaseModel): tool_description: Optional[str] = Field( None, description="A description of the tool." ) - text_column: Optional[str] = Field( - None, - description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", - ) - # TODO see if we can make an abstract Field here for embeddings - # TODO: figure dbvs type (prob base class of DatabricksVectorSearch) - def _get_default_tool_description(self, dbvs: Any ) -> str: - index_details = IndexDetails(dbvs.index) + def _get_default_tool_description(self, index_details: IndexDetails) -> str: if index_details.is_delta_sync_index(): from databricks.sdk import WorkspaceClient From db697823a2d2ccd4ddc1d3c76d07e95de30b1f5d Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Sun, 29 Dec 2024 22:04:11 -0800 Subject: [PATCH 03/14] Initial implementation --- .../src/databricks_langchain/vectorstores.py | 59 ++++------ .../vector_search_retriever_tool.py | 66 +++++++----- .../test_vector_search_retriever_tool.py | 102 ++++++++++++++++++ .../utils/vector_search.py | 89 +++++++++------ tests/utils/vector_search.py | 0 5 files changed, 224 insertions(+), 92 deletions(-) create mode 100644 integrations/openai/tests/test_vector_search_retriever_tool.py create mode 100644 tests/utils/vector_search.py diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index e94f540..ab2f029 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -22,7 +22,7 @@ from langchain_core.vectorstores import VST, VectorStore from databricks_langchain.utils import maximal_marginal_relevance -from databricks_ai_bridge.utils.vector_search import IndexDetails, DatabricksVectorSearchMixin, _validate_and_get_text_column +from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ _INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") -class DatabricksVectorSearch(VectorStore, DatabricksVectorSearchMixin): +class DatabricksVectorSearch(VectorStore): """Databricks vector store integration. Setup: @@ -247,8 +247,8 @@ def __init__( _validate_embedding(embedding, self._index_details) self._embeddings = embedding - self._text_column = _validate_and_get_text_column(text_column, self._index_details) - self._columns = _validate_and_get_return_columns( + self._text_column = validate_and_get_text_column(text_column, self._index_details) + self._columns = validate_and_get_return_columns( columns or [], self._text_column, self._index_details ) self._primary_key = self._index_details.primary_key @@ -431,7 +431,12 @@ def similarity_search_with_score( num_results=k, query_type=query_type, ) - return self._parse_search_response(search_resp) + return parse_vector_search_response( + search_resp, + self._index_details, + self._text_column, + document_class=Document + ) def _select_relevance_score_fn(self) -> Callable[[float], float]: """ @@ -541,7 +546,12 @@ def similarity_search_by_vector_with_score( num_results=k, query_type=query_type, ) - return self._parse_search_response(search_resp) + return parse_vector_search_response( + search_resp, + self._index_details, + self._text_column, + document_class=Document + ) def max_marginal_relevance_search( self, @@ -674,7 +684,13 @@ def max_marginal_relevance_search_by_vector( ) ignore_cols: List = [embedding_column] if embedding_column not in self._columns else [] - candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) + candidates = parse_vector_search_response( + search_resp, + self._index_details, + self._text_column, + ignore_cols=ignore_cols, + document_class=Document + ) selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] return selected_results @@ -688,35 +704,6 @@ async def amax_marginal_relevance_search_by_vector( ) -> List[Document]: raise NotImplementedError - def _parse_search_response( - self, search_resp: Dict, ignore_cols: Optional[List[str]] = None - ) -> List[Tuple[Document, float]]: - """Parse the search response into a list of Documents with score using VectorSearchRetrieverToolMixin function.""" - return self.parse_vector_search_response(search_resp, ignore_cols=ignore_cols, document_class=Document) - -def _validate_and_get_return_columns( - columns: List[str], text_column: str, index_details: IndexDetails -) -> List[str]: - """ - Get a list of columns to retrieve from the index. - - If the index is direct-access index, validate the given columns against the schema. - """ - # add primary key column and source column if not in columns - if index_details.primary_key not in columns: - columns.append(index_details.primary_key) - if text_column and text_column not in columns: - columns.append(text_column) - - # Validate specified columns are in the index - if index_details.is_direct_access_index() and (index_schema := index_details.schema): - if missing_columns := [c for c in columns if c not in index_schema]: - raise ValueError( - "Some columns specified in `columns` are not " - f"in the index schema: {missing_columns}" - ) - return columns - def _validate_embedding(embedding: Optional[Embeddings], index_details: IndexDetails) -> None: if index_details.is_databricks_managed_embeddings(): diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 62651ee..5ab203d 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Tuple +from openai import OpenAI from openai import pydantic_function_tool from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletion @@ -7,11 +8,11 @@ from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION -from databricks_ai_bridge.utils.vector_search import IndexDetails, DatabricksVectorSearchMixin, _validate_and_get_text_column +from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns from databricks.vector_search.client import VectorSearchClient, VectorSearchIndex import json -class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin, DatabricksVectorSearchMixin): +class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. This class integrates with Databricks Vector Search and provides a convenient interface @@ -38,10 +39,6 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin, DatabricksVector final_response.choices[0].message.content """ - embedding: Optional[Embeddings] = Field( - None, description="Embedding model for self-managed embeddings. Used for direct " - "access indexes or delta-sync indexes with self-managed embeddings" - ) text_column: Optional[str] = Field( None, description="The name of the text column to use for the embeddings. " @@ -60,9 +57,10 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin, DatabricksVector def _validate_tool_inputs(self): self._index = VectorSearchClient().get_index(index_name=self.index_name) self._index_details = IndexDetails(self._index) - _validate_and_get_text_column(self.text_column, self._index_details) + self.text_column = validate_and_get_text_column(self.text_column, self._index_details) + self.columns = validate_and_get_return_columns(self.columns, self.text_column, self._index_details) - self.tool: ChatCompletionToolParam = pydantic_function_tool( + self.tool = pydantic_function_tool( VectorSearchRetrieverToolInput, name=self.tool_name or self.index_name, description=self.tool_description or self._get_default_tool_description(self._index_details) @@ -71,7 +69,9 @@ def _validate_tool_inputs(self): def execute_retriever_calls(self, response: ChatCompletion, - choice_index: int = 0) -> List[Dict[str, Any]]: + choice_index: int = 0, + open_ai_client: OpenAI = None, + embedding_model_name: str = None) -> List[Dict[str, Any]]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. @@ -80,24 +80,38 @@ def execute_retriever_calls(self, response: The chat completion response object returned by the OpenAI API. choice_index: The index of the choice to process. Defaults to 0. Note that multiple choices are not supported yet. + open_ai_client: The OpenAI client object to use for generating embeddings. Required for + direct access indexes or delta-sync indexes with self-managed embeddings. + embedding_model_name: The name of the embedding model to use for embedding the query text. + Required for direct access indexes or delta-sync indexes with self-managed embeddings. Returns: A list of messages containing the assistant message and the retriever call results that correspond to the self.tool VectorSearchRetrieverToolInput. """ - - class Document: - def __init__(self, page_content, metadata): - self.page_content = page_content - self.metadata = metadata - def get_query_text_vector(query): if self._index_details.is_databricks_managed_embeddings(): + if open_ai_client or embedding_model_name: + raise ValueError( + f"The index '{self._index_details.name}' uses Databricks-managed embeddings. " + "Do not pass the `open_ai_client` or `embedding_model_name` parameters when executing retriever calls." + ) return query, None # For non-Databricks-managed embeddings - text = query if query_type and query_type.upper() == "HYBRID" else None - vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + if not open_ai_client or not embedding_model_name: + raise ValueError("OpenAI client and embedding model name are required for non-Databricks-managed " + "embeddings Vector Search indexes in order to generate embeddings for retrieval queries.") + text = query if self.query_type and self.query_type.upper() == "HYBRID" else None + vector = open_ai_client.embeddings.create( + input=query, + model=embedding_model_name + )['data'][0]['embedding'] + if (index_embedding_dimension := index_details.embedding_vector_column.get("embedding_dimension")) and \ + len(vector) != index_embedding_dimension: + raise ValueError( + f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" + ) return text, vector def do_arguments_match(llm_call_arguments: Dict[str, Any]): @@ -120,21 +134,23 @@ def do_arguments_match(llm_call_arguments: Dict[str, Any]): query_text, query_vector = get_query_text_vector(llm_call_arguments_json["query"]) search_resp = self._index.similarity_search( - columns=self._columns, + columns=self.columns, query_text=query_text, query_vector=query_vector, - filters=filter, - num_results=k, - query_type=query_type, + filters=self.filters, + num_results=self.num_results, + query_type=self.query_type, ) embedding_column = self._index_details.embedding_vector_column["name"] # Don't show the embedding column in the response ignore_cols: List = [embedding_column] if embedding_column not in self.columns else [] - docs_with_score: List[Tuple[Document, float]] = \ - self.parse_vector_search_response( # from DatabricksVectorSearchMixin + docs_with_score: List[Tuple[Dict, float]] = \ + parse_vector_search_response( search_resp, + self._index_details, + self.text_column, ignore_cols=ignore_cols, - document_class=Document + document_class=dict ) function_call_result_message = { diff --git a/integrations/openai/tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/test_vector_search_retriever_tool.py new file mode 100644 index 0000000..a12448d --- /dev/null +++ b/integrations/openai/tests/test_vector_search_retriever_tool.py @@ -0,0 +1,102 @@ +from typing import Any, Dict, List, Optional + +import pytest +from langchain_core.embeddings import Embeddings +from langchain_core.tools import BaseTool + +from tests.utils.chat_models import llm, mock_client # noqa: F401 +from tests.utils.vector_search import ( # noqa: F401 + ALL_INDEX_NAMES, + DELTA_SYNC_INDEX, + EMBEDDING_MODEL, + mock_vs_client, + mock_workspace_client, +) + + +def init_vector_search_tool( + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, +) -> VectorSearchRetrieverTool: + kwargs: Dict[str, Any] = { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + } + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding": EMBEDDING_MODEL, + "text_column": "text", + } + ) + return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert isinstance(vector_search_tool, BaseTool) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: + from langchain_core.messages import AIMessage + + vector_search_tool = init_vector_search_tool(index_name) + llm_with_tools = llm.bind_tools([vector_search_tool]) + response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") + assert isinstance(response, AIMessage) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) +@pytest.mark.parametrize("text_column", [None, "text"]) +def test_vector_search_retriever_tool_combinations( + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], + embedding: Optional[Any], + text_column: Optional[str], +) -> None: + if index_name == DELTA_SYNC_INDEX: + embedding = None + text_column = None + + vector_search_tool = init_vector_search_tool( + index_name=index_name, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + embedding=embedding, + text_column=text_column, + ) + assert isinstance(vector_search_tool, BaseTool) + result = vector_search_tool.invoke("Databricks Agent Framework") + assert result is not None + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_vector_search_retriever_tool_description_generation(index_name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert vector_search_tool.name != "" + assert vector_search_tool.description != "" + assert vector_search_tool.name == index_name + assert ( + "A vector search-based retrieval tool for querying indexed embeddings." + in vector_search_tool.description + ) + assert vector_search_tool.args_schema.model_fields["query"] is not None + assert vector_search_tool.args_schema.model_fields["query"].description == ( + "The string used to query the index with and identify the most similar " + "vectors and return the associated documents." + ) diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py index ddd2867..9b967e0 100644 --- a/src/databricks_ai_bridge/utils/vector_search.py +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List, Tuple class IndexType(str, Enum): DIRECT_ACCESS = "DIRECT_ACCESS" @@ -57,35 +57,39 @@ def is_direct_access_index(self) -> bool: def is_databricks_managed_embeddings(self) -> bool: return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None -class DatabricksVectorSearchMixin: - def parse_vector_search_response( - self, search_resp: Dict, ignore_cols: Optional[List[str]] = None, document_class: Any = dict - ) -> List[Tuple[Dict, float]]: - """ - Parse the search response into a list of Documents with score. - The document_class parameter is used to specify the class of the document to be created. - """ - if ignore_cols is None: - ignore_cols = [] - - columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] - docs_with_score = [] - for result in search_resp.get("result", dict()).get("data_array", []): - doc_id = result[columns.index(self._index_details.primary_key)] - text_content = result[columns.index(self._text_column)] - ignore_cols = [self._primary_key, self._text_column] + ignore_cols - metadata = { - col: value - for col, value in zip(columns[:-1], result[:-1]) - if col not in ignore_cols - } - metadata[self._primary_key] = doc_id - score = result[-1] - doc = document_class(page_content=text_content, metadata=metadata) - docs_with_score.append((doc, score)) - return docs_with_score - -def _validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: + +def parse_vector_search_response( + search_resp: Dict, + index_details: IndexDetails, + text_column: str, + ignore_cols: Optional[List[str]] = None, + document_class: Any = dict +) -> List[Tuple[Dict, float]]: + """ + Parse the search response into a list of Documents with score. + The document_class parameter is used to specify the class of the document to be created. + """ + if ignore_cols is None: + ignore_cols = [] + + columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] + docs_with_score = [] + for result in search_resp.get("result", dict()).get("data_array", []): + doc_id = result[columns.index(index_details.primary_key)] + text_content = result[columns.index(text_column)] + ignore_cols = [index_details.primary_key, text_column] + ignore_cols + metadata = { + col: value + for col, value in zip(columns[:-1], result[:-1]) + if col not in ignore_cols + } + metadata[index_details.primary_key] = doc_id + score = result[-1] + doc = document_class(page_content=text_content, metadata=metadata) + docs_with_score.append((doc, score)) + return docs_with_score + +def validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: if index_details.is_databricks_managed_embeddings(): index_source_column: str = index_details.embedding_source_column["name"] # check if input text column matches the source column of the index @@ -98,4 +102,27 @@ def _validate_and_get_text_column(text_column: Optional[str], index_details: Ind else: if text_column is None: raise ValueError("The `text_column` parameter is required for this index.") - return text_column \ No newline at end of file + return text_column + +def _validate_and_get_return_columns( + columns: List[str], text_column: str, index_details: IndexDetails +) -> List[str]: + """ + Get a list of columns to retrieve from the index. + + If the index is direct-access index, validate the given columns against the schema. + """ + # add primary key column and source column if not in columns + if index_details.primary_key not in columns: + columns.append(index_details.primary_key) + if text_column and text_column not in columns: + columns.append(text_column) + + # Validate specified columns are in the index + if index_details.is_direct_access_index() and (index_schema := index_details.schema): + if missing_columns := [c for c in columns if c not in index_schema]: + raise ValueError( + "Some columns specified in `columns` are not " + f"in the index schema: {missing_columns}" + ) + return columns \ No newline at end of file diff --git a/tests/utils/vector_search.py b/tests/utils/vector_search.py new file mode 100644 index 0000000..e69de29 From 005a5333a7f2b61d4d495cbad6df9db0469e5f2b Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Mon, 30 Dec 2024 16:54:41 -0800 Subject: [PATCH 04/14] Working e2e delta sync index happy case --- .gitignore | 1 + .../vector_search_retriever_tool.py | 45 ++++++++++--------- .../utils/vector_search.py | 2 +- tests/utils/vector_search.py | 0 4 files changed, 25 insertions(+), 23 deletions(-) delete mode 100644 tests/utils/vector_search.py diff --git a/.gitignore b/.gitignore index 40b3857..dfa90d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # test .pytest_cache/ +mlruns/ # Byte-compiled files __pycache__ diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 5ab203d..7c44bb8 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -2,12 +2,12 @@ from openai import OpenAI from openai import pydantic_function_tool -from openai.types.chat import ChatCompletionToolParam +from openai.types.chat import ChatCompletionToolParam, ChatCompletionMessageToolCall from openai.types.chat import ChatCompletion from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model -from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION +from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns from databricks.vector_search.client import VectorSearchClient, VectorSearchIndex import json @@ -33,7 +33,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): final_response = openai.chat.completions.create( model="gpt-4o", - messages=initial_messages + retriever_call_message + remaining_tool_call_messages + messages=initial_messages + retriever_call_message + remaining_tool_call_messages, tools=tools, ) final_response.choices[0].message.content @@ -48,7 +48,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): ) tool: ChatCompletionToolParam = Field( - ..., description="The tool input used in the OpenAI chat completion SDK" + None, description="The tool input used in the OpenAI chat completion SDK" ) _index: VectorSearchIndex = PrivateAttr() _index_details: IndexDetails = PrivateAttr() @@ -58,11 +58,16 @@ def _validate_tool_inputs(self): self._index = VectorSearchClient().get_index(index_name=self.index_name) self._index_details = IndexDetails(self._index) self.text_column = validate_and_get_text_column(self.text_column, self._index_details) - self.columns = validate_and_get_return_columns(self.columns, self.text_column, self._index_details) + self.columns = validate_and_get_return_columns(self.columns or [], self.text_column, self._index_details) + + # OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'." + # The '.' from the index name are not allowed + def rewrite_index_name(index_name: str): + return index_name.split(".")[-1] self.tool = pydantic_function_tool( VectorSearchRetrieverToolInput, - name=self.tool_name or self.index_name, + name=self.tool_name or rewrite_index_name(self.index_name), description=self.tool_description or self._get_default_tool_description(self._index_details) ) return self @@ -89,7 +94,9 @@ def execute_retriever_calls(self, A list of messages containing the assistant message and the retriever call results that correspond to the self.tool VectorSearchRetrieverToolInput. """ - def get_query_text_vector(query): + + def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Optional[str], Optional[List[float]]]: + query = json.loads(tool_call.function.arguments)["query"] if self._index_details.is_databricks_managed_embeddings(): if open_ai_client or embedding_model_name: raise ValueError( @@ -107,19 +114,18 @@ def get_query_text_vector(query): input=query, model=embedding_model_name )['data'][0]['embedding'] - if (index_embedding_dimension := index_details.embedding_vector_column.get("embedding_dimension")) and \ + if (index_embedding_dimension := self._index_details.embedding_vector_column.get("embedding_dimension")) and \ len(vector) != index_embedding_dimension: raise ValueError( f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" ) return text, vector - def do_arguments_match(llm_call_arguments: Dict[str, Any]): - # TODO: Remove print statement - print("llm_call_arguments_json: ", llm_call_arguments) - print("retriever_tool_params: ", self.tool.function.parameters) - retriever_tool_params: Dict[str, Any] = self.tool.function.parameters - return set(llm_call_arguments.keys()) == set(retriever_tool_params.keys()) + def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: + tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys()) + vs_index_arguments: Set[str] = set(self.tool["function"]["parameters"]["properties"].keys()) + return tool_call.function.name == self.tool["function"]["name"] and \ + tool_call_arguments == vs_index_arguments message = response.choices[choice_index].message llm_tool_calls = message.tool_calls @@ -127,12 +133,10 @@ def do_arguments_match(llm_call_arguments: Dict[str, Any]): if llm_tool_calls: for llm_tool_call in llm_tool_calls: # Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput - llm_call_arguments_json: Dict[str, Any] = json.loads(llm_tool_call.function.arguments) - if llm_tool_call.function.name != self.tool.function.name \ - or not do_arguments_match(llm_call_arguments_json): + if not is_tool_call_for_index(llm_tool_call): continue - query_text, query_vector = get_query_text_vector(llm_call_arguments_json["query"]) + query_text, query_vector = get_query_text_vector(llm_tool_call) search_resp = self._index.similarity_search( columns=self.columns, query_text=query_text, @@ -141,15 +145,12 @@ def do_arguments_match(llm_call_arguments: Dict[str, Any]): num_results=self.num_results, query_type=self.query_type, ) - embedding_column = self._index_details.embedding_vector_column["name"] - # Don't show the embedding column in the response - ignore_cols: List = [embedding_column] if embedding_column not in self.columns else [] docs_with_score: List[Tuple[Dict, float]] = \ parse_vector_search_response( search_resp, self._index_details, self.text_column, - ignore_cols=ignore_cols, + ignore_cols=[], document_class=dict ) diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py index 9b967e0..e607866 100644 --- a/src/databricks_ai_bridge/utils/vector_search.py +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -104,7 +104,7 @@ def validate_and_get_text_column(text_column: Optional[str], index_details: Inde raise ValueError("The `text_column` parameter is required for this index.") return text_column -def _validate_and_get_return_columns( +def validate_and_get_return_columns( columns: List[str], text_column: str, index_details: IndexDetails ) -> List[str]: """ diff --git a/tests/utils/vector_search.py b/tests/utils/vector_search.py deleted file mode 100644 index e69de29..0000000 From ff1866307f94330eadec6aed990fe8a9432fc92f Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 2 Jan 2025 20:53:54 -0800 Subject: [PATCH 05/14] Add unit tests and some validations --- .../test_vector_search_retriever_tool.py | 6 +- .../tests/unit_tests/test_vectorstores.py | 9 +- .../langchain/tests/utils/vector_search.py | 144 +---------------- .../vector_search_retriever_tool.py | 30 ++-- .../test_vector_search_retriever_tool.py | 102 ------------ .../test_vector_search_retriever_tool.py | 151 ++++++++++++++++++ .../test_utils/vector_search.py | 146 +++++++++++++++++ 7 files changed, 328 insertions(+), 260 deletions(-) delete mode 100644 integrations/openai/tests/test_vector_search_retriever_tool.py create mode 100644 integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py create mode 100644 src/databricks_ai_bridge/test_utils/vector_search.py diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 3dd1855..90c7efe 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -6,12 +6,12 @@ from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool from tests.utils.chat_models import llm, mock_client # noqa: F401 -from tests.utils.vector_search import ( # noqa: F401 +from tests.utils.vector_search import EMBEDDING_MODEL +from databricks_ai_bridge.test_utils.vector_search import( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, - EMBEDDING_MODEL, mock_vs_client, - mock_workspace_client, + mock_workspace_client ) diff --git a/integrations/langchain/tests/unit_tests/test_vectorstores.py b/integrations/langchain/tests/unit_tests/test_vectorstores.py index 29ead86..6aca450 100644 --- a/integrations/langchain/tests/unit_tests/test_vectorstores.py +++ b/integrations/langchain/tests/unit_tests/test_vectorstores.py @@ -7,15 +7,18 @@ from databricks_langchain.vectorstores import DatabricksVectorSearch from tests.utils.vector_search import ( + EMBEDDING_MODEL, + FakeEmbeddings, +) +from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, DIRECT_ACCESS_INDEX, - EMBEDDING_MODEL, ENDPOINT_NAME, INDEX_DETAILS, INPUT_TEXTS, - FakeEmbeddings, - mock_vs_client, # noqa: F401 + mock_workspace_client, + mock_vs_client ) diff --git a/integrations/langchain/tests/utils/vector_search.py b/integrations/langchain/tests/utils/vector_search.py index a57c14c..b751ec3 100644 --- a/integrations/langchain/tests/utils/vector_search.py +++ b/integrations/langchain/tests/utils/vector_search.py @@ -1,15 +1,6 @@ -import uuid -from typing import Generator, List, Optional -from unittest import mock -from unittest.mock import MagicMock, patch - -import pytest -from databricks.vector_search.client import VectorSearchIndex # type: ignore +from typing import List from langchain_core.embeddings import Embeddings - -INPUT_TEXTS = ["foo", "bar", "baz"] -DEFAULT_VECTOR_DIMENSION = 4 - +from databricks_ai_bridge.test_utils.vector_search import DEFAULT_VECTOR_DIMENSION class FakeEmbeddings(Embeddings): """Fake embeddings functionality for testing.""" @@ -30,134 +21,3 @@ def embed_query(self, text: str) -> List[float]: EMBEDDING_MODEL = FakeEmbeddings() - - -### Dummy similarity_search() Response ### -EXAMPLE_SEARCH_RESPONSE = { - "manifest": { - "column_count": 3, - "columns": [ - {"name": "id"}, - {"name": "text"}, - {"name": "text_vector"}, - {"name": "score"}, - ], - }, - "result": { - "row_count": len(INPUT_TEXTS), - "data_array": sorted( - [ - [str(uuid.uuid4()), s, e, 0.5] - for s, e in zip(INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)) - ], - key=lambda x: x[2], # type: ignore - reverse=True, - ), - }, - "next_page_token": "", -} - - -### Dummy Indices #### - -ENDPOINT_NAME = "test-endpoint" -DIRECT_ACCESS_INDEX = "test.direct_access.index" -DELTA_SYNC_INDEX = "test.delta_sync.index" -DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test.delta_sync_self_managed.index" -ALL_INDEX_NAMES = { - DIRECT_ACCESS_INDEX, - DELTA_SYNC_INDEX, - DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, -} - -INDEX_DETAILS = { - DELTA_SYNC_INDEX: { - "name": DELTA_SYNC_INDEX, - "endpoint_name": ENDPOINT_NAME, - "index_type": "DELTA_SYNC", - "primary_key": "id", - "delta_sync_index_spec": { - "source_table": "ml.llm.source_table", - "pipeline_type": "CONTINUOUS", - "embedding_source_columns": [ - { - "name": "text", - "embedding_model_endpoint_name": "openai-text-embedding", - } - ], - }, - }, - DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: { - "name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, - "endpoint_name": ENDPOINT_NAME, - "index_type": "DELTA_SYNC", - "primary_key": "id", - "delta_sync_index_spec": { - "source_table": "ml.llm.source_table", - "pipeline_type": "CONTINUOUS", - "embedding_vector_columns": [ - { - "name": "text_vector", - "embedding_dimension": DEFAULT_VECTOR_DIMENSION, - } - ], - }, - }, - DIRECT_ACCESS_INDEX: { - "name": DIRECT_ACCESS_INDEX, - "endpoint_name": ENDPOINT_NAME, - "index_type": "DIRECT_ACCESS", - "primary_key": "id", - "direct_access_index_spec": { - "embedding_vector_columns": [ - { - "name": "text_vector", - "embedding_dimension": DEFAULT_VECTOR_DIMENSION, - } - ], - "schema_json": f"{{" - f'"{"id"}": "int", ' - f'"feat1": "str", ' - f'"feat2": "float", ' - f'"text": "string", ' - f'"{"text_vector"}": "array"' - f"}}", - }, - }, -} - - -@pytest.fixture(autouse=True) -def mock_vs_client() -> Generator: - def _get_index( - endpoint_name: Optional[str] = None, - index_name: str = None, # type: ignore - ) -> MagicMock: - index = MagicMock(spec=VectorSearchIndex) - index.describe.return_value = INDEX_DETAILS[index_name] - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - return index - - mock_client = MagicMock() - mock_client.get_index.side_effect = _get_index - with mock.patch( - "databricks.vector_search.client.VectorSearchClient", - return_value=mock_client, - ): - yield - - -@pytest.fixture(autouse=True) -def mock_workspace_client() -> Generator: - def _get_table_comment(full_name: str) -> MagicMock: - table = MagicMock() - table.comment = "Mocked table comment" - return table - - mock_client = MagicMock() - mock_client.tables.get.side_effect = _get_table_comment - with patch( - "databricks.sdk.WorkspaceClient", - return_value=mock_client, - ): - yield diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 7c44bb8..8a9ad08 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -9,7 +9,7 @@ from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns -from databricks.vector_search.client import VectorSearchClient, VectorSearchIndex +from databricks.vector_search.client import VectorSearchIndex import json class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): @@ -55,6 +55,8 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): @model_validator(mode="after") def _validate_tool_inputs(self): + from databricks.vector_search.client import VectorSearchClient # import here so we can mock in tests + self._index = VectorSearchClient().get_index(index_name=self.index_name) self._index_details = IndexDetails(self._index) self.text_column = validate_and_get_text_column(self.text_column, self._index_details) @@ -75,8 +77,8 @@ def rewrite_index_name(index_name: str): def execute_retriever_calls(self, response: ChatCompletion, choice_index: int = 0, - open_ai_client: OpenAI = None, - embedding_model_name: str = None) -> List[Dict[str, Any]]: + embedding_model_name: str = None, + openai_client: OpenAI = None) -> List[Dict[str, Any]]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. @@ -85,10 +87,10 @@ def execute_retriever_calls(self, response: The chat completion response object returned by the OpenAI API. choice_index: The index of the choice to process. Defaults to 0. Note that multiple choices are not supported yet. - open_ai_client: The OpenAI client object to use for generating embeddings. Required for - direct access indexes or delta-sync indexes with self-managed embeddings. embedding_model_name: The name of the embedding model to use for embedding the query text. Required for direct access indexes or delta-sync indexes with self-managed embeddings. + openai_client: The OpenAI client object used to generate embeddings for retrieval queries. If not provided, + the default OpenAI client in the current environment will be used. Returns: A list of messages containing the assistant message and the retriever call results @@ -98,19 +100,24 @@ def execute_retriever_calls(self, def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Optional[str], Optional[List[float]]]: query = json.loads(tool_call.function.arguments)["query"] if self._index_details.is_databricks_managed_embeddings(): - if open_ai_client or embedding_model_name: + if embedding_model_name: raise ValueError( f"The index '{self._index_details.name}' uses Databricks-managed embeddings. " - "Do not pass the `open_ai_client` or `embedding_model_name` parameters when executing retriever calls." + "Do not pass the `embedding_model_name` parameter when executing retriever calls." ) return query, None # For non-Databricks-managed embeddings - if not open_ai_client or not embedding_model_name: - raise ValueError("OpenAI client and embedding model name are required for non-Databricks-managed " + from openai import OpenAI + oai_client = openai_client or OpenAI() + if not oai_client.api_key: + raise ValueError("OpenAI API key is required to generate embeddings for retrieval queries.") + if not embedding_model_name: + raise ValueError("The embedding model name is required for non-Databricks-managed " "embeddings Vector Search indexes in order to generate embeddings for retrieval queries.") + text = query if self.query_type and self.query_type.upper() == "HYBRID" else None - vector = open_ai_client.embeddings.create( + vector = oai_client.embeddings.create( input=query, model=embedding_model_name )['data'][0]['embedding'] @@ -127,6 +134,8 @@ def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: return tool_call.function.name == self.tool["function"]["name"] and \ tool_call_arguments == vs_index_arguments + if type(response) is not ChatCompletion: + raise ValueError("response must be an instance of ChatCompletion") message = response.choices[choice_index].message llm_tool_calls = message.tool_calls function_calls = [] @@ -134,6 +143,7 @@ def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: for llm_tool_call in llm_tool_calls: # Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput if not is_tool_call_for_index(llm_tool_call): + raise ValueError("The tool call does not correspond to the VectorSearchRetrieverToolInput.") continue query_text, query_vector = get_query_text_vector(llm_tool_call) diff --git a/integrations/openai/tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/test_vector_search_retriever_tool.py deleted file mode 100644 index a12448d..0000000 --- a/integrations/openai/tests/test_vector_search_retriever_tool.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Any, Dict, List, Optional - -import pytest -from langchain_core.embeddings import Embeddings -from langchain_core.tools import BaseTool - -from tests.utils.chat_models import llm, mock_client # noqa: F401 -from tests.utils.vector_search import ( # noqa: F401 - ALL_INDEX_NAMES, - DELTA_SYNC_INDEX, - EMBEDDING_MODEL, - mock_vs_client, - mock_workspace_client, -) - - -def init_vector_search_tool( - index_name: str, - columns: Optional[List[str]] = None, - tool_name: Optional[str] = None, - tool_description: Optional[str] = None, - embedding: Optional[Embeddings] = None, - text_column: Optional[str] = None, -) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - } - if index_name != DELTA_SYNC_INDEX: - kwargs.update( - { - "embedding": EMBEDDING_MODEL, - "text_column": "text", - } - ) - return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_init(index_name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert isinstance(vector_search_tool, BaseTool) - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: - from langchain_core.messages import AIMessage - - vector_search_tool = init_vector_search_tool(index_name) - llm_with_tools = llm.bind_tools([vector_search_tool]) - response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") - assert isinstance(response, AIMessage) - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -@pytest.mark.parametrize("columns", [None, ["id", "text"]]) -@pytest.mark.parametrize("tool_name", [None, "test_tool"]) -@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) -@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) -@pytest.mark.parametrize("text_column", [None, "text"]) -def test_vector_search_retriever_tool_combinations( - index_name: str, - columns: Optional[List[str]], - tool_name: Optional[str], - tool_description: Optional[str], - embedding: Optional[Any], - text_column: Optional[str], -) -> None: - if index_name == DELTA_SYNC_INDEX: - embedding = None - text_column = None - - vector_search_tool = init_vector_search_tool( - index_name=index_name, - columns=columns, - tool_name=tool_name, - tool_description=tool_description, - embedding=embedding, - text_column=text_column, - ) - assert isinstance(vector_search_tool, BaseTool) - result = vector_search_tool.invoke("Databricks Agent Framework") - assert result is not None - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_vector_search_retriever_tool_description_generation(index_name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert vector_search_tool.name != "" - assert vector_search_tool.description != "" - assert vector_search_tool.name == index_name - assert ( - "A vector search-based retrieval tool for querying indexed embeddings." - in vector_search_tool.description - ) - assert vector_search_tool.args_schema.model_fields["query"] is not None - assert vector_search_tool.args_schema.model_fields["query"].description == ( - "The string used to query the index with and identify the most similar " - "vectors and return the associated documents." - ) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py new file mode 100644 index 0000000..a9c6bc9 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, Generator, List, Optional + +import pytest + +from unittest import mock +from unittest.mock import MagicMock, patch +from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 + ALL_INDEX_NAMES, + DELTA_SYNC_INDEX, + DIRECT_ACCESS_INDEX, + mock_vs_client, + mock_workspace_client, +) +from databricks_openai import VectorSearchRetrieverTool +from pydantic import BaseModel +from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call_param import Function +import os + +@pytest.fixture(autouse=True) +def mock_openai_client(): + mock_client = MagicMock() + mock_client.api_key = "fake_api_key" + mock_client.embeddings.create.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}] + } + with patch("openai.OpenAI", return_value=mock_client): + yield mock_client + +def get_chat_completion_response(tool_name: str, index_name: str): + return ChatCompletion( + id='chatcmpl-AlSTQf3qIjeEOdoagPXUYhuWZkwme', + choices=[ + Choice( + finish_reason='tool_calls', + index=0, + logprobs=None, + message=ChatCompletionMessage( + content=None, + refusal=None, + role='assistant', + audio=None, + function_call=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id='call_VtmBTsVM2zQ3yL5GzddMgWb0', + function=Function( + arguments='{"query":"Databricks Agent Framework"}', + name=tool_name or index_name.split(".")[-1] # see rewrite_index_name() in VectorSearchRetrieverTool + ), + type='function' + ) + ] + ) + ) + ], + created=1735874232, + model='gpt-4o-mini-2024-07-18', + object='chat.completion', + ) + +def init_vector_search_tool( + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + text_column: Optional[str] = None, +) -> VectorSearchRetrieverTool: + kwargs: Dict[str, Any] = { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "text_column": text_column, + } + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "text_column": "text", + } + ) + return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] + +class SelfManagedEmbeddingsTest: + def __init__(self, text_column=None, embedding_model_name=None, open_ai_client=None): + self.text_column = text_column + self.embedding_model_name = embedding_model_name + self.open_ai_client = open_ai_client + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +def test_vector_search_retriever_tool_init( + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str] +) -> None: + if index_name == DELTA_SYNC_INDEX: + self_managed_embeddings_test = SelfManagedEmbeddingsTest() + else: + from openai import OpenAI + self_managed_embeddings_test = SelfManagedEmbeddingsTest("text", "text-embedding-3-small", OpenAI(api_key="your-api-key")) + + vector_search_tool = init_vector_search_tool( + index_name=index_name, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + text_column=self_managed_embeddings_test.text_column, + ) + assert isinstance(vector_search_tool, BaseModel) + # simulate call to openai.chat.completions.create + chat_completion_resp = get_chat_completion_response(tool_name, index_name) + response = vector_search_tool.execute_retriever_calls( + chat_completion_resp, + embedding_model_name=self_managed_embeddings_test.embedding_model_name, + openai_client=self_managed_embeddings_test.open_ai_client + ) + assert response is not None + + +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +def test_open_ai_client_from_env( + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str] +) -> None: + self_managed_embeddings_test = SelfManagedEmbeddingsTest("text", "text-embedding-3-small", None) + os.environ["OPENAI_API_KEY"] = "your-api-key" + + vector_search_tool = init_vector_search_tool( + index_name=DIRECT_ACCESS_INDEX, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + text_column=self_managed_embeddings_test.text_column, + ) + assert isinstance(vector_search_tool, BaseModel) + # simulate call to openai.chat.completions.create + chat_completion_resp = get_chat_completion_response(tool_name, DIRECT_ACCESS_INDEX) + response = vector_search_tool.execute_retriever_calls( + chat_completion_resp, + embedding_model_name=self_managed_embeddings_test.embedding_model_name, + openai_client=self_managed_embeddings_test.open_ai_client + ) + assert response is not None diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py new file mode 100644 index 0000000..d4ee1db --- /dev/null +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -0,0 +1,146 @@ +import uuid +from typing import Generator, List, Optional +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from databricks.vector_search.client import VectorSearchIndex # type: ignore + +INPUT_TEXTS = ["foo", "bar", "baz"] +DEFAULT_VECTOR_DIMENSION = 4 + +def embed_documents(embedding_texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (DEFAULT_VECTOR_DIMENSION - 1) + [float(i)] for i in range(len(embedding_texts)) + ] + +### Dummy similarity_search() Response ### +EXAMPLE_SEARCH_RESPONSE = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "text_vector"}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(INPUT_TEXTS), + "data_array": sorted( + [ + [str(uuid.uuid4()), s, e, 0.5] + for s, e in zip(INPUT_TEXTS, embed_documents(INPUT_TEXTS)) + ], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + + +### Dummy Indices #### + +ENDPOINT_NAME = "test-endpoint" +DIRECT_ACCESS_INDEX = "test.direct_access.index" +DELTA_SYNC_INDEX = "test.delta_sync.index" +DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test.delta_sync_self_managed.index" +ALL_INDEX_NAMES = { + DIRECT_ACCESS_INDEX, + DELTA_SYNC_INDEX, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, +} + +INDEX_DETAILS = { + DELTA_SYNC_INDEX: { + "name": DELTA_SYNC_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_source_columns": [ + { + "name": "text", + "embedding_model_endpoint_name": "openai-text-embedding", + } + ], + }, + }, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: { + "name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + }, + }, + DIRECT_ACCESS_INDEX: { + "name": DIRECT_ACCESS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DIRECT_ACCESS", + "primary_key": "id", + "direct_access_index_spec": { + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + "schema_json": f"{{" + f'"{"id"}": "int", ' + f'"feat1": "str", ' + f'"feat2": "float", ' + f'"text": "string", ' + f'"{"text_vector"}": "array"' + f"}}", + }, + }, +} + + +@pytest.fixture(autouse=True) +def mock_vs_client() -> Generator: + def _get_index( + endpoint_name: Optional[str] = None, + index_name: str = None, # type: ignore + ) -> MagicMock: + index = MagicMock(spec=VectorSearchIndex) + index.describe.return_value = INDEX_DETAILS[index_name] + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + return index + + mock_client = MagicMock() + mock_client.get_index.side_effect = _get_index + with mock.patch( + "databricks.vector_search.client.VectorSearchClient", + return_value=mock_client, + ): + yield + + +@pytest.fixture(autouse=True) +def mock_workspace_client() -> Generator: + def _get_table_comment(full_name: str) -> MagicMock: + table = MagicMock() + table.comment = "Mocked table comment" + return table + + mock_client = MagicMock() + mock_client.tables.get.side_effect = _get_table_comment + with patch( + "databricks.sdk.WorkspaceClient", + return_value=mock_client, + ): + yield From d27c68fac3937ffe4725d8b8e486b79d50e4203f Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 2 Jan 2025 20:54:19 -0800 Subject: [PATCH 06/14] Undo line --- .../openai/src/databricks_openai/vector_search_retriever_tool.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 8a9ad08..5d4b9d7 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -143,7 +143,6 @@ def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: for llm_tool_call in llm_tool_calls: # Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput if not is_tool_call_for_index(llm_tool_call): - raise ValueError("The tool call does not correspond to the VectorSearchRetrieverToolInput.") continue query_text, query_vector = get_query_text_vector(llm_tool_call) From 5cc9b469f3d3e1b2fc784eb6ec965150468599b7 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Wed, 8 Jan 2025 14:59:51 -0800 Subject: [PATCH 07/14] Remove extra changes --- .../src/databricks_langchain/vector_search_retriever_tool.py | 3 ++- integrations/langchain/tests/utils/vector_search.py | 1 + integrations/openai/pyproject.toml | 2 +- src/databricks_ai_bridge/test_utils/vector_search.py | 2 ++ src/databricks_ai_bridge/utils/vector_search.py | 4 ++++ src/databricks_ai_bridge/vector_search_retriever_tool.py | 2 ++ 6 files changed, 12 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 62873a0..0a3020c 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -11,6 +11,7 @@ from databricks_langchain.vectorstores import DatabricksVectorSearch + class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. @@ -66,4 +67,4 @@ def _validate_tool_inputs(self): def _run(self, query: str) -> str: return self._vector_store.similarity_search( query, k=self.num_results, filter=self.filters, query_type=self.query_type - ) \ No newline at end of file + ) diff --git a/integrations/langchain/tests/utils/vector_search.py b/integrations/langchain/tests/utils/vector_search.py index 5993001..91b8802 100644 --- a/integrations/langchain/tests/utils/vector_search.py +++ b/integrations/langchain/tests/utils/vector_search.py @@ -3,6 +3,7 @@ from databricks_ai_bridge.test_utils.vector_search import DEFAULT_VECTOR_DIMENSION from langchain_core.embeddings import Embeddings + class FakeEmbeddings(Embeddings): """Fake embeddings functionality for testing.""" diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b70f1d9..d4d2122 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -68,4 +68,4 @@ docstring-code-format = true docstring-code-line-length = 88 [tool.ruff.lint.pydocstyle] -convention = "google" \ No newline at end of file +convention = "google" diff --git a/src/databricks_ai_bridge/test_utils/vector_search.py b/src/databricks_ai_bridge/test_utils/vector_search.py index 42aff51..368e359 100644 --- a/src/databricks_ai_bridge/test_utils/vector_search.py +++ b/src/databricks_ai_bridge/test_utils/vector_search.py @@ -9,6 +9,7 @@ INPUT_TEXTS = ["foo", "bar", "baz"] DEFAULT_VECTOR_DIMENSION = 4 + def embed_documents(embedding_texts: List[str]) -> List[List[float]]: """Return simple embeddings.""" return [ @@ -16,6 +17,7 @@ def embed_documents(embedding_texts: List[str]) -> List[List[float]]: for i in range(len(embedding_texts)) ] + ### Dummy similarity_search() Response ### EXAMPLE_SEARCH_RESPONSE = { "manifest": { diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py index 3fd3c32..b572449 100644 --- a/src/databricks_ai_bridge/utils/vector_search.py +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -2,10 +2,12 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple + class IndexType(str, Enum): DIRECT_ACCESS = "DIRECT_ACCESS" DELTA_SYNC = "DELTA_SYNC" + class IndexDetails: """An utility class to store the configuration details of an index.""" @@ -87,6 +89,7 @@ def parse_vector_search_response( docs_with_score.append((doc, score)) return docs_with_score + def validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: if index_details.is_databricks_managed_embeddings(): index_source_column: str = index_details.embedding_source_column["name"] @@ -102,6 +105,7 @@ def validate_and_get_text_column(text_column: Optional[str], index_details: Inde raise ValueError("The `text_column` parameter is required for this index.") return text_column + def validate_and_get_return_columns( columns: List[str], text_column: str, index_details: IndexDetails ) -> List[str]: diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 36b7eac..a8f2a58 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional + from pydantic import BaseModel, Field from databricks_ai_bridge.utils.vector_search import IndexDetails @@ -12,6 +13,7 @@ class VectorSearchRetrieverToolInput(BaseModel): "vectors and return the associated documents." ) + class VectorSearchRetrieverToolMixin(BaseModel): """ Mixin class for Databricks Vector Search retrieval tools. From 88ae6048635dff045025d5290fef853f363a6a80 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 9 Jan 2025 10:34:38 -0800 Subject: [PATCH 08/14] Fix embedding --- .../src/databricks_openai/vector_search_retriever_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 5d4b9d7..d591755 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -120,7 +120,7 @@ def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Opt vector = oai_client.embeddings.create( input=query, model=embedding_model_name - )['data'][0]['embedding'] + ).data[0].embedding if (index_embedding_dimension := self._index_details.embedding_vector_column.get("embedding_dimension")) and \ len(vector) != index_embedding_dimension: raise ValueError( From bb9f1092abc89db1378a5da8f6c8129a549ec59b Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 9 Jan 2025 12:19:19 -0800 Subject: [PATCH 09/14] Remove double field --- .../vector_search_retriever_tool.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 0a3020c..1f17ba1 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -34,16 +34,6 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): description: str = Field(default="", description="The description of the tool") args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput - embedding: Optional[Embeddings] = Field( - None, description="Embedding model for self-managed embeddings." - ) - text_column: Optional[str] = Field( - None, - description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", - ) - _vector_store: DatabricksVectorSearch = PrivateAttr() @model_validator(mode="after") From 9041017fd089f10950e846b5b9ca808e1e871a2c Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 9 Jan 2025 12:21:49 -0800 Subject: [PATCH 10/14] Lint --- .../vector_search_retriever_tool.py | 113 +++++++++++------- .../test_vector_search_retriever_tool.py | 76 ++++++------ 2 files changed, 111 insertions(+), 78 deletions(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index d591755..ad1e008 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,16 +1,22 @@ -from typing import Any, Dict, List, Optional, Type, Tuple - -from openai import OpenAI -from openai import pydantic_function_tool -from openai.types.chat import ChatCompletionToolParam, ChatCompletionMessageToolCall -from openai.types.chat import ChatCompletion - -from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model +import json +from typing import Any, Dict, List, Optional, Tuple -from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput -from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns from databricks.vector_search.client import VectorSearchIndex -import json +from databricks_ai_bridge.utils.vector_search import ( + IndexDetails, + parse_vector_search_response, + validate_and_get_return_columns, + validate_and_get_text_column, +) +from databricks_ai_bridge.vector_search_retriever_tool import ( + VectorSearchRetrieverToolInput, + VectorSearchRetrieverToolMixin, +) +from pydantic import Field, PrivateAttr, model_validator + +from openai import OpenAI, pydantic_function_tool +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall, ChatCompletionToolParam + class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): """ @@ -42,9 +48,9 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): text_column: Optional[str] = Field( None, description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings. Used for direct access indexes or " - "delta-sync indexes with self-managed embeddings", + "Required for direct-access index or delta-sync index with " + "self-managed embeddings. Used for direct access indexes or " + "delta-sync indexes with self-managed embeddings", ) tool: ChatCompletionToolParam = Field( @@ -55,12 +61,16 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): @model_validator(mode="after") def _validate_tool_inputs(self): - from databricks.vector_search.client import VectorSearchClient # import here so we can mock in tests + from databricks.vector_search.client import ( + VectorSearchClient, # import here so we can mock in tests + ) self._index = VectorSearchClient().get_index(index_name=self.index_name) self._index_details = IndexDetails(self._index) self.text_column = validate_and_get_text_column(self.text_column, self._index_details) - self.columns = validate_and_get_return_columns(self.columns or [], self.text_column, self._index_details) + self.columns = validate_and_get_return_columns( + self.columns or [], self.text_column, self._index_details + ) # OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'." # The '.' from the index name are not allowed @@ -70,15 +80,18 @@ def rewrite_index_name(index_name: str): self.tool = pydantic_function_tool( VectorSearchRetrieverToolInput, name=self.tool_name or rewrite_index_name(self.index_name), - description=self.tool_description or self._get_default_tool_description(self._index_details) + description=self.tool_description + or self._get_default_tool_description(self._index_details), ) return self - def execute_retriever_calls(self, - response: ChatCompletion, - choice_index: int = 0, - embedding_model_name: str = None, - openai_client: OpenAI = None) -> List[Dict[str, Any]]: + def execute_retriever_calls( + self, + response: ChatCompletion, + choice_index: int = 0, + embedding_model_name: str = None, + openai_client: OpenAI = None, + ) -> List[Dict[str, Any]]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. @@ -97,7 +110,9 @@ def execute_retriever_calls(self, that correspond to the self.tool VectorSearchRetrieverToolInput. """ - def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Optional[str], Optional[List[float]]]: + def get_query_text_vector( + tool_call: ChatCompletionMessageToolCall, + ) -> Tuple[Optional[str], Optional[List[float]]]: query = json.loads(tool_call.function.arguments)["query"] if self._index_details.is_databricks_managed_embeddings(): if embedding_model_name: @@ -109,20 +124,29 @@ def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Opt # For non-Databricks-managed embeddings from openai import OpenAI + oai_client = openai_client or OpenAI() if not oai_client.api_key: - raise ValueError("OpenAI API key is required to generate embeddings for retrieval queries.") + raise ValueError( + "OpenAI API key is required to generate embeddings for retrieval queries." + ) if not embedding_model_name: - raise ValueError("The embedding model name is required for non-Databricks-managed " - "embeddings Vector Search indexes in order to generate embeddings for retrieval queries.") + raise ValueError( + "The embedding model name is required for non-Databricks-managed " + "embeddings Vector Search indexes in order to generate embeddings for retrieval queries." + ) text = query if self.query_type and self.query_type.upper() == "HYBRID" else None - vector = oai_client.embeddings.create( - input=query, - model=embedding_model_name - ).data[0].embedding - if (index_embedding_dimension := self._index_details.embedding_vector_column.get("embedding_dimension")) and \ - len(vector) != index_embedding_dimension: + vector = ( + oai_client.embeddings.create(input=query, model=embedding_model_name) + .data[0] + .embedding + ) + if ( + index_embedding_dimension := self._index_details.embedding_vector_column.get( + "embedding_dimension" + ) + ) and len(vector) != index_embedding_dimension: raise ValueError( f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" ) @@ -130,9 +154,13 @@ def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Opt def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys()) - vs_index_arguments: Set[str] = set(self.tool["function"]["parameters"]["properties"].keys()) - return tool_call.function.name == self.tool["function"]["name"] and \ - tool_call_arguments == vs_index_arguments + vs_index_arguments: Set[str] = set( + self.tool["function"]["parameters"]["properties"].keys() + ) + return ( + tool_call.function.name == self.tool["function"]["name"] + and tool_call_arguments == vs_index_arguments + ) if type(response) is not ChatCompletion: raise ValueError("response must be an instance of ChatCompletion") @@ -154,14 +182,13 @@ def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: num_results=self.num_results, query_type=self.query_type, ) - docs_with_score: List[Tuple[Dict, float]] = \ - parse_vector_search_response( - search_resp, - self._index_details, - self.text_column, - ignore_cols=[], - document_class=dict - ) + docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( + search_resp, + self._index_details, + self.text_column, + ignore_cols=[], + document_class=dict, + ) function_call_result_message = { "role": "tool", diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index a9c6bc9..9fc4cff 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,9 +1,8 @@ -from typing import Any, Dict, Generator, List, Optional +import os +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch import pytest - -from unittest import mock -from unittest.mock import MagicMock, patch from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 ALL_INDEX_NAMES, DELTA_SYNC_INDEX, @@ -11,61 +10,65 @@ mock_vs_client, mock_workspace_client, ) -from databricks_openai import VectorSearchRetrieverTool -from pydantic import BaseModel from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call_param import Function -import os +from pydantic import BaseModel + +from databricks_openai import VectorSearchRetrieverTool + @pytest.fixture(autouse=True) def mock_openai_client(): mock_client = MagicMock() mock_client.api_key = "fake_api_key" - mock_client.embeddings.create.return_value = { - "data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}] - } + mock_client.embeddings.create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}]} with patch("openai.OpenAI", return_value=mock_client): yield mock_client + def get_chat_completion_response(tool_name: str, index_name: str): return ChatCompletion( - id='chatcmpl-AlSTQf3qIjeEOdoagPXUYhuWZkwme', + id="chatcmpl-AlSTQf3qIjeEOdoagPXUYhuWZkwme", choices=[ Choice( - finish_reason='tool_calls', + finish_reason="tool_calls", index=0, logprobs=None, message=ChatCompletionMessage( content=None, refusal=None, - role='assistant', + role="assistant", audio=None, function_call=None, tool_calls=[ ChatCompletionMessageToolCall( - id='call_VtmBTsVM2zQ3yL5GzddMgWb0', + id="call_VtmBTsVM2zQ3yL5GzddMgWb0", function=Function( arguments='{"query":"Databricks Agent Framework"}', - name=tool_name or index_name.split(".")[-1] # see rewrite_index_name() in VectorSearchRetrieverTool + name=tool_name + or index_name.split(".")[ + -1 + ], # see rewrite_index_name() in VectorSearchRetrieverTool ), - type='function' + type="function", ) - ] - ) + ], + ), ) ], created=1735874232, - model='gpt-4o-mini-2024-07-18', - object='chat.completion', + model="gpt-4o-mini-2024-07-18", + object="chat.completion", ) + def init_vector_search_tool( - index_name: str, - columns: Optional[List[str]] = None, - tool_name: Optional[str] = None, - tool_description: Optional[str] = None, - text_column: Optional[str] = None, + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + text_column: Optional[str] = None, ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { "index_name": index_name, @@ -82,27 +85,32 @@ def init_vector_search_tool( ) return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] + class SelfManagedEmbeddingsTest: def __init__(self, text_column=None, embedding_model_name=None, open_ai_client=None): self.text_column = text_column self.embedding_model_name = embedding_model_name self.open_ai_client = open_ai_client + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("columns", [None, ["id", "text"]]) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) @pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) def test_vector_search_retriever_tool_init( - index_name: str, - columns: Optional[List[str]], - tool_name: Optional[str], - tool_description: Optional[str] + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], ) -> None: if index_name == DELTA_SYNC_INDEX: self_managed_embeddings_test = SelfManagedEmbeddingsTest() else: from openai import OpenAI - self_managed_embeddings_test = SelfManagedEmbeddingsTest("text", "text-embedding-3-small", OpenAI(api_key="your-api-key")) + + self_managed_embeddings_test = SelfManagedEmbeddingsTest( + "text", "text-embedding-3-small", OpenAI(api_key="your-api-key") + ) vector_search_tool = init_vector_search_tool( index_name=index_name, @@ -117,7 +125,7 @@ def test_vector_search_retriever_tool_init( response = vector_search_tool.execute_retriever_calls( chat_completion_resp, embedding_model_name=self_managed_embeddings_test.embedding_model_name, - openai_client=self_managed_embeddings_test.open_ai_client + openai_client=self_managed_embeddings_test.open_ai_client, ) assert response is not None @@ -126,9 +134,7 @@ def test_vector_search_retriever_tool_init( @pytest.mark.parametrize("tool_name", [None, "test_tool"]) @pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) def test_open_ai_client_from_env( - columns: Optional[List[str]], - tool_name: Optional[str], - tool_description: Optional[str] + columns: Optional[List[str]], tool_name: Optional[str], tool_description: Optional[str] ) -> None: self_managed_embeddings_test = SelfManagedEmbeddingsTest("text", "text-embedding-3-small", None) os.environ["OPENAI_API_KEY"] = "your-api-key" @@ -146,6 +152,6 @@ def test_open_ai_client_from_env( response = vector_search_tool.execute_retriever_calls( chat_completion_resp, embedding_model_name=self_managed_embeddings_test.embedding_model_name, - openai_client=self_managed_embeddings_test.open_ai_client + openai_client=self_managed_embeddings_test.open_ai_client, ) assert response is not None From e954e9e9cfab0661b5ea137ee2caa92ec016e070 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Thu, 9 Jan 2025 12:39:41 -0800 Subject: [PATCH 11/14] Minor cleanup --- .../databricks_openai/vector_search_retriever_tool.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index ad1e008..afb4749 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -75,7 +75,7 @@ def _validate_tool_inputs(self): # OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'." # The '.' from the index name are not allowed def rewrite_index_name(index_name: str): - return index_name.split(".")[-1] + return index_name.replace(".", "_") self.tool = pydantic_function_tool( VectorSearchRetrieverToolInput, @@ -183,10 +183,9 @@ def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool: query_type=self.query_type, ) docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( - search_resp, - self._index_details, - self.text_column, - ignore_cols=[], + search_resp=search_resp, + index_details=self._index_details, + text_column=self.text_column, document_class=dict, ) From e94cc9e35b21e21208ee67188b0bc32a6b93e78e Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Mon, 13 Jan 2025 10:42:55 -0800 Subject: [PATCH 12/14] PR feedback --- .../openai/src/databricks_openai/__init__.py | 2 +- .../vector_search_retriever_tool.py | 5 ++--- .../test_vector_search_retriever_tool.py | 20 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/integrations/openai/src/databricks_openai/__init__.py b/integrations/openai/src/databricks_openai/__init__.py index 015814c..ff3512a 100644 --- a/integrations/openai/src/databricks_openai/__init__.py +++ b/integrations/openai/src/databricks_openai/__init__.py @@ -1,6 +1,6 @@ from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool -# Expose all integrations to users under databricks-langchain +# Expose all integrations to users under databricks-openai __all__ = [ "VectorSearchRetrieverTool", ] diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index afb4749..7b59e1e 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -49,8 +49,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): None, description="The name of the text column to use for the embeddings. " "Required for direct-access index or delta-sync index with " - "self-managed embeddings. Used for direct access indexes or " - "delta-sync indexes with self-managed embeddings", + "self-managed embeddings.", ) tool: ChatCompletionToolParam = Field( @@ -94,7 +93,7 @@ def execute_retriever_calls( ) -> List[Dict[str, Any]]: """ Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the - self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into toll call messages. + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. Args: response: The chat completion response object returned by the OpenAI API. diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 9fc4cff..174cdff 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,6 +1,6 @@ import os from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock import pytest from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 @@ -10,10 +10,10 @@ mock_vs_client, mock_workspace_client, ) -from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall +from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionMessageParam from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call_param import Function -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from databricks_openai import VectorSearchRetrieverTool @@ -22,7 +22,9 @@ def mock_openai_client(): mock_client = MagicMock() mock_client.api_key = "fake_api_key" - mock_client.embeddings.create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3, 0.4]}]} + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3, 0.4])] + mock_client.embeddings.create.return_value = mock_response with patch("openai.OpenAI", return_value=mock_client): yield mock_client @@ -127,7 +129,15 @@ def test_vector_search_retriever_tool_init( embedding_model_name=self_managed_embeddings_test.embedding_model_name, openai_client=self_managed_embeddings_test.open_ai_client, ) - assert response is not None + assert isinstance(response, list) + + # ChatCompletionMessageParam is a union of different ChatCompletionMessage types so we check that each + # element in the list is a union member + adapter = TypeAdapter(List[ChatCompletionMessageParam]) + parsed_list = adapter.validate_python(response) + + # parsed_list is now a list of union members + assert len(parsed_list) == len(response) @pytest.mark.parametrize("columns", [None, ["id", "text"]]) From 1fcc397c9e78b348eaf8cb8b5125512ba1b37e93 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Mon, 13 Jan 2025 10:44:01 -0800 Subject: [PATCH 13/14] Lint --- .../unit_tests/test_vector_search_retriever_tool.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 174cdff..51be437 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,6 +1,6 @@ import os from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, Mock, patch import pytest from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 @@ -10,7 +10,12 @@ mock_vs_client, mock_workspace_client, ) -from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionMessageParam +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, +) from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call_param import Function from pydantic import BaseModel, TypeAdapter From 2b3b9d6f9f43c3090f01633658ae7b341155c894 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Mon, 13 Jan 2025 17:05:17 -0800 Subject: [PATCH 14/14] Rename tool call --- .../src/databricks_openai/vector_search_retriever_tool.py | 4 ++-- .../tests/unit_tests/test_vector_search_retriever_tool.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 7b59e1e..efb2990 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -32,7 +32,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): messages=initial_messages, tools=tools, ) - retriever_call_message = dbvs_tool.execute_retriever_calls(response) + retriever_call_message = dbvs_tool.execute_calls(response) ### If needed, execute potential remaining tool calls here ### remaining_tool_call_messages = execute_remaining_tool_calls(response) @@ -84,7 +84,7 @@ def rewrite_index_name(index_name: str): ) return self - def execute_retriever_calls( + def execute_calls( self, response: ChatCompletion, choice_index: int = 0, diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 51be437..c80fd20 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -129,7 +129,7 @@ def test_vector_search_retriever_tool_init( assert isinstance(vector_search_tool, BaseModel) # simulate call to openai.chat.completions.create chat_completion_resp = get_chat_completion_response(tool_name, index_name) - response = vector_search_tool.execute_retriever_calls( + response = vector_search_tool.execute_calls( chat_completion_resp, embedding_model_name=self_managed_embeddings_test.embedding_model_name, openai_client=self_managed_embeddings_test.open_ai_client, @@ -164,7 +164,7 @@ def test_open_ai_client_from_env( assert isinstance(vector_search_tool, BaseModel) # simulate call to openai.chat.completions.create chat_completion_resp = get_chat_completion_response(tool_name, DIRECT_ACCESS_INDEX) - response = vector_search_tool.execute_retriever_calls( + response = vector_search_tool.execute_calls( chat_completion_resp, embedding_model_name=self_managed_embeddings_test.embedding_model_name, openai_client=self_managed_embeddings_test.open_ai_client,