Skip to content

Commit

Permalink
core[patch]: fix ToolCall "type" when streaming (#24218)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Jul 13, 2024
1 parent 2b7d1cd commit 65321bf
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 95 deletions.
3 changes: 2 additions & 1 deletion libs/community/langchain_community/chat_models/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ToolMessage,
)
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Expand Down Expand Up @@ -96,7 +97,7 @@ def _parse_tool_calling(tool_call: dict) -> ToolCall:
name = tool_call["function"].get("name", "")
args = json.loads(tool_call["function"]["arguments"])
id = tool_call.get("id")
return ToolCall(name=name, args=args, id=id)
return create_tool_call(name=name, args=args, id=id)


def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/edenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
Expand All @@ -63,7 +65,7 @@ def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationCh
message = generated_result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
ToolCallChunk(
create_tool_call_chunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
Expand Down Expand Up @@ -189,15 +191,15 @@ def _extract_tool_calls_from_edenai_response(
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(
ToolCall(
create_tool_call(
name=raw_tool_call["name"],
args=json.loads(raw_tool_call["arguments"]),
id=raw_tool_call["id"],
)
)
except json.JSONDecodeError as exc:
invalid_tool_calls.append(
InvalidToolCall(
create_invalid_tool_call(
name=raw_tool_call.get("name"),
args=raw_tool_call.get("arguments"),
id=raw_tool_call.get("id"),
Expand Down
72 changes: 55 additions & 17 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.messages.tool import (
invalid_tool_call as create_invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)
from langchain_core.utils.json import parse_partial_json


class UsageMetadata(TypedDict):
Expand Down Expand Up @@ -106,24 +113,55 @@ def lc_attributes(self) -> Dict:

@root_validator(pre=True)
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
check_additional_kwargs = not any(
values.get(k)
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
if check_additional_kwargs and (
raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls")
):
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser(
raw_tool_calls
)
values["tool_calls"] = parsed_tool_calls
values["invalid_tool_calls"] = parsed_invalid_tool_calls
except Exception:
pass

# Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"):
updated: List = []
for tc in tool_calls:
updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
)
values["tool_calls"] = updated
if invalid_tool_calls := values.get("invalid_tool_calls"):
updated = []
for tc in invalid_tool_calls:
updated.append(
create_invalid_tool_call(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["invalid_tool_calls"] = updated

if tool_call_chunks := values.get("tool_call_chunks"):
updated = []
for tc in tool_call_chunks:
updated.append(
create_tool_call_chunk(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["tool_call_chunks"] = updated

return values

def pretty_repr(self, html: bool = False) -> str:
Expand Down Expand Up @@ -216,7 +254,7 @@ def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
if values["tool_calls"]:
values["tool_call_chunks"] = [
ToolCallChunk(
create_tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
Expand All @@ -228,7 +266,7 @@ def init_tool_calls(cls, values: dict) -> dict:
tool_call_chunks = values.get("tool_call_chunks", [])
tool_call_chunks.extend(
[
ToolCallChunk(
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in values["invalid_tool_calls"]
Expand All @@ -244,7 +282,7 @@ def init_tool_calls(cls, values: dict) -> dict:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
create_tool_call(
name=chunk["name"] or "",
args=args_,
id=chunk["id"],
Expand All @@ -254,7 +292,7 @@ def init_tool_calls(cls, values: dict) -> dict:
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
create_invalid_tool_call(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
Expand Down Expand Up @@ -297,7 +335,7 @@ def add_ai_message_chunks(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
ToolCallChunk(
create_tool_call_chunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
Expand Down
20 changes: 10 additions & 10 deletions libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,25 @@ def default_tool_parser(
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
for raw_tool_call in raw_tool_calls:
if "function" not in raw_tool_call:
continue
else:
function_name = tool_call["function"]["name"]
function_name = raw_tool_call["function"]["name"]
try:
function_args = json.loads(tool_call["function"]["arguments"])
parsed = ToolCall(
function_args = json.loads(raw_tool_call["function"]["arguments"])
parsed = tool_call(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
id=raw_tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
InvalidToolCall(
invalid_tool_call(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=None,
)
)
Expand All @@ -272,7 +272,7 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
parsed = tool_call_chunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),
Expand Down
8 changes: 4 additions & 4 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,12 @@ def merge_message_runs(
HumanMessage("wait your favorite food", id="bar",),
AIMessage(
"my favorite colo",
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123")],
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123", type="tool_call")],
id="baz",
),
AIMessage(
[{"type": "text", "text": "my favorite dish is lasagna"}],
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456")],
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456", type="tool_call")],
id="blur",
),
]
Expand All @@ -474,8 +474,8 @@ def merge_message_runs(
{"type": "text", "text": "my favorite dish is lasagna"}
],
tool_calls=[
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123"),
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456")
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123", "type": "tool_call"}),
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456", "type": "tool_call"})
]
id="baz"
),
Expand Down
57 changes: 41 additions & 16 deletions libs/core/tests/unit_tests/messages/test_ai.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk


def test_serdes_message() -> None:
msg = AIMessage(
content=[{"text": "blah", "type": "text"}],
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
tool_calls=[create_tool_call(name="foo", args={"bar": 1}, id="baz")],
invalid_tool_calls=[
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
create_invalid_tool_call(name="foobad", args="blah", id="booz", error="bad")
],
)
expected = {
Expand All @@ -23,9 +20,17 @@ def test_serdes_message() -> None:
"kwargs": {
"type": "ai",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"tool_calls": [
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
],
"invalid_tool_calls": [
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "bad",
"type": "invalid_tool_call",
}
],
},
}
Expand All @@ -38,8 +43,13 @@ def test_serdes_message_chunk() -> None:
chunk = AIMessageChunk(
content=[{"text": "blah", "type": "text"}],
tool_call_chunks=[
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
create_tool_call_chunk(name="foo", args='{"bar": 1}', id="baz", index=0),
create_tool_call_chunk(
name="foobad",
args="blah",
id="booz",
index=1,
),
],
)
expected = {
Expand All @@ -49,18 +59,33 @@ def test_serdes_message_chunk() -> None:
"kwargs": {
"type": "AIMessageChunk",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"tool_calls": [
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
],
"invalid_tool_calls": [
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": None,
"type": "invalid_tool_call",
}
],
"tool_call_chunks": [
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
{
"name": "foo",
"args": '{"bar": 1}',
"id": "baz",
"index": 0,
"type": "tool_call_chunk",
},
{
"name": "foobad",
"args": "blah",
"id": "booz",
"index": 1,
"type": "tool_call_chunk",
},
],
},
}
Expand Down
12 changes: 8 additions & 4 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ def test_merge_message_runs_content() -> None:
{"text": "bar", "type": "text"},
{"image_url": "...", "type": "image_url"},
],
tool_calls=[ToolCall(name="foo_tool", args={"x": 1}, id="tool1")],
tool_calls=[
ToolCall(name="foo_tool", args={"x": 1}, id="tool1", type="tool_call")
],
id="2",
),
AIMessage(
"baz",
tool_calls=[ToolCall(name="foo_tool", args={"x": 5}, id="tool2")],
tool_calls=[
ToolCall(name="foo_tool", args={"x": 5}, id="tool2", type="tool_call")
],
id="3",
),
]
Expand All @@ -54,8 +58,8 @@ def test_merge_message_runs_content() -> None:
"baz",
],
tool_calls=[
ToolCall(name="foo_tool", args={"x": 1}, id="tool1"),
ToolCall(name="foo_tool", args={"x": 5}, id="tool2"),
ToolCall(name="foo_tool", args={"x": 1}, id="tool1", type="tool_call"),
ToolCall(name="foo_tool", args={"x": 5}, id="tool2", type="tool_call"),
],
id="1",
),
Expand Down
Loading

0 comments on commit 65321bf

Please sign in to comment.