Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorSearchRetrieverTool Openai integration #39

Merged
merged 15 commits into from
Jan 14, 2025
Prev Previous commit
Next Next commit
Working e2e delta sync index happy case
  • Loading branch information
leonbi100 committed Dec 31, 2024
commit 005a5333a7f2b61d4d495cbad6df9db0469e5f2b
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# test
.pytest_cache/
mlruns/

# Byte-compiled files
__pycache__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from openai import OpenAI
from openai import pydantic_function_tool
from openai.types.chat import ChatCompletionToolParam
from openai.types.chat import ChatCompletionToolParam, ChatCompletionMessageToolCall
from openai.types.chat import ChatCompletion

from pydantic import BaseModel, Field, PrivateAttr, model_validator, create_model

from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput, DEFAULT_TOOL_DESCRIPTION
from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin, VectorSearchRetrieverToolInput
from databricks_ai_bridge.utils.vector_search import IndexDetails, parse_vector_search_response, validate_and_get_text_column, validate_and_get_return_columns
from databricks.vector_search.client import VectorSearchClient, VectorSearchIndex
import json
Expand All @@ -33,7 +33,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):

final_response = openai.chat.completions.create(
model="gpt-4o",
messages=initial_messages + retriever_call_message + remaining_tool_call_messages
messages=initial_messages + retriever_call_message + remaining_tool_call_messages,
tools=tools,
)
final_response.choices[0].message.content
Expand All @@ -48,7 +48,7 @@ class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin):
)

tool: ChatCompletionToolParam = Field(
..., description="The tool input used in the OpenAI chat completion SDK"
None, description="The tool input used in the OpenAI chat completion SDK"
)
_index: VectorSearchIndex = PrivateAttr()
_index_details: IndexDetails = PrivateAttr()
Expand All @@ -58,11 +58,16 @@ def _validate_tool_inputs(self):
self._index = VectorSearchClient().get_index(index_name=self.index_name)
self._index_details = IndexDetails(self._index)
self.text_column = validate_and_get_text_column(self.text_column, self._index_details)
self.columns = validate_and_get_return_columns(self.columns, self.text_column, self._index_details)
self.columns = validate_and_get_return_columns(self.columns or [], self.text_column, self._index_details)

# OpenAI tool names must match the pattern '^[a-zA-Z0-9_-]+$'."
# The '.' from the index name are not allowed
def rewrite_index_name(index_name: str):
return index_name.split(".")[-1]

self.tool = pydantic_function_tool(
VectorSearchRetrieverToolInput,
name=self.tool_name or self.index_name,
name=self.tool_name or rewrite_index_name(self.index_name),
description=self.tool_description or self._get_default_tool_description(self._index_details)
)
return self
Expand All @@ -89,7 +94,9 @@ def execute_retriever_calls(self,
A list of messages containing the assistant message and the retriever call results
that correspond to the self.tool VectorSearchRetrieverToolInput.
"""
def get_query_text_vector(query):

def get_query_text_vector(tool_call: ChatCompletionMessageToolCall) -> Tuple[Optional[str], Optional[List[float]]]:
query = json.loads(tool_call.function.arguments)["query"]
if self._index_details.is_databricks_managed_embeddings():
if open_ai_client or embedding_model_name:
raise ValueError(
Expand All @@ -107,32 +114,29 @@ def get_query_text_vector(query):
input=query,
model=embedding_model_name
)['data'][0]['embedding']
if (index_embedding_dimension := index_details.embedding_vector_column.get("embedding_dimension")) and \
if (index_embedding_dimension := self._index_details.embedding_vector_column.get("embedding_dimension")) and \
len(vector) != index_embedding_dimension:
raise ValueError(
f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}"
)
return text, vector

def do_arguments_match(llm_call_arguments: Dict[str, Any]):
# TODO: Remove print statement
print("llm_call_arguments_json: ", llm_call_arguments)
print("retriever_tool_params: ", self.tool.function.parameters)
retriever_tool_params: Dict[str, Any] = self.tool.function.parameters
return set(llm_call_arguments.keys()) == set(retriever_tool_params.keys())
def is_tool_call_for_index(tool_call: ChatCompletionMessageToolCall) -> bool:
tool_call_arguments: Set[str] = set(json.loads(tool_call.function.arguments).keys())
vs_index_arguments: Set[str] = set(self.tool["function"]["parameters"]["properties"].keys())
return tool_call.function.name == self.tool["function"]["name"] and \
tool_call_arguments == vs_index_arguments

message = response.choices[choice_index].message
llm_tool_calls = message.tool_calls
function_calls = []
if llm_tool_calls:
for llm_tool_call in llm_tool_calls:
# Only process tool calls that correspond to the self.tool VectorSearchRetrieverToolInput
llm_call_arguments_json: Dict[str, Any] = json.loads(llm_tool_call.function.arguments)
if llm_tool_call.function.name != self.tool.function.name \
or not do_arguments_match(llm_call_arguments_json):
if not is_tool_call_for_index(llm_tool_call):
continue

query_text, query_vector = get_query_text_vector(llm_call_arguments_json["query"])
query_text, query_vector = get_query_text_vector(llm_tool_call)
search_resp = self._index.similarity_search(
columns=self.columns,
query_text=query_text,
Expand All @@ -141,15 +145,12 @@ def do_arguments_match(llm_call_arguments: Dict[str, Any]):
num_results=self.num_results,
query_type=self.query_type,
)
embedding_column = self._index_details.embedding_vector_column["name"]
# Don't show the embedding column in the response
ignore_cols: List = [embedding_column] if embedding_column not in self.columns else []
docs_with_score: List[Tuple[Dict, float]] = \
parse_vector_search_response(
search_resp,
self._index_details,
self.text_column,
ignore_cols=ignore_cols,
ignore_cols=[],
document_class=dict
)

Expand Down
2 changes: 1 addition & 1 deletion src/databricks_ai_bridge/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def validate_and_get_text_column(text_column: Optional[str], index_details: Inde
raise ValueError("The `text_column` parameter is required for this index.")
return text_column

def _validate_and_get_return_columns(
def validate_and_get_return_columns(
columns: List[str], text_column: str, index_details: IndexDetails
) -> List[str]:
"""
Expand Down
Empty file removed tests/utils/vector_search.py
Empty file.