From 1f43c172053c09de43477629800da2bf7a370614 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:11:33 -0800 Subject: [PATCH] Inline langchain-databricks components (#25) * Inline langchain-databricks components Signed-off-by: Prithvi Kannan * add unit tests Signed-off-by: Prithvi Kannan * format Signed-off-by: Prithvi Kannan * ruff Signed-off-by: Prithvi Kannan * fix Signed-off-by: Prithvi Kannan * deps Signed-off-by: Prithvi Kannan * ruff Signed-off-by: Prithvi Kannan * mlflow Signed-off-by: Prithvi Kannan --------- Signed-off-by: Prithvi Kannan --- integrations/langchain/pyproject.toml | 3 +- .../src/databricks_langchain/__init__.py | 12 +- .../src/databricks_langchain/chat_models.py | 817 +++++++++++++++++ .../src/databricks_langchain/embeddings.py | 88 ++ .../src/databricks_langchain/genie.py | 4 +- .../src/databricks_langchain/utils.py | 97 ++ .../src/databricks_langchain/vectorstores.py | 836 ++++++++++++++++++ .../langchain/tests/test_chat_models.py | 448 ++++++++++ .../langchain/tests/test_embeddings.py | 67 ++ .../langchain/tests/test_vectorstores.py | 617 +++++++++++++ tests/databricks_ai_bridge/test_genie.py | 114 ++- 11 files changed, 3078 insertions(+), 25 deletions(-) create mode 100644 integrations/langchain/src/databricks_langchain/chat_models.py create mode 100644 integrations/langchain/src/databricks_langchain/embeddings.py create mode 100644 integrations/langchain/src/databricks_langchain/utils.py create mode 100644 integrations/langchain/src/databricks_langchain/vectorstores.py create mode 100644 integrations/langchain/tests/test_chat_models.py create mode 100644 integrations/langchain/tests/test_embeddings.py create mode 100644 integrations/langchain/tests/test_vectorstores.py diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 28d4422..1e19cb9 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -11,7 +11,8 @@ requires-python = ">=3.9" dependencies = [ "langchain>=0.2.0", "langchain-community>=0.2.0", - "langchain-databricks>=0.1.1", + "mlflow", + "databricks-vectorsearch>=0.40", "databricks-ai-bridge", ] diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 4b7a9da..b457f47 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -1,11 +1,7 @@ -# Import modules from langchain-databricks -from langchain_databricks import ( - ChatDatabricks, - DatabricksEmbeddings, - DatabricksVectorSearch, -) - -from .genie import GenieAgent +from databricks_langchain.chat_models import ChatDatabricks +from databricks_langchain.embeddings import DatabricksEmbeddings +from databricks_langchain.genie import GenieAgent +from databricks_langchain.vectorstores import DatabricksVectorSearch # Expose all integrations to users under databricks-langchain __all__ = [ diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py new file mode 100644 index 0000000..1dbe33c --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -0,0 +1,817 @@ +"""Databricks chat models.""" + +import json +import logging +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Type, + Union, +) + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.base import LanguageModelInput +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolMessage, + ToolMessageChunk, +) +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.pydantic import is_basemodel_subclass +from mlflow.deployments import BaseDeploymentClient # type: ignore +from pydantic import BaseModel, Field + +from databricks_langchain.utils import get_deployment_client + +logger = logging.getLogger(__name__) + + +class ChatDatabricks(BaseChatModel): + """Databricks chat model integration. + + Setup: + Install ``databricks-langchain``. + + .. code-block:: bash + + pip install -U databricks-langchain + + If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables: + + .. code-block:: bash + + export DATABRICKS_HOSTNAME="https://your-databricks-workspace" + export DATABRICKS_TOKEN="your-personal-access-token" + + Key init args — completion params: + endpoint: str + Name of Databricks Model Serving endpoint to query. + target_uri: str + The target URI to use. Defaults to ``databricks``. + temperature: float + Sampling temperature. Higher values make the model more creative. + n: Optional[int] + The number of completion choices to generate. + stop: Optional[List[str]] + List of strings to stop generation at. + max_tokens: Optional[int] + Max number of tokens to generate. + extra_params: Optional[Dict[str, Any]] + Any extra parameters to pass to the endpoint. + + Instantiate: + .. code-block:: python + + from databricks_langchain import ChatDatabricks + + llm = ChatDatabricks( + endpoint="databricks-meta-llama-3-1-405b-instruct", + temperature=0, + max_tokens=500, + ) + + Invoke: + .. code-block:: python + + messages = [ + ("system", "You are a helpful translator. Translate the user sentence to French."), + ("human", "I love programming."), + ] + llm.invoke(messages) + + .. code-block:: python + + AIMessage( + content="J'adore la programmation.", + response_metadata={"prompt_tokens": 32, "completion_tokens": 9, "total_tokens": 41}, + id="run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0", + ) + + Stream: + .. code-block:: python + + for chunk in llm.stream(messages): + print(chunk) + + .. code-block:: python + + content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e' + + .. code-block:: python + + stream = llm.stream(messages) + full = next(stream) + for chunk in stream: + full += chunk + full + + .. code-block:: python + + AIMessageChunk( + content="J'adore la programmation.", + response_metadata={"finish_reason": "stop"}, + id="run-4cef851f-6223-424f-ad26-4a54e5852aa5", + ) + + To get token usage returned when streaming, pass the ``stream_usage`` kwarg: + + .. code-block:: python + + stream = llm.stream(messages, stream_usage=True) + next(stream).usage_metadata + + .. code-block:: python + + {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33} + + Alternatively, setting ``stream_usage`` when instantiating the model can be + useful when incorporating ``ChatDatabricks`` into LCEL chains-- or when using + methods like ``.with_structured_output``, which generate chains under the + hood. + + .. code-block:: python + + llm = ChatDatabricks( + endpoint="databricks-meta-llama-3-1-405b-instruct", stream_usage=True + ) + structured_llm = llm.with_structured_output(...) + + Async: + .. code-block:: python + + await llm.ainvoke(messages) + + # stream: + # async for chunk in llm.astream(messages) + + # batch: + # await llm.abatch([messages]) + + .. code-block:: python + + AIMessage( + content="J'adore la programmation.", + response_metadata={"prompt_tokens": 32, "completion_tokens": 9, "total_tokens": 41}, + id="run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0", + ) + + Tool calling: + .. code-block:: python + + from pydantic import BaseModel, Field + + + class GetWeather(BaseModel): + '''Get the current weather in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + + class GetPopulation(BaseModel): + '''Get the current population in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + + llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) + ai_msg = llm_with_tools.invoke( + "Which city is hotter today and which is bigger: LA or NY?" + ) + ai_msg.tool_calls + + .. code-block:: python + + [ + { + "name": "GetWeather", + "args": {"location": "Los Angeles, CA"}, + "id": "call_ea0a6004-8e64-4ae8-a192-a40e295bfa24", + "type": "tool_call", + } + ] + + To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information. + + """ # noqa: E501 + + endpoint: str + """Name of Databricks Model Serving endpoint to query.""" + target_uri: str = "databricks" + """The target URI to use. Defaults to ``databricks``.""" + temperature: float = 0.0 + """Sampling temperature. Higher values make the model more creative.""" + n: int = 1 + """The number of completion choices to generate.""" + stop: Optional[List[str]] = None + """List of strings to stop generation at.""" + max_tokens: Optional[int] = None + """The maximum number of tokens to generate.""" + extra_params: Optional[Dict[str, Any]] = None + """Whether to include usage metadata in streaming output. If True, additional + message chunks will be generated during the stream including usage metadata. + """ + stream_usage: bool = False + """Any extra parameters to pass to the endpoint.""" + client: Optional[BaseDeploymentClient] = Field(default=None, exclude=True) #: :meta private: + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.client = get_deployment_client(self.target_uri) + self.extra_params = self.extra_params or {} + + @property + def _default_params(self) -> Dict[str, Any]: + params: Dict[str, Any] = { + "target_uri": self.target_uri, + "endpoint": self.endpoint, + "temperature": self.temperature, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens, + "extra_params": self.extra_params, + } + return params + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + data = self._prepare_inputs(messages, stop, **kwargs) + resp = self.client.predict(endpoint=self.endpoint, inputs=data) # type: ignore + return self._convert_response_to_chat_result(resp) + + def _prepare_inputs( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + data: Dict[str, Any] = { + "messages": [_convert_message_to_dict(msg) for msg in messages], + "temperature": self.temperature, + "n": self.n, + **self.extra_params, # type: ignore + **kwargs, + } + if stop := self.stop or stop: + data["stop"] = stop + if self.max_tokens is not None: + data["max_tokens"] = self.max_tokens + + return data + + def _convert_response_to_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [ + ChatGeneration( + message=_convert_dict_to_message(choice["message"]), + generation_info=choice.get("usage", {}), + ) + for choice in response["choices"] + ] + usage = response.get("usage", {}) + return ChatResult(generations=generations, llm_output=usage) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + if stream_usage is None: + stream_usage = self.stream_usage + data = self._prepare_inputs(messages, stop, **kwargs) + first_chunk_role = None + for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore + if chunk["choices"]: + choice = chunk["choices"][0] + + chunk_delta = choice["delta"] + if first_chunk_role is None: + first_chunk_role = chunk_delta.get("role") + + if stream_usage and (usage := chunk.get("usage")): + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + usage = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + else: + usage = None + + chunk_message = _convert_dict_to_message_chunk( + chunk_delta, first_chunk_role, usage=usage + ) + + generation_info = {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if logprobs := choice.get("logprobs"): + generation_info["logprobs"] = logprobs + + chunk = ChatGenerationChunk( + message=chunk_message, generation_info=generation_info or None + ) + + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) + + yield chunk + else: + # Handle the case where choices are empty if needed + continue + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Options are: + name of the tool (str): calls corresponding tool; + "auto": automatically selects a tool (including no tool); + "none": model does not generate any tool calls and instead must + generate a standard assistant message; + "required": the model picks the most relevant tool in tools and + must generate a tool call; + + or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice: + if isinstance(tool_choice, str): + # tool_choice is a tool/function name + if tool_choice not in ("auto", "none", "required", "any"): + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + # 'any' is not natively supported by OpenAI API, + # but supported by other models in Langchain. + # Ref: https://github.com/langchain-ai/langchain/blob/202d7f6c4a2ca8c7e5949d935bcf0ba9b0c23fb0/libs/partners/openai/langchain_openai/chat_models/base.py#L1098C1-L1101C45 + if tool_choice == "any": + tool_choice = "required" + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + else: + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + kwargs["tool_choice"] = tool_choice + return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Optional[Union[Dict, Type]] = None, + *, + method: Literal["function_calling", "json_mode"] = "function_calling", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OpenAI function-calling spec or be a valid JSON schema + with top level 'title' and 'description' keys specified. + method: The method for steering model generation, either "function_calling" + or "json_mode". If "function_calling" then the schema will be converted + to an OpenAI function and the returned model will make use of the + function-calling API. If "json_mode" then OpenAI's JSON mode will be + used. Note that if using "json_mode" then you must include instructions + for formatting the output into the desired schema into the model call. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs + an instance of ``schema`` (i.e., a Pydantic object). + + Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + If ``include_raw`` is True, then Runnable outputs a dict with keys: + - ``"raw"``: BaseMessage + - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - ``"parsing_error"``: Optional[BaseException] + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False): + .. code-block:: python + + from databricks_langchain import ChatDatabricks + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True): + .. code-block:: python + + from databricks_langchain import ChatDatabricks + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Function-calling, dict schema (method="function_calling", include_raw=False): + .. code-block:: python + + from databricks_langchain import ChatDatabricks + from langchain_core.utils.function_calling import convert_to_openai_tool + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True): + .. code-block:: + + from databricks_langchain import ChatDatabricks + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + answer: str + justification: str + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output( + AnswerWithJustification, + method="json_mode", + include_raw=True + ) + + structured_llm.invoke( + "Answer the following question. " + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "What's heavier a pound of bricks or a pound of feathers?" + ) + # -> { + # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), + # 'parsing_error': None + # } + + Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True): + .. code-block:: + + structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) + + structured_llm.invoke( + "Answer the following question. " + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "What's heavier a pound of bricks or a pound of feathers?" + ) + # -> { + # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'parsed': { + # 'answer': 'They are both the same weight.', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' + # }, + # 'parsing_error': None + # } + + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) + if method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " "Received None." + ) + tool_name = convert_to_openai_tool(schema)["function"]["name"] + llm = self.bind_tools([schema], tool_choice=tool_name) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True) + elif method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected one of 'function_calling' or " + f"'json_mode'. Received: '{method}'" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + + @property + def _identifying_params(self) -> Dict[str, Any]: + return self._default_params + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **self._default_params, + **super()._get_invocation_params(stop=stop, **kwargs), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "chat-databricks" + + +### Conversion function to convert Pydantic models to dictionaries and vice versa. ### + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict = {"content": message.content} + + # NB: We don't propagate 'name' field from input message to the endpoint because + # FMAPI doesn't support it. We should update the endpoints to be compatible with + # OpenAI and then we can uncomment the following code. + # if (name := message.name or message.additional_kwargs.get("name")) is not None: + # message_dict["name"] = name + + if isinstance(message, ChatMessage): + return {"role": message.role, **message_dict} + elif isinstance(message, HumanMessage): + return {"role": "user", **message_dict} + elif isinstance(message, AIMessage): + if tool_calls := _get_tool_calls_from_ai_message(message): + message_dict["tool_calls"] = tool_calls # type: ignore[assignment] + # If tool calls present, content null value should be None not empty string. + message_dict["content"] = message_dict["content"] or None # type: ignore[assignment] + return {"role": "assistant", **message_dict} + elif isinstance(message, SystemMessage): + return {"role": "system", **message_dict} + elif isinstance(message, ToolMessage): + return { + "role": "tool", + "tool_call_id": message.tool_call_id, + **message_dict, + } + elif isinstance(message, FunctionMessage) or "function_call" in message.additional_kwargs: + raise ValueError( + "Function messages are not supported by Databricks. Please" + " create a feature request at https://github.com/mlflow/mlflow/issues." + ) + else: + raise ValueError(f"Got unknown message type: {type(message)}") + + +def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]: + tool_calls = [ + { + "type": "function", + "id": tc["id"], + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["args"]), + }, + } + for tc in message.tool_calls + ] + + invalid_tool_calls = [ + { + "type": "function", + "id": tc["id"], + "function": { + "name": tc["name"], + "arguments": tc["args"], + }, + } + for tc in message.invalid_tool_calls + ] + + if tool_calls or invalid_tool_calls: + return tool_calls + invalid_tool_calls + + # Get tool calls from additional kwargs if present. + return [ + { + k: v + for k, v in tool_call.items() # type: ignore[union-attr] + if k in {"id", "type", "function"} + } + for tool_call in message.additional_kwargs.get("tool_calls", []) + ] + + +def _convert_dict_to_message(_dict: Dict) -> BaseMessage: + role = _dict["role"] + content = _dict.get("content") + content = content if content is not None else "" + + if role == "user": + return HumanMessage(content=content) + elif role == "system": + return SystemMessage(content=content) + elif role == "assistant": + additional_kwargs: Dict = {} + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e))) + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + id=_dict.get("id"), + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) + else: + return ChatMessage(content=content, role=role) + + +def _convert_dict_to_message_chunk( + _dict: Mapping[str, Any], + default_role: str, + usage: Optional[Dict[str, Any]] = None, +) -> BaseMessageChunk: + role = _dict.get("role", default_role) + content = _dict.get("content") + content = content if content is not None else "" + + if role == "user": + return HumanMessageChunk(content=content) + elif role == "system": + return SystemMessageChunk(content=content) + elif role == "tool": + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id") + ) + elif role == "assistant": + additional_kwargs: Dict = {} + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + tool_call_chunk( + name=tc["function"].get("name"), + args=tc["function"].get("arguments"), + id=tc.get("id"), + index=tc["index"], + ) + for tc in raw_tool_calls + ] + except KeyError: + pass + usage_metadata = UsageMetadata(**usage) if usage else None # type: ignore + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + id=_dict.get("id"), + tool_call_chunks=tool_call_chunks, + usage_metadata=usage_metadata, + ) + else: + return ChatMessageChunk(content=content, role=role) diff --git a/integrations/langchain/src/databricks_langchain/embeddings.py b/integrations/langchain/src/databricks_langchain/embeddings.py new file mode 100644 index 0000000..421e45c --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/embeddings.py @@ -0,0 +1,88 @@ +from typing import Any, Dict, Iterator, List + +from langchain_core.embeddings import Embeddings +from pydantic import BaseModel, PrivateAttr + +from databricks_langchain.utils import get_deployment_client + + +class DatabricksEmbeddings(Embeddings, BaseModel): + """Databricks embedding model integration. + + Setup: + Install ``databricks-langchain``. + + .. code-block:: bash + + pip install -U databricks-langchain + + If you are outside Databricks, set the Databricks workspace + hostname and personal access token to environment variables: + + .. code-block:: bash + + export DATABRICKS_HOSTNAME="https://your-databricks-workspace" + export DATABRICKS_TOKEN="your-personal-access-token" + + Key init args — completion params: + endpoint: str + Name of Databricks Model Serving endpoint to query. + target_uri: str + The target URI to use. Defaults to ``databricks``. + query_params: Dict[str, str] + The parameters to use for queries. + documents_params: Dict[str, str] + The parameters to use for documents. + + Instantiate: + .. code-block:: python + from databricks_langchain import DatabricksEmbeddings + + embed = DatabricksEmbeddings( + endpoint="databricks-bge-large-en", + ) + + Embed single text: + .. code-block:: python + input_text = "The meaning of life is 42" + embed.embed_query(input_text) + + .. code-block:: python + [0.01605224609375, -0.0298309326171875, ...] + + """ + + endpoint: str + """The endpoint to use.""" + target_uri: str = "databricks" + """The parameters to use for queries.""" + query_params: Dict[str, Any] = {} + """The parameters to use for documents.""" + documents_params: Dict[str, Any] = {} + """The target URI to use.""" + _client: Any = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self._client = get_deployment_client(self.target_uri) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self._embed(texts, params=self.documents_params) + + def embed_query(self, text: str) -> List[float]: + return self._embed([text], params=self.query_params)[0] + + def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]: + embeddings: List[List[float]] = [] + for txt in _chunk(texts, 20): + resp = self._client.predict( + endpoint=self.endpoint, + inputs={"input": txt, **params}, # type: ignore[arg-type] + ) + embeddings.extend(r["embedding"] for r in resp["data"]) + return embeddings + + +def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: + for i in range(0, len(texts), size): + yield texts[i : i + size] diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 153c2df..36bca44 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -40,7 +40,9 @@ def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""): # Create a partial function with the genie_space_id pre-filled partial_genie_agent = partial( - _query_genie_as_agent, genie_space_id=genie_space_id, genie_agent_name=genie_agent_name + _query_genie_as_agent, + genie_space_id=genie_space_id, + genie_agent_name=genie_agent_name, ) # Use the partial function in the RunnableLambda diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py new file mode 100644 index 0000000..e94c2bc --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -0,0 +1,97 @@ +from typing import Any, List, Union +from urllib.parse import urlparse + +import numpy as np + + +def get_deployment_client(target_uri: str) -> Any: + if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"): + raise ValueError("Invalid target URI. The target URI must be a valid databricks URI.") + + try: + from mlflow.deployments import get_deploy_client # type: ignore[import-untyped] + + return get_deploy_client(target_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + "Please run `pip install mlflow` to install " + "required dependencies." + ) from e + + +# Utility function for Maximal Marginal Relevance (MMR) reranking. +# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency +Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + + +def maximal_marginal_relevance( + query_embedding: np.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> List[int]: + """Calculate maximal marginal relevance. + + Args: + query_embedding: Query embedding. + embedding_list: List of embeddings to select from. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of indices of embeddings selected by maximal marginal relevance. + """ + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding.ndim == 1: + query_embedding = np.expand_dims(query_embedding, axis=0) + similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = lambda_mult * query_score - (1 - lambda_mult) * redundant_score + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs + + +def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices. + + Raises: + ValueError: If the number of columns in X and Y are not the same. + """ + if len(X) == 0 or len(Y) == 0: + return np.array([]) + + X = np.array(X) + Y = np.array(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError( + "Number of columns in X and Y must be the same. X has shape" + f"{X.shape} " + f"and Y has shape {Y.shape}." + ) + + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py new file mode 100644 index 0000000..e67f315 --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -0,0 +1,836 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import re +import uuid +from enum import Enum +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import numpy as np +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VST, VectorStore + +from databricks_langchain.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + + +class IndexType(str, Enum): + DIRECT_ACCESS = "DIRECT_ACCESS" + DELTA_SYNC = "DELTA_SYNC" + + +_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index." +_NON_MANAGED_EMB_ONLY_MSG = "`%s` is not supported for index with Databricks-managed embeddings." +_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$") + + +class DatabricksVectorSearch(VectorStore): + """Databricks vector store integration. + + Setup: + Install ``databricks-langchain`` and ``databricks-vectorsearch`` python packages. + + .. code-block:: bash + + pip install -U databricks-langchain databricks-vectorsearch + + If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html + + If you are outside Databricks, set the Databricks workspace + hostname and personal access token to environment variables: + + .. code-block:: bash + + export DATABRICKS_HOSTNAME="https://your-databricks-workspace" + export DATABRICKS_TOKEN="your-personal-access-token" + + Key init args — indexing params: + + index_name: The name of the index to use. Format: "catalog.schema.index". + endpoint: The name of the Databricks Vector Search endpoint. If not specified, + the endpoint name is automatically inferred based on the index name. + + .. note:: + + If you are using `databricks-vectorsearch` version < 0.35, the `endpoint` parameter + is required when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + ... + ) + + embedding: The embedding model. + Required for direct-access index or delta-sync index + with self-managed embeddings. + text_column: The name of the text column to use for the embeddings. + Required for direct-access index or delta-sync index + with self-managed embeddings. + Make sure the text column specified is in the index. + columns: The list of column names to get when doing the search. + Defaults to ``[primary_key, text_column]``. + + Instantiate: + + `DatabricksVectorSearch` supports two types of indexes: + + * **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes. + + * **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK. + + Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes). + + If you are using a delta-sync index with Databricks-managed embeddings: + + .. code-block:: python + + from databricks_langchain.vectorstores import DatabricksVectorSearch + + vector_store = DatabricksVectorSearch(index_name="") + + If you are using a direct-access index or a delta-sync index with self-managed embeddings, + you also need to provide the embedding model and text column in your source table to + use for the embeddings: + + .. code-block:: python + + from langchain_openai import OpenAIEmbeddings + + vector_store = DatabricksVectorSearch( + index_name="", + embedding=OpenAIEmbeddings(), + text_column="document_content", + ) + + Add Documents: + .. code-block:: python + from langchain_core.documents import Document + + document_1 = Document(page_content="foo", metadata={"baz": "bar"}) + document_2 = Document(page_content="thud", metadata={"bar": "baz"}) + document_3 = Document(page_content="i will be deleted :(") + documents = [document_1, document_2, document_3] + ids = ["1", "2", "3"] + vector_store.add_documents(documents=documents, ids=ids) + + Delete Documents: + .. code-block:: python + vector_store.delete(ids=["3"]) + + .. note:: + + The `delete` method is only supported for direct-access index. + + Search: + .. code-block:: python + results = vector_store.similarity_search(query="thud",k=1) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + .. code-block:: python + *thud[{"id": "2"}] + + .. note: + + By default, similarity search only returns the primary key and text column. + If you want to retrieve the custom metadata associated with the document, + pass the additional columns in the `columns` parameter when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + columns=["baz", "bar"], + ) + + vector_store.similarity_search(query="thud", k=1) + # Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}] + + Search with filter: + .. code-block:: python + results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"}) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + .. code-block:: python + *thud[{"id": "2"}] + + Search with score: + .. code-block:: python + results = vector_store.similarity_search_with_score(query="qux",k=1) + for doc, score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * [SIM=0.748804] foo [{'id': '1'}] + + Async: + .. code-block:: python + # add documents + await vector_store.aadd_documents(documents=documents, ids=ids) + # delete documents + await vector_store.adelete(ids=["3"]) + # search + results = vector_store.asimilarity_search(query="thud",k=1) + # search with score + results = await vector_store.asimilarity_search_with_score(query="qux",k=1) + for doc,score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * [SIM=0.748807] foo [{'id': '1'}] + + Use as Retriever: + .. code-block:: python + retriever = vector_store.as_retriever( + search_type="mmr", + search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5}, + ) + retriever.invoke("thud") + .. code-block:: python + [Document(metadata={"id": "2"}, page_content="thud")] + """ # noqa: E501 + + def __init__( + self, + index_name: str, + endpoint: Optional[str] = None, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + columns: Optional[List[str]] = None, + ): + if not (isinstance(index_name, str) and _INDEX_NAME_PATTERN.match(index_name)): + raise ValueError( + "The `index_name` parameter must be a string in the format " + f"'catalog.schema.index'. Received: {index_name}" + ) + + try: + from databricks.vector_search.client import ( # type: ignore[import] + VectorSearchClient, + ) + except ImportError as e: + raise ImportError( + "Could not import databricks-vectorsearch python package. " + "Please install it with `pip install databricks-vectorsearch`." + ) from e + + try: + self.index = VectorSearchClient().get_index( + endpoint_name=endpoint, index_name=index_name + ) + except Exception as e: + if endpoint is None and "Wrong vector search endpoint" in str(e): + raise ValueError( + "The `endpoint` parameter is required for instantiating " + "DatabricksVectorSearch with the `databricks-vectorsearch` " + "version earlier than 0.35. Please provide the endpoint " + "name or upgrade to version 0.35 or later." + ) from e + else: + raise + + self._index_details = IndexDetails(self.index) + + _validate_embedding(embedding, self._index_details) + self._embeddings = embedding + self._text_column = _validate_and_get_text_column(text_column, self._index_details) + self._columns = _validate_and_get_return_columns( + columns or [], self._text_column, self._index_details + ) + self._primary_key = self._index_details.primary_key + + @property + def embeddings(self) -> Optional[Embeddings]: + """Access the query embedding object if available.""" + return self._embeddings + + @classmethod + def from_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VST: + raise NotImplementedError( + "`from_texts` is not supported. " + "Use `add_texts` to add to existing direct-access index." + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict]] = None, + ids: Optional[List[Any]] = None, + **kwargs: Any, + ) -> List[str]: + """Add texts to the index. + + .. note:: + + This method is only supported for a direct-access index. + + Args: + texts: List of texts to add. + metadatas: List of metadata for each text. Defaults to None. + ids: List of ids for each text. Defaults to None. + If not provided, a random uuid will be generated for each text. + + Returns: + List of ids from adding the texts into the index. + """ + if self._index_details.is_delta_sync_index(): + raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts") + + # Wrap to list if input texts is a single string + if isinstance(texts, str): + texts = [texts] + texts = list(texts) + vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr] + ids = ids or [str(uuid.uuid4()) for _ in texts] + metadatas = metadatas or [{} for _ in texts] + + updates = [ + { + self._primary_key: id_, + self._text_column: text, + self._index_details.embedding_vector_column["name"]: vector, + **metadata, + } + for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) + ] + + upsert_resp = self.index.upsert(updates) + if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): + failed_ids = upsert_resp.get("result", dict()).get("failed_primary_keys", []) + if upsert_resp.get("status") == "FAILURE": + logger.error("Failed to add texts to the index.") + else: + logger.warning("Some texts failed to be added to the index.") + return [id_ for id_ in ids if id_ not in failed_ids] + + return ids + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.add_texts, **kwargs), texts, metadatas + ) + + def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: + """Delete documents from the index. + + .. note:: + + This method is only supported for a direct-access index. + + Args: + ids: List of ids of documents to delete. + + Returns: + True if successful. + """ + if self._index_details.is_delta_sync_index(): + raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete") + + if ids is None: + raise ValueError("ids must be provided.") + self.index.delete(ids) + return True + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding. + """ + docs_with_score = self.similarity_search_with_score( + query=query, + k=k, + filter=filter, + query_type=query_type, + **kwargs, + ) + return [doc for doc, _ in docs_with_score] + + async def asimilarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search, query, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._index_details.is_databricks_managed_embeddings(): + query_text = query + query_vector = None + else: + # The value for `query_text` needs to be specified only for hybrid search. + if query_type is not None and query_type.upper() == "HYBRID": + query_text = query + else: + query_text = None + query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + + search_resp = self.index.similarity_search( + columns=self._columns, + query_text=query_text, + query_vector=query_vector, + filters=filter, + num_results=k, + query_type=query_type, + ) + return self._parse_search_response(search_resp) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + Databricks Vector search uses a normalized score 1/(1+d) where d + is the L2 distance. Hence, we simply return the identity function. + """ + return lambda score: score + + async def asimilarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[Document, float]]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_with_score, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + query: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError(_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector") + + docs_with_score = self.similarity_search_by_vector_with_score( + embedding=embedding, + k=k, + filter=filter, + query_type=query_type, + query=query, + **kwargs, + ) + return [doc for doc, _ in docs_with_score] + + async def asimilarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + query: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector, along with scores. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score" + ) + + if query_type is not None and query_type.upper() == "HYBRID": + if query is None: + raise ValueError("A value for `query` must be specified for hybrid search.") + query_text = query + else: + if query is not None: + raise ValueError( + ("Cannot specify both `embedding` and " '`query` unless `query_type="HYBRID"') + ) + query_text = None + + search_resp = self.index.similarity_search( + columns=self._columns, + query_vector=embedding, + query_text=query_text, + filters=filter, + num_results=k, + query_type=query_type, + ) + return self._parse_search_response(search_resp) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError(_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search") + + query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + docs = self.max_marginal_relevance_search_by_vector( + query_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + query_type=query_type, + ) + return docs + + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial( + self.max_marginal_relevance_search, + query, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + **kwargs, + ) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector" + ) + + embedding_column = self._index_details.embedding_vector_column["name"] + search_resp = self.index.similarity_search( + columns=list(set(self._columns + [embedding_column])), + query_text=None, + query_vector=embedding, + filters=filter, + num_results=fetch_k, + query_type=query_type, + ) + + embeddings_result_index = ( + search_resp.get("manifest").get("columns").index({"name": embedding_column}) + ) + embeddings = [ + doc[embeddings_result_index] for doc in search_resp.get("result").get("data_array") + ] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) + + ignore_cols: List = [embedding_column] if embedding_column not in self._columns else [] + candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) + selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError + + def _parse_search_response( + self, search_resp: Dict, ignore_cols: Optional[List[str]] = None + ) -> List[Tuple[Document, float]]: + """Parse the search response into a list of Documents with score.""" + if ignore_cols is None: + ignore_cols = [] + + columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] + docs_with_score = [] + for result in search_resp.get("result", dict()).get("data_array", []): + doc_id = result[columns.index(self._primary_key)] + text_content = result[columns.index(self._text_column)] + ignore_cols = [self._primary_key, self._text_column] + ignore_cols + metadata = { + col: value + for col, value in zip(columns[:-1], result[:-1]) + if col not in ignore_cols + } + metadata[self._primary_key] = doc_id + score = result[-1] + doc = Document(page_content=text_content, metadata=metadata) + docs_with_score.append((doc, score)) + return docs_with_score + + +def _validate_and_get_text_column(text_column: Optional[str], index_details: IndexDetails) -> str: + if index_details.is_databricks_managed_embeddings(): + index_source_column: str = index_details.embedding_source_column["name"] + # check if input text column matches the source column of the index + if text_column is not None: + raise ValueError( + f"The index '{index_details.name}' has the source column configured as " + f"'{index_source_column}'. Do not pass the `text_column` parameter." + ) + return index_source_column + else: + if text_column is None: + raise ValueError("The `text_column` parameter is required for this index.") + return text_column + + +def _validate_and_get_return_columns( + columns: List[str], text_column: str, index_details: IndexDetails +) -> List[str]: + """ + Get a list of columns to retrieve from the index. + + If the index is direct-access index, validate the given columns against the schema. + """ + # add primary key column and source column if not in columns + if index_details.primary_key not in columns: + columns.append(index_details.primary_key) + if text_column and text_column not in columns: + columns.append(text_column) + + # Validate specified columns are in the index + if index_details.is_direct_access_index() and (index_schema := index_details.schema): + if missing_columns := [c for c in columns if c not in index_schema]: + raise ValueError( + "Some columns specified in `columns` are not " + f"in the index schema: {missing_columns}" + ) + return columns + + +def _validate_embedding(embedding: Optional[Embeddings], index_details: IndexDetails) -> None: + if index_details.is_databricks_managed_embeddings(): + if embedding is not None: + raise ValueError( + f"The index '{index_details.name}' uses Databricks-managed embeddings. " + "Do not pass the `embedding` parameter when initializing vector store." + ) + else: + if not embedding: + raise ValueError( + "The `embedding` parameter is required for a direct-access index " + "or delta-sync index with self-managed embedding." + ) + _validate_embedding_dimension(embedding, index_details) + + +def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDetails) -> None: + """validate if the embedding dimension matches with the index's configuration.""" + if index_embedding_dimension := index_details.embedding_vector_column.get( + "embedding_dimension" + ): + # Infer the embedding dimension from the embedding function.""" + actual_dimension = len(embeddings.embed_query("test")) + if actual_dimension != index_embedding_dimension: + raise ValueError( + f"The specified embedding model's dimension '{actual_dimension}' does " + f"not match with the index configuration '{index_embedding_dimension}'." + ) + + +class IndexDetails: + """An utility class to store the configuration details of an index.""" + + def __init__(self, index: Any): + self._index_details = index.describe() + + @property + def name(self) -> str: + return self._index_details["name"] + + @property + def schema(self) -> Optional[Dict]: + if self.is_direct_access_index(): + schema_json = self.index_spec.get("schema_json") + if schema_json is not None: + return json.loads(schema_json) + return None + + @property + def primary_key(self) -> str: + return self._index_details["primary_key"] + + @property + def index_spec(self) -> Dict: + return ( + self._index_details.get("delta_sync_index_spec", {}) + if self.is_delta_sync_index() + else self._index_details.get("direct_access_index_spec", {}) + ) + + @property + def embedding_vector_column(self) -> Dict: + if vector_columns := self.index_spec.get("embedding_vector_columns"): + return vector_columns[0] + return {} + + @property + def embedding_source_column(self) -> Dict: + if source_columns := self.index_spec.get("embedding_source_columns"): + return source_columns[0] + return {} + + def is_delta_sync_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DELTA_SYNC.value + + def is_direct_access_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value + + def is_databricks_managed_embeddings(self) -> bool: + return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None diff --git a/integrations/langchain/tests/test_chat_models.py b/integrations/langchain/tests/test_chat_models.py new file mode 100644 index 0000000..f5742b4 --- /dev/null +++ b/integrations/langchain/tests/test_chat_models.py @@ -0,0 +1,448 @@ +"""Test chat model integration.""" + +import json +from typing import Generator +from unittest import mock + +import mlflow # type: ignore # noqa: F401 +import pytest +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolMessageChunk, +) +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.runnables import RunnableMap +from pydantic import BaseModel, Field + +from databricks_langchain.chat_models import ( + ChatDatabricks, + _convert_dict_to_message, + _convert_dict_to_message_chunk, + _convert_message_to_dict, +) + +_MOCK_CHAT_RESPONSE = { + "id": "chatcmpl_id", + "object": "chat.completion", + "created": 1721875529, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "To calculate the result of 36939 multiplied by 8922.4, " + "I get:\n\n36939 x 8922.4 = 329,511,111.6", + }, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66}, +} + +_MOCK_STREAM_RESPONSE = [ + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "36939"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "x"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "8922.4"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": " = "}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "329,511,111.6"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60}, + }, + { + "id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a", + "object": "chat.completion.chunk", + "created": 1721877054, + "model": "meta-llama-3.1-70b-instruct-072424", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66}, + }, +] + + +@pytest.fixture(autouse=True) +def mock_client() -> Generator: + client = mock.MagicMock() + client.predict.return_value = _MOCK_CHAT_RESPONSE + client.predict_stream.return_value = _MOCK_STREAM_RESPONSE + with mock.patch("mlflow.deployments.get_deploy_client", return_value=client): + yield + + +@pytest.fixture +def llm() -> ChatDatabricks: + return ChatDatabricks(endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks") + + +def test_dict(llm: ChatDatabricks) -> None: + d = llm.dict() + assert d["_type"] == "chat-databricks" + assert d["endpoint"] == "databricks-meta-llama-3-70b-instruct" + assert d["target_uri"] == "databricks" + + +def test_chat_model_predict(llm: ChatDatabricks) -> None: + res = llm.invoke( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ] + ) + assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index] + + +def test_chat_model_stream(llm: ChatDatabricks) -> None: + res = llm.stream( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ] + ) + for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE): + assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] + + +def test_chat_model_stream_with_usage(llm: ChatDatabricks) -> None: + def _assert_usage(chunk, expected): + usage = chunk.usage_metadata + assert usage is not None + assert usage["input_tokens"] == expected["usage"]["prompt_tokens"] + assert usage["output_tokens"] == expected["usage"]["completion_tokens"] + assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"] + + # Method 1: Pass stream_usage=True to the constructor + res = llm.stream( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ], + stream_usage=True, + ) + for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE): + assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] + _assert_usage(chunk, expected) + + # Method 2: Pass stream_usage=True to the constructor + llm_with_usage = ChatDatabricks( + endpoint="databricks-meta-llama-3-70b-instruct", + target_uri="databricks", + stream_usage=True, + ) + res = llm_with_usage.stream( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "36939 * 8922.4"}, + ], + ) + for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE): + assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] + _assert_usage(chunk, expected) + + +class GetWeather(BaseModel): + """Get the current weather in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + +class GetPopulation(BaseModel): + """Get the current population in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + +def test_chat_model_bind_tools(llm: ChatDatabricks) -> None: + llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) + response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") + assert isinstance(response, AIMessage) + + +@pytest.mark.parametrize( + ("tool_choice", "expected_output"), + [ + ("auto", "auto"), + ("none", "none"), + ("required", "required"), + # "any" should be replaced with "required" + ("any", "required"), + ("GetWeather", {"type": "function", "function": {"name": "GetWeather"}}), + ( + {"type": "function", "function": {"name": "GetWeather"}}, + {"type": "function", "function": {"name": "GetWeather"}}, + ), + ], +) +def test_chat_model_bind_tools_with_choices( + llm: ChatDatabricks, tool_choice, expected_output +) -> None: + llm_with_tool = llm.bind_tools([GetWeather], tool_choice=tool_choice) + assert llm_with_tool.kwargs["tool_choice"] == expected_output + + +def test_chat_model_bind_tolls_with_invalid_choices(llm: ChatDatabricks) -> None: + with pytest.raises(ValueError, match="Unrecognized tool_choice type"): + llm.bind_tools([GetWeather], tool_choice=123) + + # Non-existing tool + with pytest.raises(ValueError, match="Tool choice"): + llm.bind_tools( + [GetWeather], + tool_choice={"type": "function", "function": {"name": "NonExistingTool"}}, + ) + + +# Pydantic-based schema +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str = Field(description="The answer to the user question.") + justification: str = Field(description="The justification for the answer.") + + +# Raw JSON schema +JSON_SCHEMA = { + "title": "AnswerWithJustification", + "description": "An answer to the user question along with justification.", + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the user question.", + }, + "justification": { + "type": "string", + "description": "The justification for the answer.", + }, + }, + "required": ["answer", "justification"], +} + + +@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None]) +@pytest.mark.parametrize("method", ["function_calling", "json_mode"]) +def test_chat_model_with_structured_output(llm, schema, method: str): + if schema is None and method == "function_calling": + pytest.skip("Cannot use function_calling without schema") + + structured_llm = llm.with_structured_output(schema, method=method) + + bind = structured_llm.first.kwargs + if method == "function_calling": + assert bind["tool_choice"]["function"]["name"] == "AnswerWithJustification" + else: + assert bind["response_format"] == {"type": "json_object"} + + structured_llm = llm.with_structured_output(schema, include_raw=True, method=method) + assert isinstance(structured_llm.first, RunnableMap) + + +### Test data conversion functions ### + + +@pytest.mark.parametrize( + ("role", "expected_output"), + [ + ("user", HumanMessage("foo")), + ("system", SystemMessage("foo")), + ("assistant", AIMessage("foo")), + ("any_role", ChatMessage(content="foo", role="any_role")), + ], +) +def test_convert_message(role: str, expected_output: BaseMessage) -> None: + message = {"role": role, "content": "foo"} + result = _convert_dict_to_message(message) + assert result == expected_output + + # convert back + dict_result = _convert_message_to_dict(result) + assert dict_result == message + + +def test_convert_message_not_propagate_id() -> None: + # The AIMessage returned by the model endpoint can contain "id" field, + # but it is not always supported for requests. Therefore, we should not + # propagate it to the request payload. + message = AIMessage(content="foo", id="some-id") + result = _convert_message_to_dict(message) + assert "id" not in result + + +def test_convert_message_with_tool_calls() -> None: + ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12" + tool_calls = [ + { + "id": ID, + "type": "function", + "function": { + "name": "main__test__python_exec", + "arguments": '{"code": "result = 36939 * 8922.4"}', + }, + } + ] + message_with_tools = { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "id": ID, + } + result = _convert_dict_to_message(message_with_tools) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls}, + id=ID, + tool_calls=[ + { + "name": tool_calls[0]["function"]["name"], # type: ignore[index] + "args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index] + "id": ID, + "type": "tool_call", + } + ], + ) + assert result == expected_output + + # convert back + dict_result = _convert_message_to_dict(result) + message_with_tools.pop("id") # id is not propagated + assert dict_result == message_with_tools + + +@pytest.mark.parametrize( + ("role", "expected_output"), + [ + ("user", HumanMessageChunk(content="foo")), + ("system", SystemMessageChunk(content="foo")), + ("assistant", AIMessageChunk(content="foo")), + ("any_role", ChatMessageChunk(content="foo", role="any_role")), + ], +) +def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None: + delta = {"role": role, "content": "foo"} + result = _convert_dict_to_message_chunk(delta, "default_role") + assert result == expected_output + + # convert back + dict_result = _convert_message_to_dict(result) + assert dict_result == delta + + +def test_convert_message_chunk_with_tool_calls() -> None: + delta_with_tools = { + "role": "assistant", + "content": None, + "tool_calls": [{"index": 0, "function": {"arguments": " }"}}], + } + result = _convert_dict_to_message_chunk(delta_with_tools, "role") + expected_output = AIMessageChunk( + content="", + additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]}, + id=None, + tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)], + ) + assert result == expected_output + + +def test_convert_tool_message_chunk() -> None: + delta = { + "role": "tool", + "content": "foo", + "tool_call_id": "tool_call_id", + "id": "some_id", + } + result = _convert_dict_to_message_chunk(delta, "default_role") + expected_output = ToolMessageChunk(content="foo", id="some_id", tool_call_id="tool_call_id") + assert result == expected_output + + # convert back + dict_result = _convert_message_to_dict(result) + delta.pop("id") # id is not propagated + assert dict_result == delta + + +def test_convert_message_to_dict_function() -> None: + with pytest.raises(ValueError, match="Function messages are not supported"): + _convert_message_to_dict(FunctionMessage(content="", name="name")) diff --git a/integrations/langchain/tests/test_embeddings.py b/integrations/langchain/tests/test_embeddings.py new file mode 100644 index 0000000..5b9abc3 --- /dev/null +++ b/integrations/langchain/tests/test_embeddings.py @@ -0,0 +1,67 @@ +"""Test Together AI embeddings.""" + +from typing import Any, Dict, Generator +from unittest import mock + +import pytest +from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped] + +from databricks_langchain import DatabricksEmbeddings + + +def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]: + return { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": list(range(1536)), + "index": 0, + } + for _ in inputs["input"] + ], + "model": "text-embedding-3-small", + "usage": {"prompt_tokens": 8, "total_tokens": 8}, + } + + +@pytest.fixture +def mock_client() -> Generator: + client = mock.MagicMock() + client.predict.side_effect = _mock_embeddings + with mock.patch("mlflow.deployments.get_deploy_client", return_value=client): + yield client + + +@pytest.fixture +def embeddings() -> DatabricksEmbeddings: + return DatabricksEmbeddings( + endpoint="text-embedding-3-small", + documents_params={"fruit": "apple"}, + query_params={"fruit": "banana"}, + ) + + +def test_embed_documents( + mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings +) -> None: + documents = ["foo"] * 30 + output = embeddings.embed_documents(documents) + assert len(output) == 30 + assert len(output[0]) == 1536 + assert mock_client.predict.call_count == 2 + assert all( + call_arg[1]["inputs"]["fruit"] == "apple" + for call_arg in mock_client().predict.call_args_list + ) + + +def test_embed_query(mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings) -> None: + query = "foo bar" + output = embeddings.embed_query(query) + assert len(output) == 1536 + mock_client.predict.assert_called_once() + assert mock_client.predict.call_args[1] == { + "endpoint": "text-embedding-3-small", + "inputs": {"input": [query], "fruit": "banana"}, + } diff --git a/integrations/langchain/tests/test_vectorstores.py b/integrations/langchain/tests/test_vectorstores.py new file mode 100644 index 0000000..a78cc1b --- /dev/null +++ b/integrations/langchain/tests/test_vectorstores.py @@ -0,0 +1,617 @@ +import uuid +from typing import Any, Dict, Generator, List, Optional, Set +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from databricks.vector_search.client import VectorSearchIndex # type: ignore +from langchain_core.embeddings import Embeddings + +from databricks_langchain.vectorstores import DatabricksVectorSearch + +INPUT_TEXTS = ["foo", "bar", "baz"] +DEFAULT_VECTOR_DIMENSION = 4 + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION): + super().__init__() + self.dimension = dimension + + def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (self.dimension - 1) + [float(i)] for i in range(len(embedding_texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (self.dimension - 1) + [float(0.0)] + + +EMBEDDING_MODEL = FakeEmbeddings() + + +### Dummy similarity_search() Response ### +EXAMPLE_SEARCH_RESPONSE = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "text_vector"}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(INPUT_TEXTS), + "data_array": sorted( + [ + [str(uuid.uuid4()), s, e, 0.5] + for s, e in zip(INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)) + ], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + + +### Dummy Indices #### + +ENDPOINT_NAME = "test-endpoint" +DIRECT_ACCESS_INDEX = "test.direct_access.index" +DELTA_SYNC_INDEX = "test.delta_sync.index" +DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test.delta_sync_self_managed.index" +ALL_INDEX_NAMES = { + DIRECT_ACCESS_INDEX, + DELTA_SYNC_INDEX, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, +} + +INDEX_DETAILS = { + DELTA_SYNC_INDEX: { + "name": DELTA_SYNC_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_source_columns": [ + { + "name": "text", + "embedding_model_endpoint_name": "openai-text-embedding", + } + ], + }, + }, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: { + "name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + }, + }, + DIRECT_ACCESS_INDEX: { + "name": DIRECT_ACCESS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DIRECT_ACCESS", + "primary_key": "id", + "direct_access_index_spec": { + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + "schema_json": f"{{" + f'"{"id"}": "int", ' + f'"feat1": "str", ' + f'"feat2": "float", ' + f'"text": "string", ' + f'"{"text_vector"}": "array"' + f"}}", + }, + }, +} + + +@pytest.fixture(autouse=True) +def mock_vs_client() -> Generator: + def _get_index( + endpoint_name: Optional[str] = None, + index_name: str = None, # type: ignore + ) -> MagicMock: + index = MagicMock(spec=VectorSearchIndex) + index.describe.return_value = INDEX_DETAILS[index_name] + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + return index + + mock_client = MagicMock() + mock_client.get_index.side_effect = _get_index + with mock.patch( + "databricks.vector_search.client.VectorSearchClient", + return_value=mock_client, + ): + yield + + +def init_vector_search( + index_name: str, columns: Optional[List[str]] = None +) -> DatabricksVectorSearch: + kwargs: Dict[str, Any] = { + "index_name": index_name, + "columns": columns, + } + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding": EMBEDDING_MODEL, + "text_column": "text", + } + ) + return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type] + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + assert vectorsearch.index.describe() == INDEX_DETAILS[index_name] + + +def test_init_with_endpoint_name() -> None: + vectorsearch = DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=DELTA_SYNC_INDEX, + ) + assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX] + + +@pytest.mark.parametrize("index_name", [None, "invalid", 123, MagicMock(spec=VectorSearchIndex)]) +def test_init_fail_invalid_index_name(index_name) -> None: + with pytest.raises(ValueError, match="The `index_name` parameter must be"): + DatabricksVectorSearch(index_name=index_name) + + +def test_init_fail_text_column_mismatch() -> None: + with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"): + DatabricksVectorSearch( + index_name=DELTA_SYNC_INDEX, + text_column="some_other_column", + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_no_text_column(index_name: str) -> None: + with pytest.raises(ValueError, match="The `text_column` parameter is required"): + DatabricksVectorSearch( + index_name=index_name, + embedding=EMBEDDING_MODEL, + ) + + +def test_init_fail_columns_not_in_schema() -> None: + columns = ["some_random_column"] + with pytest.raises(ValueError, match="Some columns specified in `columns`"): + init_vector_search(DIRECT_ACCESS_INDEX, columns=columns) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_no_embedding(index_name: str) -> None: + with pytest.raises(ValueError, match="The `embedding` parameter is required"): + DatabricksVectorSearch( + index_name=index_name, + text_column="text", + ) + + +def test_init_fail_embedding_already_specified_in_source() -> None: + with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"): + DatabricksVectorSearch( + index_name=DELTA_SYNC_INDEX, + embedding=EMBEDDING_MODEL, + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_embedding_dim_mismatch(index_name: str) -> None: + with pytest.raises(ValueError, match="embedding model's dimension '1000' does not match"): + DatabricksVectorSearch( + index_name=index_name, + text_column="text", + embedding=FakeEmbeddings(1000), + ) + + +def test_from_texts_not_supported() -> None: + with pytest.raises(NotImplementedError, match="`from_texts` is not supported"): + DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX}) +def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + with pytest.raises( + NotImplementedError, + match="`add_texts` is only supported for direct-access index.", + ): + vectorsearch.add_texts(INPUT_TEXTS) + + +def is_valid_uuid(val: str) -> bool: + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def test_add_texts() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + ids = [idx for idx, i in enumerate(INPUT_TEXTS)] + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert added_ids == ids + + +def test_add_texts_handle_single_text() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS[0]) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids) + ] + ) + assert len(added_ids) == 1 + assert is_valid_uuid(added_ids[0]) + + +def test_add_texts_with_default_id() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +def test_add_texts_with_metadata() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))] + + added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + **metadata, # type: ignore[arg-type] + } + for text, vector, id_, metadata in zip(INPUT_TEXTS, vectors, added_ids, metadatas) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_embeddings_property(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + assert vectorsearch.embeddings == EMBEDDING_MODEL + + +def test_delete() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectorsearch.delete(["some id"]) + vectorsearch.index.delete.assert_called_once_with(["some id"]) + + +def test_delete_fail_no_ids() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + with pytest.raises(ValueError, match="ids must be provided."): + vectorsearch.delete() + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX}) +def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + with pytest.raises(NotImplementedError, match="`delete` is only supported for direct-access"): + vectorsearch.delete(["some id"]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("query_type", [None, "ANN"]) +def test_similarity_search(index_name: str, query_type: Optional[str]) -> None: + vectorsearch = init_vector_search(index_name) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search( + query, k=limit, filter=filters, query_type=query_type + ) + if index_name == DELTA_SYNC_INDEX: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + query_type=query_type, + ) + else: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=None, + query_vector=EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + query_type=query_type, + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_similarity_search_hybrid(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search( + query, k=limit, filter=filters, query_type="HYBRID" + ) + if index_name == DELTA_SYNC_INDEX: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + else: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +def test_similarity_search_both_filter_and_filters_passed() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + query = "foo" + filter = {"some filter": True} + filters = {"some other filter": False} + + vectorsearch.similarity_search(query, filter=filter, filters=filters) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=EMBEDDING_MODEL.embed_query(query), + # `filter` should prevail over `filters` + filters=filter, + num_results=4, + query_text=None, + query_type=None, + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +@pytest.mark.parametrize( + "columns, expected_columns", + [ + (None, {"id"}), + (["id", "text", "text_vector"], {"text_vector", "id"}), + ], +) +def test_mmr_search( + index_name: str, columns: Optional[List[str]], expected_columns: Set[str] +) -> None: + vectorsearch = init_vector_search(index_name, columns=columns) + + query = INPUT_TEXTS[0] + filters = {"some filter": True} + limit = 1 + + search_result = vectorsearch.max_marginal_relevance_search(query, k=limit, filters=filters) + assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]] + assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns] + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_mmr_parameters(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + + query = INPUT_TEXTS[0] + limit = 1 + fetch_k = 3 + lambda_mult = 0.25 + filters = {"some filter": True} + + with patch("databricks_langchain.vectorstores.maximal_marginal_relevance") as mock_mmr: + mock_mmr.return_value = [2] + retriever = vectorsearch.as_retriever( + search_type="mmr", + search_kwargs={ + "k": limit, + "fetch_k": fetch_k, + "lambda_mult": lambda_mult, + "filter": filters, + }, + ) + search_result = retriever.invoke(query) + + mock_mmr.assert_called_once() + assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult + assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k + assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters + assert len(search_result) == limit + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8]) +def test_similarity_score_threshold(index_name: str, threshold: float) -> None: + query = INPUT_TEXTS[0] + limit = len(INPUT_TEXTS) + + vectorsearch = init_vector_search(index_name) + retriever = vectorsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": limit, "score_threshold": threshold}, + ) + search_result = retriever.invoke(query) + if threshold <= 0.5: + assert len(search_result) == len(INPUT_TEXTS) + else: + assert len(search_result) == 0 + + +def test_standard_params() -> None: + vectorstore = init_vector_search(DIRECT_ACCESS_INDEX) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + "ls_embedding_provider": "FakeEmbeddings", + } + + vectorstore = init_vector_search(DELTA_SYNC_INDEX) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + } + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +@pytest.mark.parametrize("query_type", [None, "ANN"]) +def test_similarity_search_by_vector(index_name: str, query_type: Optional[str]) -> None: + vectorsearch = init_vector_search(index_name) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filter=filters, query_type=query_type + ) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=query_embedding, + filters=filters, + num_results=limit, + query_type=query_type, + query_text=None, + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_similarity_search_by_vector_hybrid(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo" + ) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=query_embedding, + filters=filters, + num_results=limit, + query_type="HYBRID", + query_text="foo", + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_similarity_search_empty_result(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + vectorsearch.index.similarity_search.return_value = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "score"}, + ], + }, + "result": { + "row_count": 0, + "data_array": [], + }, + "next_page_token": "", + } + + search_result = vectorsearch.similarity_search("foo") + assert len(search_result) == 0 + + +def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None: + vectorsearch = init_vector_search(DELTA_SYNC_INDEX) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + with pytest.raises(NotImplementedError, match="`similarity_search_by_vector` is not supported"): + vectorsearch.similarity_search_by_vector(query_embedding, k=limit, filters=filters) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index f7a2caf..0156a50 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -70,7 +70,10 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client): def test_poll_for_result_executing_query(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ - {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, + { + "status": "EXECUTING_QUERY", + "attachments": [{"query": {"query": "SELECT *"}}], + }, { "statement_response": { "status": {"state": "SUCCEEDED"}, @@ -116,7 +119,10 @@ def test_poll_for_result_max_iterations(genie, mock_workspace_client): patch("time.sleep", return_value=None), ): mock_workspace_client.genie._api.do.side_effect = [ - {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, + { + "status": "EXECUTING_QUERY", + "attachments": [{"query": {"query": "SELECT *"}}], + }, { "statement_response": { "status": {"state": "RUNNING"}, @@ -165,8 +171,20 @@ def test_parse_query_result_with_data(): }, "result": { "data_typed_array": [ - {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, - {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, + { + "values": [ + {"str": "1"}, + {"str": "Alice"}, + {"str": "2023-10-01T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "2"}, + {"str": "Bob"}, + {"str": "2023-10-02T00:00:00Z"}, + ] + }, ] }, } @@ -194,7 +212,13 @@ def test_parse_query_result_with_null_values(): }, "result": { "data_typed_array": [ - {"values": [{"str": "1"}, {"str": None}, {"str": "2023-10-01T00:00:00Z"}]}, + { + "values": [ + {"str": "1"}, + {"str": None}, + {"str": "2023-10-01T00:00:00Z"}, + ] + }, {"values": [{"str": "2"}, {"str": "Bob"}, {"str": None}]}, ] }, @@ -225,16 +249,76 @@ def test_parse_query_result_trims_large_data(): }, "result": { "data_typed_array": [ - {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, - {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, - {"values": [{"str": "3"}, {"str": "Charlie"}, {"str": "2023-10-03T00:00:00Z"}]}, - {"values": [{"str": "4"}, {"str": "David"}, {"str": "2023-10-04T00:00:00Z"}]}, - {"values": [{"str": "5"}, {"str": "Eve"}, {"str": "2023-10-05T00:00:00Z"}]}, - {"values": [{"str": "6"}, {"str": "Frank"}, {"str": "2023-10-06T00:00:00Z"}]}, - {"values": [{"str": "7"}, {"str": "Grace"}, {"str": "2023-10-07T00:00:00Z"}]}, - {"values": [{"str": "8"}, {"str": "Hank"}, {"str": "2023-10-08T00:00:00Z"}]}, - {"values": [{"str": "9"}, {"str": "Ivy"}, {"str": "2023-10-09T00:00:00Z"}]}, - {"values": [{"str": "10"}, {"str": "Jack"}, {"str": "2023-10-10T00:00:00Z"}]}, + { + "values": [ + {"str": "1"}, + {"str": "Alice"}, + {"str": "2023-10-01T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "2"}, + {"str": "Bob"}, + {"str": "2023-10-02T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "3"}, + {"str": "Charlie"}, + {"str": "2023-10-03T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "4"}, + {"str": "David"}, + {"str": "2023-10-04T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "5"}, + {"str": "Eve"}, + {"str": "2023-10-05T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "6"}, + {"str": "Frank"}, + {"str": "2023-10-06T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "7"}, + {"str": "Grace"}, + {"str": "2023-10-07T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "8"}, + {"str": "Hank"}, + {"str": "2023-10-08T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "9"}, + {"str": "Ivy"}, + {"str": "2023-10-09T00:00:00Z"}, + ] + }, + { + "values": [ + {"str": "10"}, + {"str": "Jack"}, + {"str": "2023-10-10T00:00:00Z"}, + ] + }, ] }, }