diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index e94c2bc..692e13a 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -2,7 +2,13 @@ from urllib.parse import urlparse import numpy as np +from enum import Enum +import json +from typing import ( + Dict, + Optional +) def get_deployment_client(target_uri: str) -> Any: if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"): @@ -95,3 +101,58 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity + +class IndexType(str, Enum): + DIRECT_ACCESS = "DIRECT_ACCESS" + DELTA_SYNC = "DELTA_SYNC" + +class IndexDetails: + """An utility class to store the configuration details of an index.""" + + def __init__(self, index: Any): + self._index_details = index.describe() + + @property + def name(self) -> str: + return self._index_details["name"] + + @property + def schema(self) -> Optional[Dict]: + if self.is_direct_access_index(): + schema_json = self.index_spec.get("schema_json") + if schema_json is not None: + return json.loads(schema_json) + return None + + @property + def primary_key(self) -> str: + return self._index_details["primary_key"] + + @property + def index_spec(self) -> Dict: + return ( + self._index_details.get("delta_sync_index_spec", {}) + if self.is_delta_sync_index() + else self._index_details.get("direct_access_index_spec", {}) + ) + + @property + def embedding_vector_column(self) -> Dict: + if vector_columns := self.index_spec.get("embedding_vector_columns"): + return vector_columns[0] + return {} + + @property + def embedding_source_column(self) -> Dict: + if source_columns := self.index_spec.get("embedding_source_columns"): + return source_columns[0] + return {} + + def is_delta_sync_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DELTA_SYNC.value + + def is_direct_access_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value + + def is_databricks_managed_embeddings(self) -> bool: + return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None diff --git a/integrations/langchain/src/databricks_langchain/vector_search.py b/integrations/langchain/src/databricks_langchain/vector_search.py index 8871086..2e02ec6 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search.py +++ b/integrations/langchain/src/databricks_langchain/vector_search.py @@ -1,79 +1,81 @@ -from databricks_langchain import DatabricksVectorSearch -from typing import ( - List, - Optional -) -from langchain_core.embeddings import Embeddings -from langchain.tools.retriever import create_retriever_tool +from typing import Any, Dict, List, Optional -from typing import Any, Dict, Optional, Type +from pydantic import BaseModel, Field, model_validator, PrivateAttr -from langchain.callbacks.manager import ( - AsyncCallbackManagerForToolRun, - CallbackManagerForToolRun, -) +from databricks_langchain import DatabricksVectorSearch +from databricks_langchain.utils import IndexDetails +from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field -from langchain_core.callbacks import ( - AsyncCallbackManagerForToolRun, - CallbackManagerForToolRun, -) -class VectorSearchRetrieverToolInput(BaseModel): - query: str = Field(description="query used to search the index") -class VectorSearchRetrieverTool(): +class VectorSearchRetrieverTool(BaseTool): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. This class integrates with a Databricks Vector Search and provides a convenient interface for building a retriever tool for agents. - - Parameters: - index_name (str): - The name of the index to use. Format: “catalog.schema.index”. endpoint: - num_results (int): - The number of results to return. Defaults to 10. - columns (Optional[List[str]]): - The list of column names to get when doing the search. Defaults to [primary_key, text_column]. - filters (Optional[Dict[str, Any]]): - The filters to apply to the search. Defaults to None. - query_type (str): - The type of query to run. Defaults to "ANN". - tool_name (str): - The name of the retrieval tool to be created. This will be passed to the language model, - so should be unique and somewhat descriptive. - tool_description (str): - A description of the tool's functionality. This will be passed to the language model, - so should be descriptive. """ - name: str = Field(description="The name of the tool") - description: str = Field(description="The description of the tool") - args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput - def __init__( - self, - index_name: str, - num_results: int = 10, - *, - columns: Optional[List[str]] = None, - filters: Optional[Dict[str, Any]] = None, - query_type: str = "ANN", - tool_name: Optional[str] = None, - tool_description: Optional[str], # TODO: By default the UC metadata for description, how do I get this info? Call using client? - ): - # Use the index name as the tool name if no tool name is provided - self.name = index_name - if tool_name: - self.name = tool_name - self.num_results = num_results - self.columns = columns - self.filters = filters - self.query_type = query_type - self.description = tool_description - self.vector_store = DatabricksVectorSearch(index_name=index_name) - def _run( - self, - query: str - ) -> str: - """Use the tool.""" - self.vector_store.similarity_search(query, self.num_results, self.columns, self.filters, self.query_type) + index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.") + num_results: int = Field(10, description="The number of results to return.") + columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.") + filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") + query_type: str = Field("ANN", description="The type of query to run.") + tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") + tool_description: Optional[str] = Field(None, description="A description of the tool.") + # TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings, + text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.") + embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.") + # TODO: Confirm if we can add this endpoint field + endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.") + + # The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() + name: str = Field(default="", description="The name of the tool") + description: str = Field(default="", description="The description of the tool") + + _vector_store: DatabricksVectorSearch = PrivateAttr() + + @model_validator(mode='after') + def validate_tool_inputs(self): + # Construct the vector store using provided params + kwargs = { + "index_name": self.index_name, + "endpoint": self.endpoint, + "embedding": self.embedding, + "text_column": self.text_column, + "columns": self.columns, + } + dbvs = DatabricksVectorSearch(**kwargs) + self._vector_store = dbvs + + def get_tool_description(): + default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings." + index_details = IndexDetails(dbvs.index) + if index_details.is_delta_sync_index(): + from databricks.sdk import WorkspaceClient + + source_table = index_details.index_spec.get('source_table', "") + w = WorkspaceClient() + source_table_comment = w.tables.get(full_name=source_table).comment + if source_table_comment: + return ( + default_tool_description + + f" The queried index uses the source table {source_table} with the description: " + + source_table_comment + ) + else: + return default_tool_description + f" The queried index uses the source table {source_table}" + return default_tool_description + + self.name = self.tool_name or self.index_name + self.description = self.tool_description or get_tool_description() + + return self + + + def _run(self, query: str) -> str: + return self._vector_store.similarity_search( + query, + k = self.num_results, + filter = self.filters, + query_type = self.query_type + ) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index e67f315..c6b124c 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -import json import logging import re import uuid -from enum import Enum from functools import partial from typing import ( Any, @@ -23,16 +21,11 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore -from databricks_langchain.utils import maximal_marginal_relevance +from databricks_langchain.utils import maximal_marginal_relevance, IndexDetails logger = logging.getLogger(__name__) -class IndexType(str, Enum): - DIRECT_ACCESS = "DIRECT_ACCESS" - DELTA_SYNC = "DELTA_SYNC" - - _DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index." _NON_MANAGED_EMB_ONLY_MSG = "`%s` is not supported for index with Databricks-managed embeddings." _INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") @@ -783,54 +776,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe f"not match with the index configuration '{index_embedding_dimension}'." ) - -class IndexDetails: - """An utility class to store the configuration details of an index.""" - - def __init__(self, index: Any): - self._index_details = index.describe() - - @property - def name(self) -> str: - return self._index_details["name"] - - @property - def schema(self) -> Optional[Dict]: - if self.is_direct_access_index(): - schema_json = self.index_spec.get("schema_json") - if schema_json is not None: - return json.loads(schema_json) - return None - - @property - def primary_key(self) -> str: - return self._index_details["primary_key"] - - @property - def index_spec(self) -> Dict: - return ( - self._index_details.get("delta_sync_index_spec", {}) - if self.is_delta_sync_index() - else self._index_details.get("direct_access_index_spec", {}) - ) - - @property - def embedding_vector_column(self) -> Dict: - if vector_columns := self.index_spec.get("embedding_vector_columns"): - return vector_columns[0] - return {} - - @property - def embedding_source_column(self) -> Dict: - if source_columns := self.index_spec.get("embedding_source_columns"): - return source_columns[0] - return {} - - def is_delta_sync_index(self) -> bool: - return self._index_details["index_type"] == IndexType.DELTA_SYNC.value - - def is_direct_access_index(self) -> bool: - return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value - - def is_databricks_managed_embeddings(self) -> bool: - return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None diff --git a/integrations/langchain/tests/test_vector_search.py b/integrations/langchain/tests/test_vector_search.py deleted file mode 100644 index b3a4892..0000000 --- a/integrations/langchain/tests/test_vector_search.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Any, Dict, Generator, List, Optional, Set - -import pytest -from databricks.vector_search.client import VectorSearchIndex # type: ignore - -from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks -from tests.utils.vector_search import EMBEDDING_MODEL, DELTA_SYNC_INDEX, ALL_INDEX_NAMES, mock_vs_client -from tests.utils.chat_models import mock_client, llm -from langchain_core.tools import BaseTool - -def init_vector_search_tool( - index_name: str, columns: Optional[List[str]] = None -) -> VectorSearchRetrieverTool: - kwargs: Dict[str, Any] = { - "index_name": index_name, - "columns": columns, - "tool_name": "test_tool", - "tool_description": "Test tool for vector search", - } - if index_name != DELTA_SYNC_INDEX: - kwargs.update( - { - "embedding": EMBEDDING_MODEL, - "text_column": "text", - } - ) - return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_init(index_name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert isinstance(vector_search_tool, BaseTool) - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: - from langchain_core.messages import AIMessage - - vector_search_tool = init_vector_search_tool(index_name) - llm_with_tools = llm.bind_tools([vector_search_tool]) - response = llm_with_tools.invoke( - "Which city is hotter today and which is bigger: LA or NY?" - ) - assert isinstance(response, AIMessage) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search.py b/integrations/langchain/tests/unit_tests/test_vector_search.py new file mode 100644 index 0000000..e0a1535 --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_vector_search.py @@ -0,0 +1,86 @@ +from typing import Any, Dict, Generator, List, Optional, Set + +import pytest +from databricks.vector_search.client import VectorSearchIndex # type: ignore + +from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks +from tests.utils.vector_search import EMBEDDING_MODEL, DELTA_SYNC_INDEX, ALL_INDEX_NAMES, mock_vs_client, mock_workspace_client, mock_workspace_client +from tests.utils.chat_models import mock_client, llm +from langchain_core.tools import BaseTool +from langchain_core.embeddings import Embeddings + +def init_vector_search_tool( + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + endpoint: Optional[str] = None +) -> VectorSearchRetrieverTool: + kwargs: Dict[str, Any] = { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + "endpoint": endpoint, + } + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding": EMBEDDING_MODEL, + "text_column": "text", + } + ) + return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert isinstance(vector_search_tool, BaseTool) + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: + from langchain_core.messages import AIMessage + + vector_search_tool = init_vector_search_tool(index_name) + llm_with_tools = llm.bind_tools([vector_search_tool]) + response = llm_with_tools.invoke( + "Which city is hotter today and which is bigger: LA or NY?" + ) + assert isinstance(response, AIMessage) + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) +@pytest.mark.parametrize("text_column", [None, "text"]) +@pytest.mark.parametrize("endpoint", [None, "test_endpoint"]) +def test_vector_search_retriever_tool_combinations( + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], + embedding: Optional[Any], + text_column: Optional[str], + endpoint: Optional[str] +) -> None: + if index_name == DELTA_SYNC_INDEX: + embedding = None + text_column = None + + vector_search_tool = init_vector_search_tool( + index_name=index_name, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + embedding=embedding, + text_column=text_column, + endpoint=endpoint + ) + assert isinstance(vector_search_tool, BaseTool) + result = vector_search_tool.invoke("Databricks Agent Framework") + assert result is not None \ No newline at end of file diff --git a/integrations/langchain/tests/utils/vector_search.py b/integrations/langchain/tests/utils/vector_search.py index 4a12fac..557ad4c 100644 --- a/integrations/langchain/tests/utils/vector_search.py +++ b/integrations/langchain/tests/utils/vector_search.py @@ -146,5 +146,20 @@ def _get_index( with mock.patch( "databricks.vector_search.client.VectorSearchClient", return_value=mock_client, + ): + yield + +@pytest.fixture(autouse=True) +def mock_workspace_client() -> Generator: + def _get_table_comment(full_name: str) -> MagicMock: + table = MagicMock() + table.comment = "Mocked table comment" + return table + + mock_client = MagicMock() + mock_client.tables.get.side_effect = _get_table_comment + with patch( + "databricks.sdk.WorkspaceClient", + return_value=mock_client, ): yield \ No newline at end of file