Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core, anthropic[patch]: support streaming tool calls when function has no arguments #23915

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def init_tool_calls(cls, values: dict) -> dict:
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
Expand Down
6 changes: 6 additions & 0 deletions libs/core/tests/unit_tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def test_message_chunks() -> None:
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk

ai_msg_chunk = AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
)
assert ai_msg_chunk.tool_calls == [ToolCall(name="tool1", args={}, id="1")]

# Test token usage
left = AIMessageChunk(
content="",
Expand Down
6 changes: 6 additions & 0 deletions libs/partners/groq/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def chat_model_params(self) -> dict:
def test_structured_output(self, model: BaseChatModel) -> None:
super().test_structured_output(model)

@pytest.mark.xfail(
reason=("May pass arguments: {'properties': {}, 'type': 'object'}")
)
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)


class TestGroqLlama(BaseTestGroq):
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Type

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
Expand All @@ -18,3 +19,7 @@ def chat_model_class(self) -> Type[BaseChatModel]:
@property
def chat_model_params(self) -> dict:
return {"model": "mistralai/Mistral-7B-Instruct-v0.1"}

@pytest.mark.xfail(reason=("May not call a tool."))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
Expand All @@ -28,7 +29,13 @@ def magic_function(input: int) -> int:
return input + 2


def _validate_tool_call_message(message: AIMessage) -> None:
@tool
def magic_function_no_args() -> int:
"""Calculates a magic function."""
return 5


def _validate_tool_call_message(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
Expand All @@ -37,6 +44,15 @@ def _validate_tool_call_message(message: AIMessage) -> None:
assert tool_call["id"] is not None


def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call["name"] == "magic_function_no_args"
assert tool_call["args"] == {}
assert tool_call["id"] is not None


class ChatModelIntegrationTests(ChatModelTests):
def test_invoke(self, model: BaseChatModel) -> None:
result = model.invoke("Hello")
Expand Down Expand Up @@ -131,7 +147,6 @@ def test_tool_calling(self, model: BaseChatModel) -> None:
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
assert isinstance(result, AIMessage)
_validate_tool_call_message(result)

# Test stream
Expand All @@ -141,6 +156,21 @@ def test_tool_calling(self, model: BaseChatModel) -> None:
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)

def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

model_with_tools = model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
_validate_tool_call_message_no_args(result)

full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)

def test_structured_output(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
Expand Down
Loading