From 53f992484615a25d824c4b570f0d9ec6ab8e92a8 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Fri, 20 Dec 2024 12:20:58 -0500 Subject: [PATCH] Fix lint --- .../src/databricks_langchain/__init__.py | 2 +- .../src/databricks_langchain/utils.py | 2 + .../vector_search_retriever_tool.py | 57 ++++++++++++------- .../src/databricks_langchain/vectorstores.py | 1 - .../test_vector_search_retriever_tool.py | 39 ++++++------- .../langchain/tests/utils/chat_models.py | 6 +- .../langchain/tests/utils/vector_search.py | 34 ++++++----- 7 files changed, 77 insertions(+), 64 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 6a8da26..acc056e 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -10,5 +10,5 @@ "DatabricksEmbeddings", "DatabricksVectorSearch", "GenieAgent", - "VectorSearchRetrieverTool" + "VectorSearchRetrieverTool", ] diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index 3e42d05..8218ab9 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -98,10 +98,12 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity + class IndexType(str, Enum): DIRECT_ACCESS = "DIRECT_ACCESS" DELTA_SYNC = "DELTA_SYNC" + class IndexDetails: """An utility class to store the configuration details of an index.""" diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index f999faa..e21d5de 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -9,8 +9,11 @@ class VectorSearchRetrieverToolInput(BaseModel): - query: str = Field(description="The string used to query the index with and identify the most similar " - "vectors and return the associated documents.") + query: str = Field( + description="The string used to query the index with and identify the most similar " + "vectors and return the associated documents." + ) + class VectorSearchRetrieverTool(BaseTool): """ @@ -19,17 +22,28 @@ class VectorSearchRetrieverTool(BaseTool): for building a retriever tool for agents. """ - index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.") + index_name: str = Field( + ..., description="The name of the index to use, format: 'catalog.schema.index'." + ) num_results: int = Field(10, description="The number of results to return.") - columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.") + columns: Optional[List[str]] = Field( + None, description="Columns to return when doing the search." + ) filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") - query_type: str = Field("ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'.") + query_type: str = Field( + "ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'." + ) tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") tool_description: Optional[str] = Field(None, description="A description of the tool.") - text_column: Optional[str] = Field(None, description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.") - embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.") + text_column: Optional[str] = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", + ) + embedding: Optional[Embeddings] = Field( + None, description="Embedding model for self-managed embeddings." + ) # The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() name: str = Field(default="", description="The name of the tool") @@ -38,7 +52,7 @@ class VectorSearchRetrieverTool(BaseTool): _vector_store: DatabricksVectorSearch = PrivateAttr() - @model_validator(mode='after') + @model_validator(mode="after") def validate_tool_inputs(self): kwargs = { "index_name": self.index_name, @@ -50,22 +64,27 @@ def validate_tool_inputs(self): self._vector_store = dbvs def get_tool_description(): - default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings." + default_tool_description = ( + "A vector search-based retrieval tool for querying indexed embeddings." + ) index_details = IndexDetails(dbvs.index) if index_details.is_delta_sync_index(): from databricks.sdk import WorkspaceClient - source_table = index_details.index_spec.get('source_table', "") + source_table = index_details.index_spec.get("source_table", "") w = WorkspaceClient() source_table_comment = w.tables.get(full_name=source_table).comment if source_table_comment: return ( - default_tool_description + - f" The queried index uses the source table {source_table} with the description: " + - source_table_comment + default_tool_description + + f" The queried index uses the source table {source_table} with the description: " + + source_table_comment ) else: - return default_tool_description + f" The queried index uses the source table {source_table}" + return ( + default_tool_description + + f" The queried index uses the source table {source_table}" + ) return default_tool_description self.name = self.tool_name or self.index_name @@ -73,11 +92,7 @@ def get_tool_description(): return self - def _run(self, query: str) -> str: return self._vector_store.similarity_search( - query, - k = self.num_results, - filter = self.filters, - query_type = self.query_type + query, k=self.num_results, filter=self.filters, query_type=self.query_type ) diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index ef9f7e2..47477ee 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -775,4 +775,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe f"The specified embedding model's dimension '{actual_dimension}' does " f"not match with the index configuration '{index_embedding_dimension}'." ) - diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index a22bd28..7fc422e 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -16,13 +16,13 @@ def init_vector_search_tool( - index_name: str, - columns: Optional[List[str]] = None, - tool_name: Optional[str] = None, - tool_description: Optional[str] = None, - embedding: Optional[Embeddings] = None, - text_column: Optional[str] = None, - endpoint: Optional[str] = None + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + endpoint: Optional[str] = None, ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { "index_name": index_name, @@ -42,22 +42,23 @@ def init_vector_search_tool( ) return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) def test_init(index_name: str) -> None: vector_search_tool = init_vector_search_tool(index_name) assert isinstance(vector_search_tool, BaseTool) + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: from langchain_core.messages import AIMessage vector_search_tool = init_vector_search_tool(index_name) llm_with_tools = llm.bind_tools([vector_search_tool]) - response = llm_with_tools.invoke( - "Which city is hotter today and which is bigger: LA or NY?" - ) + response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") assert isinstance(response, AIMessage) + @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("columns", [None, ["id", "text"]]) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) @@ -66,13 +67,13 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: @pytest.mark.parametrize("text_column", [None, "text"]) @pytest.mark.parametrize("endpoint", [None, "test_endpoint"]) def test_vector_search_retriever_tool_combinations( - index_name: str, - columns: Optional[List[str]], - tool_name: Optional[str], - tool_description: Optional[str], - embedding: Optional[Any], - text_column: Optional[str], - endpoint: Optional[str] + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], + embedding: Optional[Any], + text_column: Optional[str], + endpoint: Optional[str], ) -> None: if index_name == DELTA_SYNC_INDEX: embedding = None @@ -85,8 +86,8 @@ def test_vector_search_retriever_tool_combinations( tool_description=tool_description, embedding=embedding, text_column=text_column, - endpoint=endpoint + endpoint=endpoint, ) assert isinstance(vector_search_tool, BaseTool) result = vector_search_tool.invoke("Databricks Agent Framework") - assert result is not None \ No newline at end of file + assert result is not None diff --git a/integrations/langchain/tests/utils/chat_models.py b/integrations/langchain/tests/utils/chat_models.py index 7da3923..b3b03eb 100644 --- a/integrations/langchain/tests/utils/chat_models.py +++ b/integrations/langchain/tests/utils/chat_models.py @@ -16,7 +16,7 @@ "message": { "role": "assistant", "content": "To calculate the result of 36939 multiplied by 8922.4, " - "I get:\n\n36939 x 8922.4 = 329,511,111.6", + "I get:\n\n36939 x 8922.4 = 329,511,111.6", }, "finish_reason": "stop", "logprobs": None, @@ -130,6 +130,4 @@ def mock_client() -> Generator: @pytest.fixture def llm() -> ChatDatabricks: - return ChatDatabricks( - endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks" - ) \ No newline at end of file + return ChatDatabricks(endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks") diff --git a/integrations/langchain/tests/utils/vector_search.py b/integrations/langchain/tests/utils/vector_search.py index 489ec7b..a57c14c 100644 --- a/integrations/langchain/tests/utils/vector_search.py +++ b/integrations/langchain/tests/utils/vector_search.py @@ -21,8 +21,7 @@ def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION): def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: """Return simple embeddings.""" return [ - [float(1.0)] * (self.dimension - 1) + [float(i)] - for i in range(len(embedding_texts)) + [float(1.0)] * (self.dimension - 1) + [float(i)] for i in range(len(embedding_texts)) ] def embed_query(self, text: str) -> List[float]: @@ -49,9 +48,7 @@ def embed_query(self, text: str) -> List[float]: "data_array": sorted( [ [str(uuid.uuid4()), s, e, 0.5] - for s, e in zip( - INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) - ) + for s, e in zip(INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)) ], key=lambda x: x[2], # type: ignore reverse=True, @@ -119,12 +116,12 @@ def embed_query(self, text: str) -> List[float]: } ], "schema_json": f"{{" - f'"{"id"}": "int", ' - f'"feat1": "str", ' - f'"feat2": "float", ' - f'"text": "string", ' - f'"{"text_vector"}": "array"' - f"}}", + f'"{"id"}": "int", ' + f'"feat1": "str", ' + f'"feat2": "float", ' + f'"text": "string", ' + f'"{"text_vector"}": "array"' + f"}}", }, }, } @@ -133,8 +130,8 @@ def embed_query(self, text: str) -> List[float]: @pytest.fixture(autouse=True) def mock_vs_client() -> Generator: def _get_index( - endpoint_name: Optional[str] = None, - index_name: str = None, # type: ignore + endpoint_name: Optional[str] = None, + index_name: str = None, # type: ignore ) -> MagicMock: index = MagicMock(spec=VectorSearchIndex) index.describe.return_value = INDEX_DETAILS[index_name] @@ -144,11 +141,12 @@ def _get_index( mock_client = MagicMock() mock_client.get_index.side_effect = _get_index with mock.patch( - "databricks.vector_search.client.VectorSearchClient", - return_value=mock_client, + "databricks.vector_search.client.VectorSearchClient", + return_value=mock_client, ): yield + @pytest.fixture(autouse=True) def mock_workspace_client() -> Generator: def _get_table_comment(full_name: str) -> MagicMock: @@ -159,7 +157,7 @@ def _get_table_comment(full_name: str) -> MagicMock: mock_client = MagicMock() mock_client.tables.get.side_effect = _get_table_comment with patch( - "databricks.sdk.WorkspaceClient", - return_value=mock_client, + "databricks.sdk.WorkspaceClient", + return_value=mock_client, ): - yield \ No newline at end of file + yield