diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 3aff07faecd4c..605352e232139 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -495,20 +495,25 @@ def convert_to_openai_tool( def tool_example_to_messages( - input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None + input: str, + tool_calls: list[BaseModel], + tool_outputs: Optional[list[str]] = None, + ai_response: Optional[str] = None, ) -> list[BaseMessage]: """Convert an example into a list of messages that can be fed into an LLM. This code is an adapter that converts a single example to a list of messages that can be fed into a chat model. - The list of messages per example corresponds to: + The list of messages per example by default corresponds to: 1) HumanMessage: contains the content from which content should be extracted. 2) AIMessage: contains the extracted information from the model 3) ToolMessage: contains confirmation to the model that the model requested a tool correctly. + If `ai_response` is specified, there will be a final AIMessage with that response. + The ToolMessage is required because some chat models are hyper-optimized for agents rather than for an extraction use case. @@ -519,6 +524,7 @@ def tool_example_to_messages( tool_outputs: Optional[List[str]], a list of tool call outputs. Does not need to be provided. If not provided, a placeholder value will be inserted. Defaults to None. + ai_response: Optional[str], if provided, content for a final AIMessage. Returns: A list of messages @@ -584,6 +590,9 @@ class Person(BaseModel): ) for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore + + if ai_response: + messages.append(AIMessage(content=ai_response)) return messages diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 4eaa3da2b19ab..ba4c50187f139 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -679,6 +679,24 @@ def test_tool_outputs() -> None: ] assert messages[2].content == "Output1" + # Test final AI response + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + ], + tool_outputs=["Output1"], + ai_response="The output is Output1", + ) + assert len(messages) == 4 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], AIMessage) + response = messages[3] + assert response.content == "The output is Output1" + assert not response.tool_calls + @pytest.mark.parametrize("use_extension_typed_dict", [True, False]) @pytest.mark.parametrize("use_extension_annotated", [True, False])