Skip to content

Commit

Permalink
together, standard-tests: specify tool_choice in standard tests (#25548)
Browse files Browse the repository at this point in the history
Here we allow standard tests to specify a value for `tool_choice` via a
`tool_choice_value` property, which defaults to None.

Chat models [available in
Together](https://docs.together.ai/docs/chat-models) have issues passing
standard tool calling tests:
- llama 3.1 models currently [appear to rely on user-side
parsing](https://docs.together.ai/docs/llama-3-function-calling) in
Together;
- Mixtral-8x7B and Mistral-7B (currently tested) consistently do not
call tools in some tests.

Specifying tool_choice also lets us remove an existing `xfail` and use a
smaller model in Groq tests.
  • Loading branch information
ccurme authored Aug 19, 2024
1 parent 015ab91 commit c5bf114
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 10 deletions.
37 changes: 37 additions & 0 deletions libs/partners/groq/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool

from langchain_groq import ChatGroq
from tests.unit_tests.fake.callbacks import (
Expand Down Expand Up @@ -393,6 +394,42 @@ class Joke(BaseModel):
assert len(result.punchline) != 0


def test_tool_calling_no_arguments() -> None:
# Note: this is a variant of a test in langchain_standard_tests
# that as of 2024-08-19 fails with "Failed to call a function. Please
# adjust your prompt." when `tool_choice="any"` is specified, but
# passes when `tool_choice` is not specified.
model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg]

@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5

model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"

# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"


# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:
Expand Down
15 changes: 13 additions & 2 deletions libs/partners/groq/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -28,11 +28,22 @@ class TestGroqLlama(BaseTestGroq):
@property
def chat_model_params(self) -> dict:
return {
"model": "llama-3.1-70b-versatile",
"model": "llama-3.1-8b-instant",
"temperature": 0,
"rate_limiter": rate_limiter,
}

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"

@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
Expand All @@ -18,3 +18,8 @@ def chat_model_class(self) -> Type[BaseChatModel]:
@property
def chat_model_params(self) -> dict:
return {"model": "mistral-large-latest", "temperature": 0}

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -28,9 +28,10 @@ def chat_model_params(self) -> dict:
"rate_limiter": rate_limiter,
}

@pytest.mark.xfail(reason=("May not call a tool."))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"

@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def test_stop_sequence(self, model: BaseChatModel) -> None:
def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)

# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
Expand All @@ -188,7 +192,13 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

model_with_tools = model.bind_tools([magic_function_no_args])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function_no_args"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools(
[magic_function_no_args], tool_choice=tool_choice
)
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)
Expand All @@ -212,7 +222,11 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
name="greeting_generator",
description="Generate a greeting in a particular style of speaking.",
)
model_with_tools = model.bind_tools([tool_])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "greeting_generator"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def model(self) -> BaseChatModel:
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return None

@property
def has_structured_output(self) -> bool:
return (
Expand Down

0 comments on commit c5bf114

Please sign in to comment.