Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorSearchRetrieverTool Openai integration #39

Merged
merged 15 commits into from
Jan 14, 2025
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# test
.pytest_cache/
mlruns/

# Byte-compiled files
__pycache__
Expand Down
Empty file added integrations/openai/README.md
Empty file.
Empty file added integrations/openai/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions integrations/openai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[project]
name = "databricks-openai"
version = "0.1.0"
description = "Support for Databricks AI support with OpenAI"
authors = [
{ name="Leon Bi", email="[email protected]" },
]
readme = "README.md"
license = { text="Apache-2.0" }
requires-python = ">=3.9"
dependencies = [
"databricks-vectorsearch>=0.40",
"databricks-ai-bridge>=0.1.0",
"openai>=1.46.1",
]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"ruff==0.6.4",
]

integration = [
"pytest-timeout>=2.3.1",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
include = [
"src/databricks_openai/*"
]

[tool.hatch.build.targets.wheel]
packages = ["src/databricks_openai"]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
select = [
# isort
"I",
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# print statements
"T201",
"T203",
# misuse of typing.TYPE_CHECKING
"TCH004",
# import rules
"TID251",
# undefined-local-with-import-star
"F403",
]

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 88

[tool.ruff.lint.pydocstyle]
convention = "google"
Empty file.
6 changes: 6 additions & 0 deletions integrations/openai/src/databricks_openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool

# Expose all integrations to users under databricks-openai
__all__ = [
"VectorSearchRetrieverTool",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import json
from typing import Any, Dict, List, Optional, Tuple

from databricks.vector_search.client import VectorSearchIndex
from databricks_ai_bridge.utils.vector_search import (
IndexDetails,
parse_vector_search_response,
validate_and_get_return_columns,
validate_and_get_text_column,
)
from databricks_ai_bridge.vector_search_retriever_tool import (
VectorSearchRetrieverToolInput,
VectorSearchRetrieverToolMixin,
)
from pydantic import Field, PrivateAttr, model_validator

from openai import OpenAI, pydantic_function_tool
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall, ChatCompletionToolParam


class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):
"""
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
This class integrates with Databricks Vector Search and provides a convenient interface
for tool calling using the OpenAI SDK.

Example:
dbvs_tool = VectorSearchRetrieverTool("index_name")
tools = [dbvs_tool.tool, ...]
response = openai.chat.completions.create(
model="gpt-4o",
messages=initial_messages,
tools=tools,
)
retriever_call_message = dbvs_tool.execute_calls(response)

### If needed, execute potential remaining tool calls here ###
remaining_tool_call_messages = execute_remaining_tool_calls(response)

final_response = openai.chat.completions.create(
model="gpt-4o",
messages=initial_messages + retriever_call_message + remaining_tool_call_messages,
tools=tools,
)
final_response.choices[0].message.content
"""

text_column: Optional[str] = Field(
None,
description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.",
)

tool: ChatCompletionToolParam = Field(
None, description="The tool input used in the OpenAI chat completion SDK"
)
_index: VectorSearchIndex = PrivateAttr()
_index_details: IndexDetails = PrivateAttr()

@model_validator(mode="after")
def _validate_tool_inputs(self):
from databricks.vector_search.client import (
VectorSearchClient, # import here so we can mock in tests
)

self._index = VectorSearchClient().get_index(index_name=self.index_name)
self._index_details = IndexDetails(self._index)
self.text_column = validate_and_get_text_column(self.text_column, self._index_details)
self.columns = validate_and_get_return_columns(
self.columns or [], self.text_column, self._index_details
)

# OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'."
# The '.' from the index name are not allowed
def rewrite_index_name(index_name: str):
return index_name.replace(".", "_")

self.tool = pydantic_function_tool(
VectorSearchRetrieverToolInput,
name=self.tool_name or rewrite_index_name(self.index_name),
description=self.tool_description
or self._get_default_tool_description(self._index_details),
)
return self

def execute_calls(
self,
response: ChatCompletion,
choice_index: int = 0,
embedding_model_name: str = None,
openai_client: OpenAI = None,
) -> List[Dict[str, Any]]:
"""
Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the
self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages.

Args:
response: The chat completion response object returned by the OpenAI API.
choice_index: The index of the choice to process. Defaults to 0. Note that multiple
choices are not supported yet.
embedding_model_name: The name of the embedding model to use for embedding the query text.
Required for direct access indexes or delta-sync indexes with self-managed embeddings.
openai_client: The OpenAI client object used to generate embeddings for retrieval queries. If not provided,
the default OpenAI client in the current environment will be used.

Returns:
A list of messages containing the assistant message and the retriever call results
that correspond to the self.tool VectorSearchRetrieverToolInput.
"""

def get_query_text_vector(
tool_call: ChatCompletionMessageToolCall,
) -> Tuple[Optional[str], Optional[List[float]]]:
query = json.loads(tool_call.function.arguments)["query"]
if self._index_details.is_databricks_managed_embeddings():
if embedding_model_name:
raise ValueError(
f"The index '{self._index_details.name}' uses Databricks-managed embeddings. "
"Do not pass the `embedding_model_name` parameter when executing retriever calls."
)
return query, None

# For non-Databricks-managed embeddings
from openai import OpenAI

oai_client = openai_client or OpenAI()
if not oai_client.api_key:
raise ValueError(
"OpenAI API key is required to generate embeddings for retrieval queries."
)
if not embedding_model_name:
raise ValueError(
"The embedding model name is required for non-Databricks-managed "
"embeddings Vector Search indexes in order to generate embeddings for retrieval queries."
)

text = query if self.query_type and self.query_type.upper() == "HYBRID" else None
vector = (
oai_client.embeddings.create(input=query, model=embedding_model_name)
.data[0]
.embedding
)
if (
index_embedding_dimension := self._index_details.embedding_vector_column.get(
"embedding_dimension"
)
) and len(vector) != index_embedding_dimension:
raise ValueError(
f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}"
)
return text, vector

def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool:
tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys())
vs_index_arguments: Set[str] = set(
self.tool["function"]["parameters"]["properties"].keys()
)
return (
tool_call.function.name == self.tool["function"]["name"]
and tool_call_arguments == vs_index_arguments
)

if type(response) is not ChatCompletion:
raise ValueError("response must be an instance of ChatCompletion")
message = response.choices[choice_index].message
llm_tool_calls = message.tool_calls
function_calls = []
if llm_tool_calls:
for llm_tool_call in llm_tool_calls:
# Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput
if not is_tool_call_for_index(llm_tool_call):
continue

query_text, query_vector = get_query_text_vector(llm_tool_call)
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=self.filters,
num_results=self.num_results,
query_type=self.query_type,
)
docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response(
search_resp=search_resp,
index_details=self._index_details,
text_column=self.text_column,
document_class=dict,
)

function_call_result_message = {
"role": "tool",
"content": json.dumps({"content": docs_with_score}),
"tool_call_id": llm_tool_call.id,
}
function_calls.append(function_call_result_message)
assistant_message = message.to_dict()
return [assistant_message, *function_calls]
Loading
Loading