Skip to content

Commit

Permalink
Refactor based on Pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
leonbi100 committed Dec 19, 2024
1 parent 2105f55 commit 5b5361a
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 170 deletions.
61 changes: 61 additions & 0 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from urllib.parse import urlparse

import numpy as np
from enum import Enum
import json

from typing import (
Dict,
Optional
)

def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
Expand Down Expand Up @@ -95,3 +101,58 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"

class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
138 changes: 70 additions & 68 deletions integrations/langchain/src/databricks_langchain/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,81 @@
from databricks_langchain import DatabricksVectorSearch
from typing import (
List,
Optional
)
from langchain_core.embeddings import Embeddings
from langchain.tools.retriever import create_retriever_tool
from typing import Any, Dict, List, Optional

from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Field, model_validator, PrivateAttr

from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from databricks_langchain import DatabricksVectorSearch
from databricks_langchain.utils import IndexDetails
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)

class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(description="query used to search the index")

class VectorSearchRetrieverTool():
class VectorSearchRetrieverTool(BaseTool):
"""
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
This class integrates with a Databricks Vector Search and provides a convenient interface
for building a retriever tool for agents.
Parameters:
index_name (str):
The name of the index to use. Format: “catalog.schema.index”. endpoint:
num_results (int):
The number of results to return. Defaults to 10.
columns (Optional[List[str]]):
The list of column names to get when doing the search. Defaults to [primary_key, text_column].
filters (Optional[Dict[str, Any]]):
The filters to apply to the search. Defaults to None.
query_type (str):
The type of query to run. Defaults to "ANN".
tool_name (str):
The name of the retrieval tool to be created. This will be passed to the language model,
so should be unique and somewhat descriptive.
tool_description (str):
A description of the tool's functionality. This will be passed to the language model,
so should be descriptive.
"""
name: str = Field(description="The name of the tool")
description: str = Field(description="The description of the tool")
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput
def __init__(
self,
index_name: str,
num_results: int = 10,
*,
columns: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
query_type: str = "ANN",
tool_name: Optional[str] = None,
tool_description: Optional[str], # TODO: By default the UC metadata for description, how do I get this info? Call using client?
):
# Use the index name as the tool name if no tool name is provided
self.name = index_name
if tool_name:
self.name = tool_name
self.num_results = num_results
self.columns = columns
self.filters = filters
self.query_type = query_type
self.description = tool_description
self.vector_store = DatabricksVectorSearch(index_name=index_name)

def _run(
self,
query: str
) -> str:
"""Use the tool."""
self.vector_store.similarity_search(query, self.num_results, self.columns, self.filters, self.query_type)
index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.")
num_results: int = Field(10, description="The number of results to return.")
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.")
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.")
query_type: str = Field("ANN", description="The type of query to run.")
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.")
tool_description: Optional[str] = Field(None, description="A description of the tool.")
# TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings,
text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.")
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.")
# TODO: Confirm if we can add this endpoint field
endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.")

# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")

_vector_store: DatabricksVectorSearch = PrivateAttr()

@model_validator(mode='after')
def validate_tool_inputs(self):
# Construct the vector store using provided params
kwargs = {
"index_name": self.index_name,
"endpoint": self.endpoint,
"embedding": self.embedding,
"text_column": self.text_column,
"columns": self.columns,
}
dbvs = DatabricksVectorSearch(**kwargs)
self._vector_store = dbvs

def get_tool_description():
default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings."
index_details = IndexDetails(dbvs.index)
if index_details.is_delta_sync_index():
from databricks.sdk import WorkspaceClient

source_table = index_details.index_spec.get('source_table', "")
w = WorkspaceClient()
source_table_comment = w.tables.get(full_name=source_table).comment
if source_table_comment:
return (
default_tool_description +
f" The queried index uses the source table {source_table} with the description: " +
source_table_comment
)
else:
return default_tool_description + f" The queried index uses the source table {source_table}"
return default_tool_description

self.name = self.tool_name or self.index_name
self.description = self.tool_description or get_tool_description()

return self


def _run(self, query: str) -> str:
return self._vector_store.similarity_search(
query,
k = self.num_results,
filter = self.filters,
query_type = self.query_type
)
60 changes: 1 addition & 59 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import json
import logging
import re
import uuid
from enum import Enum
from functools import partial
from typing import (
Any,
Expand All @@ -23,16 +21,11 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from databricks_langchain.utils import maximal_marginal_relevance
from databricks_langchain.utils import maximal_marginal_relevance, IndexDetails

logger = logging.getLogger(__name__)


class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"


_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
_NON_MANAGED_EMB_ONLY_MSG = "`%s` is not supported for index with Databricks-managed embeddings."
_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$")
Expand Down Expand Up @@ -783,54 +776,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe
f"not match with the index configuration '{index_embedding_dimension}'."
)


class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
43 changes: 0 additions & 43 deletions integrations/langchain/tests/test_vector_search.py

This file was deleted.

Loading

0 comments on commit 5b5361a

Please sign in to comment.