-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
233 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 70 additions & 68 deletions
138
integrations/langchain/src/databricks_langchain/vector_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,81 @@ | ||
from databricks_langchain import DatabricksVectorSearch | ||
from typing import ( | ||
List, | ||
Optional | ||
) | ||
from langchain_core.embeddings import Embeddings | ||
from langchain.tools.retriever import create_retriever_tool | ||
from typing import Any, Dict, List, Optional | ||
|
||
from typing import Any, Dict, Optional, Type | ||
from pydantic import BaseModel, Field, model_validator, PrivateAttr | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
from databricks_langchain import DatabricksVectorSearch | ||
from databricks_langchain.utils import IndexDetails | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.tools import BaseTool | ||
from pydantic import BaseModel, Field | ||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
|
||
class VectorSearchRetrieverToolInput(BaseModel): | ||
query: str = Field(description="query used to search the index") | ||
|
||
class VectorSearchRetrieverTool(): | ||
class VectorSearchRetrieverTool(BaseTool): | ||
""" | ||
A utility class to create a vector search-based retrieval tool for querying indexed embeddings. | ||
This class integrates with a Databricks Vector Search and provides a convenient interface | ||
for building a retriever tool for agents. | ||
Parameters: | ||
index_name (str): | ||
The name of the index to use. Format: “catalog.schema.index”. endpoint: | ||
num_results (int): | ||
The number of results to return. Defaults to 10. | ||
columns (Optional[List[str]]): | ||
The list of column names to get when doing the search. Defaults to [primary_key, text_column]. | ||
filters (Optional[Dict[str, Any]]): | ||
The filters to apply to the search. Defaults to None. | ||
query_type (str): | ||
The type of query to run. Defaults to "ANN". | ||
tool_name (str): | ||
The name of the retrieval tool to be created. This will be passed to the language model, | ||
so should be unique and somewhat descriptive. | ||
tool_description (str): | ||
A description of the tool's functionality. This will be passed to the language model, | ||
so should be descriptive. | ||
""" | ||
name: str = Field(description="The name of the tool") | ||
description: str = Field(description="The description of the tool") | ||
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput | ||
def __init__( | ||
self, | ||
index_name: str, | ||
num_results: int = 10, | ||
*, | ||
columns: Optional[List[str]] = None, | ||
filters: Optional[Dict[str, Any]] = None, | ||
query_type: str = "ANN", | ||
tool_name: Optional[str] = None, | ||
tool_description: Optional[str], # TODO: By default the UC metadata for description, how do I get this info? Call using client? | ||
): | ||
# Use the index name as the tool name if no tool name is provided | ||
self.name = index_name | ||
if tool_name: | ||
self.name = tool_name | ||
self.num_results = num_results | ||
self.columns = columns | ||
self.filters = filters | ||
self.query_type = query_type | ||
self.description = tool_description | ||
self.vector_store = DatabricksVectorSearch(index_name=index_name) | ||
|
||
def _run( | ||
self, | ||
query: str | ||
) -> str: | ||
"""Use the tool.""" | ||
self.vector_store.similarity_search(query, self.num_results, self.columns, self.filters, self.query_type) | ||
index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.") | ||
num_results: int = Field(10, description="The number of results to return.") | ||
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.") | ||
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") | ||
query_type: str = Field("ANN", description="The type of query to run.") | ||
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") | ||
tool_description: Optional[str] = Field(None, description="A description of the tool.") | ||
# TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings, | ||
text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.") | ||
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.") | ||
# TODO: Confirm if we can add this endpoint field | ||
endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.") | ||
|
||
# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() | ||
name: str = Field(default="", description="The name of the tool") | ||
description: str = Field(default="", description="The description of the tool") | ||
|
||
_vector_store: DatabricksVectorSearch = PrivateAttr() | ||
|
||
@model_validator(mode='after') | ||
def validate_tool_inputs(self): | ||
# Construct the vector store using provided params | ||
kwargs = { | ||
"index_name": self.index_name, | ||
"endpoint": self.endpoint, | ||
"embedding": self.embedding, | ||
"text_column": self.text_column, | ||
"columns": self.columns, | ||
} | ||
dbvs = DatabricksVectorSearch(**kwargs) | ||
self._vector_store = dbvs | ||
|
||
def get_tool_description(): | ||
default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings." | ||
index_details = IndexDetails(dbvs.index) | ||
if index_details.is_delta_sync_index(): | ||
from databricks.sdk import WorkspaceClient | ||
|
||
source_table = index_details.index_spec.get('source_table', "") | ||
w = WorkspaceClient() | ||
source_table_comment = w.tables.get(full_name=source_table).comment | ||
if source_table_comment: | ||
return ( | ||
default_tool_description + | ||
f" The queried index uses the source table {source_table} with the description: " + | ||
source_table_comment | ||
) | ||
else: | ||
return default_tool_description + f" The queried index uses the source table {source_table}" | ||
return default_tool_description | ||
|
||
self.name = self.tool_name or self.index_name | ||
self.description = self.tool_description or get_tool_description() | ||
|
||
return self | ||
|
||
|
||
def _run(self, query: str) -> str: | ||
return self._vector_store.similarity_search( | ||
query, | ||
k = self.num_results, | ||
filter = self.filters, | ||
query_type = self.query_type | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.