Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
leonbi100 committed Dec 20, 2024
1 parent 9166b5a commit 53f9924
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool"
"VectorSearchRetrieverTool",
]
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity


class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"


class IndexDetails:
"""An utility class to store the configuration details of an index."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@


class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents.")
query: str = Field(
description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents."
)


class VectorSearchRetrieverTool(BaseTool):
"""
Expand All @@ -19,17 +22,28 @@ class VectorSearchRetrieverTool(BaseTool):
for building a retriever tool for agents.
"""

index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.")
index_name: str = Field(
..., description="The name of the index to use, format: 'catalog.schema.index'."
)
num_results: int = Field(10, description="The number of results to return.")
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.")
columns: Optional[List[str]] = Field(
None, description="Columns to return when doing the search."
)
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.")
query_type: str = Field("ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'.")
query_type: str = Field(
"ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'."
)
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.")
tool_description: Optional[str] = Field(None, description="A description of the tool.")
text_column: Optional[str] = Field(None, description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.")
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.")
text_column: Optional[str] = Field(
None,
description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.",
)
embedding: Optional[Embeddings] = Field(
None, description="Embedding model for self-managed embeddings."
)

# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
Expand All @@ -38,7 +52,7 @@ class VectorSearchRetrieverTool(BaseTool):

_vector_store: DatabricksVectorSearch = PrivateAttr()

@model_validator(mode='after')
@model_validator(mode="after")
def validate_tool_inputs(self):
kwargs = {
"index_name": self.index_name,
Expand All @@ -50,34 +64,35 @@ def validate_tool_inputs(self):
self._vector_store = dbvs

def get_tool_description():
default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings."
default_tool_description = (
"A vector search-based retrieval tool for querying indexed embeddings."
)
index_details = IndexDetails(dbvs.index)
if index_details.is_delta_sync_index():
from databricks.sdk import WorkspaceClient

source_table = index_details.index_spec.get('source_table', "")
source_table = index_details.index_spec.get("source_table", "")
w = WorkspaceClient()
source_table_comment = w.tables.get(full_name=source_table).comment
if source_table_comment:
return (
default_tool_description +
f" The queried index uses the source table {source_table} with the description: " +
source_table_comment
default_tool_description
+ f" The queried index uses the source table {source_table} with the description: "
+ source_table_comment
)
else:
return default_tool_description + f" The queried index uses the source table {source_table}"
return (
default_tool_description
+ f" The queried index uses the source table {source_table}"
)
return default_tool_description

self.name = self.tool_name or self.index_name
self.description = self.tool_description or get_tool_description()

return self


def _run(self, query: str) -> str:
return self._vector_store.similarity_search(
query,
k = self.num_results,
filter = self.filters,
query_type = self.query_type
query, k=self.num_results, filter=self.filters, query_type=self.query_type
)
Original file line number Diff line number Diff line change
Expand Up @@ -775,4 +775,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe
f"The specified embedding model's dimension '{actual_dimension}' does "
f"not match with the index configuration '{index_embedding_dimension}'."
)

Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@


def init_vector_search_tool(
index_name: str,
columns: Optional[List[str]] = None,
tool_name: Optional[str] = None,
tool_description: Optional[str] = None,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
endpoint: Optional[str] = None
index_name: str,
columns: Optional[List[str]] = None,
tool_name: Optional[str] = None,
tool_description: Optional[str] = None,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
endpoint: Optional[str] = None,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
Expand All @@ -42,22 +42,23 @@ def init_vector_search_tool(
)
return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type]


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_init(index_name: str) -> None:
vector_search_tool = init_vector_search_tool(index_name)
assert isinstance(vector_search_tool, BaseTool)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None:
from langchain_core.messages import AIMessage

vector_search_tool = init_vector_search_tool(index_name)
llm_with_tools = llm.bind_tools([vector_search_tool])
response = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
assert isinstance(response, AIMessage)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("columns", [None, ["id", "text"]])
@pytest.mark.parametrize("tool_name", [None, "test_tool"])
Expand All @@ -66,13 +67,13 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None:
@pytest.mark.parametrize("text_column", [None, "text"])
@pytest.mark.parametrize("endpoint", [None, "test_endpoint"])
def test_vector_search_retriever_tool_combinations(
index_name: str,
columns: Optional[List[str]],
tool_name: Optional[str],
tool_description: Optional[str],
embedding: Optional[Any],
text_column: Optional[str],
endpoint: Optional[str]
index_name: str,
columns: Optional[List[str]],
tool_name: Optional[str],
tool_description: Optional[str],
embedding: Optional[Any],
text_column: Optional[str],
endpoint: Optional[str],
) -> None:
if index_name == DELTA_SYNC_INDEX:
embedding = None
Expand All @@ -85,8 +86,8 @@ def test_vector_search_retriever_tool_combinations(
tool_description=tool_description,
embedding=embedding,
text_column=text_column,
endpoint=endpoint
endpoint=endpoint,
)
assert isinstance(vector_search_tool, BaseTool)
result = vector_search_tool.invoke("Databricks Agent Framework")
assert result is not None
assert result is not None
6 changes: 2 additions & 4 deletions integrations/langchain/tests/utils/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"message": {
"role": "assistant",
"content": "To calculate the result of 36939 multiplied by 8922.4, "
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
},
"finish_reason": "stop",
"logprobs": None,
Expand Down Expand Up @@ -130,6 +130,4 @@ def mock_client() -> Generator:

@pytest.fixture
def llm() -> ChatDatabricks:
return ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
)
return ChatDatabricks(endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks")
34 changes: 16 additions & 18 deletions integrations/langchain/tests/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (self.dimension - 1) + [float(i)]
for i in range(len(embedding_texts))
[float(1.0)] * (self.dimension - 1) + [float(i)] for i in range(len(embedding_texts))
]

def embed_query(self, text: str) -> List[float]:
Expand All @@ -49,9 +48,7 @@ def embed_query(self, text: str) -> List[float]:
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, 0.5]
for s, e in zip(
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
)
for s, e in zip(INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS))
],
key=lambda x: x[2], # type: ignore
reverse=True,
Expand Down Expand Up @@ -119,12 +116,12 @@ def embed_query(self, text: str) -> List[float]:
}
],
"schema_json": f"{{"
f'"{"id"}": "int", '
f'"feat1": "str", '
f'"feat2": "float", '
f'"text": "string", '
f'"{"text_vector"}": "array<float>"'
f"}}",
f'"{"id"}": "int", '
f'"feat1": "str", '
f'"feat2": "float", '
f'"text": "string", '
f'"{"text_vector"}": "array<float>"'
f"}}",
},
},
}
Expand All @@ -133,8 +130,8 @@ def embed_query(self, text: str) -> List[float]:
@pytest.fixture(autouse=True)
def mock_vs_client() -> Generator:
def _get_index(
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
) -> MagicMock:
index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
Expand All @@ -144,11 +141,12 @@ def _get_index(
mock_client = MagicMock()
mock_client.get_index.side_effect = _get_index
with mock.patch(
"databricks.vector_search.client.VectorSearchClient",
return_value=mock_client,
"databricks.vector_search.client.VectorSearchClient",
return_value=mock_client,
):
yield


@pytest.fixture(autouse=True)
def mock_workspace_client() -> Generator:
def _get_table_comment(full_name: str) -> MagicMock:
Expand All @@ -159,7 +157,7 @@ def _get_table_comment(full_name: str) -> MagicMock:
mock_client = MagicMock()
mock_client.tables.get.side_effect = _get_table_comment
with patch(
"databricks.sdk.WorkspaceClient",
return_value=mock_client,
"databricks.sdk.WorkspaceClient",
return_value=mock_client,
):
yield
yield

0 comments on commit 53f9924

Please sign in to comment.