Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

together, standard-tests: specify tool_choice in standard tests #25548

Merged
merged 9 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions libs/partners/groq/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool

from langchain_groq import ChatGroq
from tests.unit_tests.fake.callbacks import (
Expand Down Expand Up @@ -393,6 +394,42 @@ class Joke(BaseModel):
assert len(result.punchline) != 0


def test_tool_calling_no_arguments() -> None:
# Note: this is a variant of a test in langchain_standard_tests
# that as of 2024-08-19 fails with "Failed to call a function. Please
# adjust your prompt." when `tool_choice="any"` is specified, but
# passes when `tool_choice` is not specified.
model = ChatGroq(model="llama-3.1-70b-versatile", temperature=0) # type: ignore[call-arg]

@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5

model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"

# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None
assert tool_call["type"] == "tool_call"


# Groq does not currently support N > 1
# @pytest.mark.scheduled
# def test_chat_multiple_completions() -> None:
Expand Down
15 changes: 13 additions & 2 deletions libs/partners/groq/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -28,11 +28,22 @@ class TestGroqLlama(BaseTestGroq):
@property
def chat_model_params(self) -> dict:
return {
"model": "llama-3.1-70b-versatile",
"model": "llama-3.1-8b-instant",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc why go smaller instead of bigger

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do 70b (I changed it from 8b to 70b this morning). My thought is if they support the same features we should prefer smaller models for tests in spirit of testing functionality vs. benchmarking.

"temperature": 0,
"rate_limiter": rate_limiter,
}

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"

@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)

@pytest.mark.xfail(
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
Expand All @@ -18,3 +18,8 @@ def chat_model_class(self) -> Type[BaseChatModel]:
@property
def chat_model_params(self) -> dict:
return {"model": "mistral-large-latest", "temperature": 0}

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Optional, Type

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -28,9 +28,10 @@ def chat_model_params(self) -> dict:
"rate_limiter": rate_limiter,
}

@pytest.mark.xfail(reason=("May not call a tool."))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"

@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def test_stop_sequence(self, model: BaseChatModel) -> None:
def test_tool_calling(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([magic_function], tool_choice=tool_choice)

# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
Expand All @@ -188,7 +192,13 @@ def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

model_with_tools = model.bind_tools([magic_function_no_args])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "magic_function_no_args"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools(
[magic_function_no_args], tool_choice=tool_choice
)
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)
Expand All @@ -212,7 +222,11 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
name="greeting_generator",
description="Generate a greeting in a particular style of speaking.",
)
model_with_tools = model.bind_tools([tool_])
if self.tool_choice_value == "tool_name":
tool_choice: Optional[str] = "greeting_generator"
else:
tool_choice = self.tool_choice_value
model_with_tools = model.bind_tools([tool_], tool_choice=tool_choice)
query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def model(self) -> BaseChatModel:
def has_tool_calling(self) -> bool:
return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools

@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return None

@property
def has_structured_output(self) -> bool:
return (
Expand Down
Loading