diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index be8814bc3e7ad..2e5a9620b2287 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -14,6 +14,7 @@ ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool from langchain_groq import ChatGroq from tests.unit_tests.fake.callbacks import ( @@ -393,6 +394,42 @@ class Joke(BaseModel): assert len(result.punchline) != 0 +def test_tool_calling_no_arguments() -> None: + # Note: this is a variant of a test in langchain_standard_tests + # that as of 2024-08-19 fails with "Failed to call a function. Please + # adjust your prompt." when `tool_choice="any"` is specified, but + # passes when `tool_choice` is not specified. + model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg] + + @tool + def magic_function_no_args() -> int: + """Calculates a magic function.""" + return 5 + + model_with_tools = model.bind_tools([magic_function_no_args]) + query = "What is the value of magic_function()? Use the tool." + result = model_with_tools.invoke(query) + assert isinstance(result, AIMessage) + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call["name"] == "magic_function_no_args" + assert tool_call["args"] == {} + assert tool_call["id"] is not None + assert tool_call["type"] == "tool_call" + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in model_with_tools.stream(query): + full = chunk if full is None else full + chunk # type: ignore + assert isinstance(full, AIMessage) + assert len(full.tool_calls) == 1 + tool_call = full.tool_calls[0] + assert tool_call["name"] == "magic_function_no_args" + assert tool_call["args"] == {} + assert tool_call["id"] is not None + assert tool_call["type"] == "tool_call" + + # Groq does not currently support N > 1 # @pytest.mark.scheduled # def test_chat_multiple_completions() -> None: diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 6feab74f60677..38fe554c5c779 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel @@ -28,11 +28,22 @@ class TestGroqLlama(BaseTestGroq): @property def chat_model_params(self) -> dict: return { - "model": "llama-3.1-70b-versatile", + "model": "llama-3.1-8b-instant", "temperature": 0, "rate_limiter": rate_limiter, } + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "any" + + @pytest.mark.xfail( + reason=("Fails with 'Failed to call a function. Please adjust your prompt.'") + ) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) + @pytest.mark.xfail( reason=("Fails with 'Failed to call a function. Please adjust your prompt.'") ) diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index 965cd03c4b178..cea6399ee4cd8 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found] @@ -18,3 +18,8 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return {"model": "mistral-large-latest", "temperature": 0} + + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "any" diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index 2250873f4b659..18c167f8a91dc 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel @@ -28,9 +28,10 @@ def chat_model_params(self) -> dict: "rate_limiter": rate_limiter, } - @pytest.mark.xfail(reason=("May not call a tool.")) - def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: - super().test_tool_calling_with_no_arguments(model) + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "tool_name" @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index bcb47a4c151a7..32d922a3e4a36 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -170,7 +170,11 @@ def test_stop_sequence(self, model: BaseChatModel) -> None: def test_tool_calling(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function]) + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "magic_function" + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice) # Test invoke query = "What is the value of magic_function(3)? Use the tool." @@ -188,7 +192,13 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model_with_tools = model.bind_tools([magic_function_no_args]) + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "magic_function_no_args" + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools( + [magic_function_no_args], tool_choice=tool_choice + ) query = "What is the value of magic_function()? Use the tool." result = model_with_tools.invoke(query) _validate_tool_call_message_no_args(result) @@ -212,7 +222,11 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: name="greeting_generator", description="Generate a greeting in a particular style of speaking.", ) - model_with_tools = model.bind_tools([tool_]) + if self.tool_choice_value == "tool_name": + tool_choice: Optional[str] = "greeting_generator" + else: + tool_choice = self.tool_choice_value + model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice) query = "Using the tool, generate a Pirate greeting." result = model_with_tools.invoke(query) assert isinstance(result, AIMessage) diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index ed73771dbdae0..6597b16177be4 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -96,6 +96,11 @@ def model(self) -> BaseChatModel: def has_tool_calling(self) -> bool: return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return None + @property def has_structured_output(self) -> bool: return (