From cfd554bffce6a2398bd54100f5dc85bf604402b8 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Fri, 1 Nov 2024 16:43:57 -0700 Subject: [PATCH 1/5] wip --- libs/core/langchain_core/messages/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 0ea03e40d00d4..14120caec5996 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -317,9 +317,14 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: except KeyError: msg_type = msg_kwargs.pop("type") # None msg content is not allowed - msg_content = msg_kwargs.pop("content") or "" + content_or_tool_calls = ( + "tool_calls" in msg_kwargs or "content" in msg_kwargs + ) + if not content_or_tool_calls: + raise KeyError("Must have one of content or tool calls") + msg_content = msg_kwargs.pop("content", "") or "" except KeyError as e: - msg = f"Message dict must contain 'role' and 'content' keys, got {message}" + msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}" msg = create_message( message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE ) @@ -957,6 +962,10 @@ def convert_to_openai_messages( oai_msg["name"] = message.name if isinstance(message, AIMessage) and message.tool_calls: oai_msg["tool_calls"] = _convert_to_openai_tool_calls(message.tool_calls) + if isinstance(message, AIMessage) and message.invalid_tool_calls: + oai_msg["tool_calls"] = oai_msg.get( + "tool_calls", [] + ) + _convert_to_openai_tool_calls(message.invalid_tool_calls) if message.additional_kwargs.get("refusal"): oai_msg["refusal"] = message.additional_kwargs["refusal"] if isinstance(message, ToolMessage): @@ -1393,14 +1402,18 @@ def _get_message_openai_role(message: BaseMessage) -> str: raise ValueError(msg) -def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]: +def _convert_to_openai_tool_calls( + tool_calls: list[ToolCall], invalid=False +) -> list[dict]: return [ { "type": "function", "id": tool_call["id"], "function": { "name": tool_call["name"], - "arguments": json.dumps(tool_call["args"]), + "arguments": tool_call["args"] + if invalid + else json.dumps(tool_call["args"]), }, } for tool_call in tool_calls From 1160090ce33ec8b374e5ab35f030720fba340e57 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Fri, 1 Nov 2024 16:45:56 -0700 Subject: [PATCH 2/5] fix --- libs/core/langchain_core/messages/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 14120caec5996..95f217174de0b 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -965,7 +965,7 @@ def convert_to_openai_messages( if isinstance(message, AIMessage) and message.invalid_tool_calls: oai_msg["tool_calls"] = oai_msg.get( "tool_calls", [] - ) + _convert_to_openai_tool_calls(message.invalid_tool_calls) + ) + _convert_to_openai_tool_calls(message.invalid_tool_calls, invalid=True) if message.additional_kwargs.get("refusal"): oai_msg["refusal"] = message.additional_kwargs["refusal"] if isinstance(message, ToolMessage): From 85a1215217a5f6b3187f7e2912f5b581b6fe286c Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Tue, 5 Nov 2024 07:30:54 -0800 Subject: [PATCH 3/5] fmt --- libs/core/langchain_core/messages/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 95f217174de0b..35e14815bbd58 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -321,10 +321,11 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: "tool_calls" in msg_kwargs or "content" in msg_kwargs ) if not content_or_tool_calls: - raise KeyError("Must have one of content or tool calls") + msg = "Must have one of content or tool calls" + raise KeyError(msg) msg_content = msg_kwargs.pop("content", "") or "" except KeyError as e: - msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}" + msg = f"Message dict must contain 'role' and one of 'content' or 'tool_calls' keys, got {message}" # noqa: E501 msg = create_message( message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE ) From df415417a19db06a3f7743fea0de8deda3e58fa8 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Tue, 5 Nov 2024 07:36:36 -0800 Subject: [PATCH 4/5] fmt --- libs/core/langchain_core/messages/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 35e14815bbd58..575d951953621 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -36,7 +36,12 @@ from langchain_core.messages.human import HumanMessage, HumanMessageChunk from langchain_core.messages.modifier import RemoveMessage from langchain_core.messages.system import SystemMessage, SystemMessageChunk -from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk +from langchain_core.messages.tool import ( + InvalidToolCall, + ToolCall, + ToolMessage, + ToolMessageChunk, +) if TYPE_CHECKING: from langchain_text_splitters import TextSplitter @@ -1404,7 +1409,7 @@ def _get_message_openai_role(message: BaseMessage) -> str: def _convert_to_openai_tool_calls( - tool_calls: list[ToolCall], invalid=False + tool_calls: list[Union[ToolCall, InvalidToolCall]], invalid: bool = False ) -> list[dict]: return [ { From 994bde53e36ac08a8cb7842f90d18d8761a06d8f Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Tue, 5 Nov 2024 07:39:42 -0800 Subject: [PATCH 5/5] fmt --- libs/core/langchain_core/messages/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 575d951953621..dfd57068cc7e5 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1409,7 +1409,7 @@ def _get_message_openai_role(message: BaseMessage) -> str: def _convert_to_openai_tool_calls( - tool_calls: list[Union[ToolCall, InvalidToolCall]], invalid: bool = False + tool_calls: Union[list[ToolCall], list[InvalidToolCall]], invalid: bool = False ) -> list[dict]: return [ {