Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'any' tool_choice handling in bind_tools #27

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,16 @@ def bind_tools(
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "required"):
if tool_choice not in ("auto", "none", "required", "any"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
# 'any' is not natively supported by OpenAI API,
# but supported by other models in Langchain.
# Ref: https://github.com/langchain-ai/langchain/blob/202d7f6c4a2ca8c7e5949d935bcf0ba9b0c23fb0/libs/partners/openai/langchain_openai/chat_models/base.py#L1098C1-L1101C45
if tool_choice == "any":
tool_choice = "required"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
Expand Down
9 changes: 7 additions & 2 deletions libs/databricks/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ async def test_chat_databricks_abatch():
assert all(isinstance(response, AIMessage) for response in responses)


def test_chat_databricks_tool_calls():
@pytest.mark.parametrize("tool_choice", [None, "auto", "required", "any", "none"])
def test_chat_databricks_tool_calls(tool_choice):
from pydantic import BaseModel, Field

chat = ChatDatabricks(
Expand All @@ -178,10 +179,14 @@ class GetWeather(BaseModel):
..., description="The city and state, e.g. San Francisco, CA"
)

llm_with_tools = chat.bind_tools([GetWeather])
llm_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice)
question = "Which is the current weather in Los Angeles, CA?"
response = llm_with_tools.invoke(question)

if tool_choice == "none":
assert response.tool_calls == []
return

assert response.tool_calls == [
{
"name": "GetWeather",
Expand Down
54 changes: 43 additions & 11 deletions libs/databricks/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,60 @@ def test_chat_model_stream(llm: ChatDatabricks) -> None:
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]


def test_chat_model_bind_tools(llm: ChatDatabricks) -> None:
class GetWeather(BaseModel):
"""Get the current weather in a given location"""
class GetWeather(BaseModel):
"""Get the current weather in a given location"""

location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")

class GetPopulation(BaseModel):
"""Get the current population in a given location"""

location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
"""Get the current population in a given location"""

location: str = Field(..., description="The city and state, e.g. San Francisco, CA")


def test_chat_model_bind_tools(llm: ChatDatabricks) -> None:
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
response = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
assert isinstance(response, AIMessage)


@pytest.mark.parametrize(
("tool_choice", "expected_output"),
[
("auto", "auto"),
("none", "none"),
("required", "required"),
# "any" should be replaced with "required"
("any", "required"),
("GetWeather", {"type": "function", "function": {"name": "GetWeather"}}),
(
{"type": "function", "function": {"name": "GetWeather"}},
{"type": "function", "function": {"name": "GetWeather"}},
),
],
)
def test_chat_model_bind_tools_with_choices(
llm: ChatDatabricks, tool_choice, expected_output
) -> None:
llm_with_tool = llm.bind_tools([GetWeather], tool_choice=tool_choice)
assert llm_with_tool.kwargs["tool_choice"] == expected_output


def test_chat_model_bind_tolls_with_invalid_choices(llm: ChatDatabricks) -> None:
with pytest.raises(ValueError, match="Unrecognized tool_choice type"):
llm.bind_tools([GetWeather], tool_choice=123)

# Non-existing tool
with pytest.raises(ValueError, match="Tool choice"):
llm.bind_tools(
[GetWeather],
tool_choice={"type": "function", "function": {"name": "NonExistingTool"}},
)


### Test data conversion functions ###


Expand Down
Loading