Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorSearchRetrieverTool for Llamaindex Integration #42

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
71 changes: 71 additions & 0 deletions integrations/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[project]
name = "databricks-llamaindex"
version = "0.1.1.dev0"
description = "Support for Databricks AI support in LlamaIndex"
authors = [
{ name="Leon Bi", email="[email protected]" },
]
readme = "README.md"
license = { text="Apache-2.0" }
requires-python = ">=3.9"
dependencies = [
"databricks-vectorsearch>=0.40",
"databricks-ai-bridge>=0.1.0",
"llama-index>=0.11.0",
]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"ruff==0.6.4",
]

integration = [
"pytest-timeout>=2.3.1",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = [
"src/databricks_llamaindex/*"
]

[tool.hatch.build.targets.wheel]
packages = ["src/databricks_llamaindex"]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
select = [
# isort
"I",
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# print statements
"T201",
"T203",
# misuse of typing.TYPE_CHECKING
"TCH004",
# import rules
"TID251",
# undefined-local-with-import-star
"F403",
]

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 88

[tool.ruff.lint.pydocstyle]
convention = "google"
Empty file.
6 changes: 6 additions & 0 deletions integrations/llamaindex/src/databricks_llamaindex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from databricks_llamaindex.vector_search_retriever_tool import VectorSearchRetrieverTool

# Expose all integrations to users under databricks-langchain
__all__ = [
"VectorSearchRetrieverTool",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any, Dict, List, Optional, Tuple

from databricks_ai_bridge.utils.vector_search import (
IndexDetails,
parse_vector_search_response,
validate_and_get_return_columns,
validate_and_get_text_column,
)
from databricks_ai_bridge.vector_search_retriever_tool import (
VectorSearchRetrieverToolInput,
VectorSearchRetrieverToolMixin,
)
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.tools import FunctionTool
from llama_index.core.tools.types import ToolMetadata
from pydantic import Field, PrivateAttr


class VectorSearchRetrieverTool(FunctionTool, VectorSearchRetrieverToolMixin):
"""Vector search retriever tool implementation."""

text_column: Optional[str] = Field(
None,
description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.",
)
embedding: Optional[BaseEmbedding] = Field(
None, description="Embedding model for self-managed embeddings."
)
return_direct: bool = Field(
default=False,
description="Whether the tool should return the output directly",
)

_index = PrivateAttr()
_index_details = PrivateAttr()

def __init__(self, **data):
# First initialize the VectorSearchRetrieverToolMixin
VectorSearchRetrieverToolMixin.__init__(self, **data)

# Initialize private attributes
from databricks.vector_search.client import VectorSearchClient

self._index = VectorSearchClient().get_index(index_name=self.index_name)
self._index_details = IndexDetails(self._index)

# Validate columns
self.text_column = validate_and_get_text_column(self.text_column, self._index_details)
self.columns = validate_and_get_return_columns(
self.columns or [], self.text_column, self._index_details
)

# Define the similarity search function
def similarity_search(query: str) -> List[Dict[str, Any]]:
def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[float]]]:
if self._index_details.is_databricks_managed_embeddings():
if self.embedding:
raise ValueError(
f"The index '{self._index_details.name}' uses Databricks-managed embeddings. "
"Do not pass the `embedding` parameter when executing retriever calls."
)
return query, None

if not self.embedding:
raise ValueError(
"The embedding model name is required for non-Databricks-managed "
"embeddings Vector Search indexes in order to generate embeddings for retrieval queries."
)

text = query if self.query_type and self.query_type.upper() == "HYBRID" else None
vector = self.embedding.get_text_embedding(text=query)
if (
index_embedding_dimension := self._index_details.embedding_vector_column.get(
"embedding_dimension"
)
) and len(vector) != index_embedding_dimension:
raise ValueError(
f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}"
)
return text, vector

query_text, query_vector = get_query_text_vector(query)
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
num_results=self.num_results,
query_type=self.query_type,
)
return parse_vector_search_response(
search_resp, self._index_details, self.text_column, document_class=dict
)

# Create tool metadata
metadata = ToolMetadata(
name=self.tool_name or self.index_name,
description=self.tool_description
or self._get_default_tool_description(self._index_details),
fn_schema=VectorSearchRetrieverToolInput,
return_direct=self.return_direct,
)

# Initialize FunctionTool with the similarity search function and metadata
FunctionTool.__init__(self, fn=similarity_search, metadata=metadata)
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any, Dict, List, Optional

import pytest
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DEFAULT_VECTOR_DIMENSION,
DELTA_SYNC_INDEX,
mock_vs_client,
mock_workspace_client,
)
from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolInput
from llama_index.core.agent import ReActAgent
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI
from pydantic import Field

from databricks_llamaindex import VectorSearchRetrieverTool


class FakeEmbeddings(BaseEmbedding):
"""Fake embeddings functionality for testing."""

dimension: int = Field(default=DEFAULT_VECTOR_DIMENSION)

def get_text_embedding(self, text: str) -> List[float]:
return [1.0] * (self.dimension - 1) + [0.0]

def _aget_query_embedding(self):
pass

def _get_query_embedding(self):
pass

def _get_text_embedding(self):
pass


EMBEDDING_MODEL = FakeEmbeddings()


def init_vector_search_tool(
index_name: str,
columns: Optional[List[str]] = None,
tool_name: Optional[str] = None,
tool_description: Optional[str] = None,
embedding: Optional[BaseEmbedding] = None,
text_column: Optional[str] = None,
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"tool_name": tool_name,
"tool_description": tool_description,
"embedding": embedding,
"text_column": text_column,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
"embedding": EMBEDDING_MODEL,
"text_column": "text",
}
)
return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type]


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_init(index_name: str) -> None:
vector_search_tool = init_vector_search_tool(index_name)
assert isinstance(vector_search_tool, FunctionTool)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("columns", [None, ["id", "text"]])
@pytest.mark.parametrize("tool_name", [None, "test_tool"])
@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"])
@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL])
@pytest.mark.parametrize("text_column", [None, "text"])
def test_vector_search_retriever_tool_combinations(
index_name: str,
columns: Optional[List[str]],
tool_name: Optional[str],
tool_description: Optional[str],
embedding: Optional[Any],
text_column: Optional[str],
) -> None:
if index_name == DELTA_SYNC_INDEX:
embedding = None
text_column = None

vector_search_tool = init_vector_search_tool(
index_name=index_name,
columns=columns,
tool_name=tool_name,
tool_description=tool_description,
embedding=embedding,
text_column=text_column,
)
assert isinstance(vector_search_tool, FunctionTool)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_vector_search_retriever_tool_description_generation(index_name: str) -> None:
vector_search_tool = init_vector_search_tool(index_name)
assert vector_search_tool.metadata.name != ""
assert vector_search_tool.metadata.description != ""
assert vector_search_tool.metadata.fn_schema == VectorSearchRetrieverToolInput
assert vector_search_tool.metadata.name == index_name
assert (
"A vector search-based retrieval tool for querying indexed embeddings."
in vector_search_tool.metadata.description
)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_vector_search_retriever_tool_bind_agent(index_name: str) -> None:
vector_search_tool = init_vector_search_tool(index_name)
llm = OpenAI()
assert ReActAgent.from_tools([vector_search_tool], llm=llm, verbose=True) is not None
Loading