diff --git a/libs/databricks/langchain_databricks/chat_models.py b/libs/databricks/langchain_databricks/chat_models.py index c7dcf74..1c6f609 100644 --- a/libs/databricks/langchain_databricks/chat_models.py +++ b/libs/databricks/langchain_databricks/chat_models.py @@ -367,11 +367,16 @@ def bind_tools( if tool_choice: if isinstance(tool_choice, str): # tool_choice is a tool/function name - if tool_choice not in ("auto", "none", "required"): + if tool_choice not in ("auto", "none", "required", "any"): tool_choice = { "type": "function", "function": {"name": tool_choice}, } + # 'any' is not natively supported by OpenAI API, + # but supported by other models in Langchain. + # Ref: https://github.com/langchain-ai/langchain/blob/202d7f6c4a2ca8c7e5949d935bcf0ba9b0c23fb0/libs/partners/openai/langchain_openai/chat_models/base.py#L1098C1-L1101C45 + if tool_choice == "any": + tool_choice = "required" elif isinstance(tool_choice, dict): tool_names = [ formatted_tool["function"]["name"] diff --git a/libs/databricks/tests/integration_tests/test_chat_models.py b/libs/databricks/tests/integration_tests/test_chat_models.py index f4e05dd..f70ab6e 100644 --- a/libs/databricks/tests/integration_tests/test_chat_models.py +++ b/libs/databricks/tests/integration_tests/test_chat_models.py @@ -162,7 +162,8 @@ async def test_chat_databricks_abatch(): assert all(isinstance(response, AIMessage) for response in responses) -def test_chat_databricks_tool_calls(): +@pytest.mark.parametrize("tool_choice", [None, "auto", "required", "any", "none"]) +def test_chat_databricks_tool_calls(tool_choice): from pydantic import BaseModel, Field chat = ChatDatabricks( @@ -178,10 +179,14 @@ class GetWeather(BaseModel): ..., description="The city and state, e.g. San Francisco, CA" ) - llm_with_tools = chat.bind_tools([GetWeather]) + llm_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice) question = "Which is the current weather in Los Angeles, CA?" response = llm_with_tools.invoke(question) + if tool_choice == "none": + assert response.tool_calls == [] + return + assert response.tool_calls == [ { "name": "GetWeather", diff --git a/libs/databricks/tests/unit_tests/test_chat_models.py b/libs/databricks/tests/unit_tests/test_chat_models.py index fad3776..4848f11 100644 --- a/libs/databricks/tests/unit_tests/test_chat_models.py +++ b/libs/databricks/tests/unit_tests/test_chat_models.py @@ -187,21 +187,19 @@ def test_chat_model_stream(llm: ChatDatabricks) -> None: assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index] -def test_chat_model_bind_tools(llm: ChatDatabricks) -> None: - class GetWeather(BaseModel): - """Get the current weather in a given location""" +class GetWeather(BaseModel): + """Get the current weather in a given location""" - location: str = Field( - ..., description="The city and state, e.g. San Francisco, CA" - ) + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") - class GetPopulation(BaseModel): - """Get the current population in a given location""" - location: str = Field( - ..., description="The city and state, e.g. San Francisco, CA" - ) +class GetPopulation(BaseModel): + """Get the current population in a given location""" + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + +def test_chat_model_bind_tools(llm: ChatDatabricks) -> None: llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) response = llm_with_tools.invoke( "Which city is hotter today and which is bigger: LA or NY?" @@ -209,6 +207,40 @@ class GetPopulation(BaseModel): assert isinstance(response, AIMessage) +@pytest.mark.parametrize( + ("tool_choice", "expected_output"), + [ + ("auto", "auto"), + ("none", "none"), + ("required", "required"), + # "any" should be replaced with "required" + ("any", "required"), + ("GetWeather", {"type": "function", "function": {"name": "GetWeather"}}), + ( + {"type": "function", "function": {"name": "GetWeather"}}, + {"type": "function", "function": {"name": "GetWeather"}}, + ), + ], +) +def test_chat_model_bind_tools_with_choices( + llm: ChatDatabricks, tool_choice, expected_output +) -> None: + llm_with_tool = llm.bind_tools([GetWeather], tool_choice=tool_choice) + assert llm_with_tool.kwargs["tool_choice"] == expected_output + + +def test_chat_model_bind_tolls_with_invalid_choices(llm: ChatDatabricks) -> None: + with pytest.raises(ValueError, match="Unrecognized tool_choice type"): + llm.bind_tools([GetWeather], tool_choice=123) + + # Non-existing tool + with pytest.raises(ValueError, match="Tool choice"): + llm.bind_tools( + [GetWeather], + tool_choice={"type": "function", "function": {"name": "NonExistingTool"}}, + ) + + ### Test data conversion functions ###