From 5d808d5e1019087e7d31afdadd9e9af82133e9ea Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Mon, 3 Mar 2025 22:58:01 +0200 Subject: [PATCH 1/6] Added ability to reflect on tool use --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 21 ++++-- pydantic_ai_slim/pydantic_ai/agent.py | 71 +++++++++++++++----- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 6447879ac..3205d18e1 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -98,6 +98,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy + reflect_on_tool_call: bool result_schema: _result.ResultSchema[ResultDataT] | None result_tools: list[ToolDefinition] @@ -304,16 +305,26 @@ async def _stream( assert self._result is not None # this should be set by the previous line async def _make_request( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> CallToolsNode[DepsT, NodeRunEndT]: if self._result is not None: return self._result model_settings, model_request_parameters = await self._prepare_request(ctx) - model_response, request_usage = await ctx.deps.model.request( - ctx.state.message_history, model_settings, model_request_parameters - ) - ctx.state.usage.incr(_usage.Usage(), requests=1) + if not ctx.deps.reflect_on_tool_call and isinstance(self.request.parts[0], _messages.ToolReturnPart): + model_response = models.ModelResponse( + parts=[_messages.TextPart(content='End tool call')], + model_name=ctx.deps.model.model_name, + ) + request_usage = _usage.Usage() + else: + model_response, request_usage = await ctx.deps.model.request( + ctx.state.message_history, + model_settings, + model_request_parameters, + ) + ctx.state.usage.incr(_usage.Usage(), requests=1) return self._finish_handling(ctx, model_response, request_usage) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d85672d27..329c50206 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -124,6 +124,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]): """ _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) + _reflect_on_tool_call: bool = dataclasses.field(repr=False) _result_tool_name: str = dataclasses.field(repr=False) _result_tool_description: str | None = dataclasses.field(repr=False) _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False) @@ -155,6 +156,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', + reflect_on_tool_call: bool = True, ): """Create an agent. @@ -184,6 +186,9 @@ def __init__( [override the model][pydantic_ai.Agent.override] for testing. end_strategy: Strategy for handling tool calls that are requested alongside a final result. See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. + reflect_on_tool_call: Whether to generate a final response with an LLM call after tool calls. + When set to `False`, the agent will return tool call results directly without generating + a final reflection response. Defaults to `True`. """ if model is None or defer_model_check: self.model = model @@ -212,6 +217,7 @@ def __init__( self._default_retries = retries self._max_result_retries = result_retries if result_retries is not None else retries + self._reflect_on_tool_call = reflect_on_tool_call for tool in tools: if isinstance(tool, Tool): self._register_tool(tool) @@ -319,7 +325,7 @@ async def iter( result_type: type[RunResultDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | None = None, - deps: AgentDepsT = None, + deps: AgentDepsT | None = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -433,20 +439,16 @@ async def main(): model_name=model_used.model_name if model_used else 'no-model', agent_name=self.name or 'agent', ) - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, + graph_deps = self._build_graph_deps( + deps, + user_prompt, + new_message_index, + model_used, + model_settings, + usage_limits, + result_schema, + result_validators, + run_span, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, @@ -1121,17 +1123,52 @@ def _get_model(self, model: models.Model | models.KnownModelName | None) -> mode return model_ - def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T: + def _get_deps(self: Agent[T, ResultDataT], deps: T | None) -> T: """Get deps for a run. If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. + + Args: + deps: Dependencies to use for this run, can be None. + + Returns: + The dependencies to use for the run. """ if some_deps := self._override_deps: return some_deps.value else: - return deps + return deps # type: ignore + + def _build_graph_deps( + self, + deps: AgentDepsT, + user_prompt: str | Sequence[_messages.UserContent], + new_message_index: int, + model_used: models.Model, + model_settings: ModelSettings | None, + usage_limits: _usage.UsageLimits, + result_schema: _result.ResultSchema[RunResultDataT] | None, + result_validators: list[_result.ResultValidator[AgentDepsT, RunResultDataT]], + run_span: logfire_api.LogfireSpan, + ) -> _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]: + return _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + result_schema=result_schema, + result_tools=self._result_schema.tool_defs() if self._result_schema else [], + result_validators=result_validators, + function_tools=self._function_tools, + reflect_on_tool_call=self._reflect_on_tool_call, + run_span=run_span, + ) def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. From 8bb0bebeb3a4b495a70e931081967db238194ed0 Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Tue, 4 Mar 2025 10:27:44 +0200 Subject: [PATCH 2/6] Draft tests --- tests/test_build_graph_deps.py | 70 ++++++ tests/test_reflect_on_tool_call.py | 357 +++++++++++++++++++++++++++++ 2 files changed, 427 insertions(+) create mode 100644 tests/test_build_graph_deps.py create mode 100644 tests/test_reflect_on_tool_call.py diff --git a/tests/test_build_graph_deps.py b/tests/test_build_graph_deps.py new file mode 100644 index 000000000..a79026410 --- /dev/null +++ b/tests/test_build_graph_deps.py @@ -0,0 +1,70 @@ +"""Tests for the _build_graph_deps method.""" + +import pytest +from typing import Any +from pydantic_ai.result import FinalResult + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai._agent_graph import GraphAgentDeps +from pydantic_ai.usage import UsageLimits + +pytestmark = pytest.mark.anyio + + +class TestBuildGraphDeps: + """Tests for the _build_graph_deps method that was added in the commit.""" + + async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: + """Test that reflect_on_tool_call is passed correctly to GraphAgentDeps.""" + # Create an agent with reflect_on_tool_call=True (the default) + agent_true = Agent(TestModel()) + + # Create mock parameters for _build_graph_deps + deps = None + user_prompt = 'test prompt' + new_message_index = 0 + model_used = TestModel() + model_settings = None + usage_limits = UsageLimits() + result_schema = None + result_validators: list[Any] = [] + run_span = None # type: ignore + + # pyright: reportPrivateUsage=false + # Call the method + graph_deps = agent_true._build_graph_deps( + deps, + user_prompt, + new_message_index, + model_used, + model_settings, + usage_limits, + result_schema, + result_validators, + run_span + ) + + # Verify that reflect_on_tool_call was correctly passed to GraphAgentDeps + assert isinstance(graph_deps, GraphAgentDeps) + assert graph_deps.reflect_on_tool_call is True + + # Create an agent with reflect_on_tool_call=False + agent_false = Agent(TestModel(), reflect_on_tool_call=False) + + # Call the method with the second agent + graph_deps = agent_false._build_graph_deps( + deps, + user_prompt, + new_message_index, + model_used, + model_settings, + usage_limits, + result_schema, + result_validators, + run_span + ) + + # Verify that reflect_on_tool_call=False was correctly passed to GraphAgentDeps + assert isinstance(graph_deps, GraphAgentDeps) + assert graph_deps.reflect_on_tool_call is False \ No newline at end of file diff --git a/tests/test_reflect_on_tool_call.py b/tests/test_reflect_on_tool_call.py new file mode 100644 index 000000000..400dd8cb5 --- /dev/null +++ b/tests/test_reflect_on_tool_call.py @@ -0,0 +1,357 @@ +"""Tests for reflect_on_tool_call functionality.""" + +import json +from typing import Any + +import pytest +from pydantic import BaseModel + +from pydantic_ai import Agent, capture_run_messages +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.test import TestModel + +pytestmark = pytest.mark.anyio + + +class TestReflectOnToolCall: + """Tests for the reflect_on_tool_call parameter.""" + + @pytest.mark.parametrize('reflect_on_tool_call', [True, False]) + def test_agent_constructor_reflects_parameter(self, reflect_on_tool_call: bool) -> None: + """Test that the reflect_on_tool_call parameter is stored in the agent.""" + agent = Agent(TestModel(), reflect_on_tool_call=reflect_on_tool_call) + + # pyright: reportPrivateUsage=false + assert agent._reflect_on_tool_call is reflect_on_tool_call + + def test_agent_default_to_reflect(self) -> None: + """Test that by default, reflect_on_tool_call is True.""" + agent = Agent(TestModel()) + + # pyright: reportPrivateUsage=false + assert agent._reflect_on_tool_call is True + + def test_with_reflection(self) -> None: + """Test normal behavior with reflection enabled (default behavior).""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return a tool call + return ModelResponse(parts=[ToolCallPart('get_data', {'param': 'test'})]) + elif len(messages) == 3: + # After tool return - return a summary + return ModelResponse(parts=[TextPart('Data retrieved: test_data')]) + else: + # This shouldn't happen in this test + return ModelResponse(parts=[TextPart('Unexpected message sequence')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'test_data: {param}' + + with capture_run_messages() as messages: + result = agent.run_sync('Get data') + + assert result.data == 'Data retrieved: test_data' + assert len(messages) == 4 # Initial request, first response, tool return, final summary + + # Verify the last message is a text summary from the model + assert isinstance(messages[-1], ModelResponse) + assert len(messages[-1].parts) == 1 + assert isinstance(messages[-1].parts[0], TextPart) + assert messages[-1].parts[0].content == 'Data retrieved: test_data' + + def test_without_reflection(self) -> None: + """Test behavior with reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return a tool call + return ModelResponse(parts=[ToolCallPart('get_data', {'param': 'test'})]) + else: + # This shouldn't be reached with reflection disabled + return ModelResponse(parts=[TextPart('This should not be called')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'test_data: {param}' + + with capture_run_messages() as messages: + result = agent.run_sync('Get data') + + # The result should be the text "End tool call" (based on the implementation) + assert result.data == 'End tool call' + + # There should be 4 messages: initial request, tool call, tool return + # But no final reflection message from the model + assert len(messages) == 4 + + # Check message sequence + assert isinstance(messages[0], ModelRequest) + assert isinstance(messages[1], ModelResponse) + assert isinstance(messages[2], ModelRequest) + assert isinstance(messages[3], ModelResponse) + + # Verify the last message is a tool return, not a model response + assert len(messages[2].parts) == 1 + assert isinstance(messages[2].parts[0], ToolReturnPart) + assert messages[2].parts[0].content == 'test_data: test' + + def test_multiple_tools_with_reflection(self) -> None: + """Test behavior with multiple tool calls and reflection enabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return multiple tool calls + return ModelResponse(parts=[ + ToolCallPart('tool_a', {'value': 'a'}), + ToolCallPart('tool_b', {'value': 'b'}) + ]) + elif len(messages) == 3: + # After tool returns - return a summary + return ModelResponse(parts=[TextPart('Tools executed: A and B')]) + else: + return ModelResponse(parts=[TextPart('Unexpected message sequence')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) + + @agent.tool_plain + def tool_a(value: str) -> str: + return f'Result A: {value}' + + @agent.tool_plain + def tool_b(value: str) -> str: + return f'Result B: {value}' + + with capture_run_messages() as messages: + result = agent.run_sync('Run tools') + + assert result.data == 'Tools executed: A and B' + assert len(messages) == 4 # Initial request, tool calls, tool returns, final summary + + def test_multiple_tools_without_reflection(self) -> None: + """Test behavior with multiple tool calls and reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return multiple tool calls + return ModelResponse(parts=[ + ToolCallPart('tool_a', {'value': 'a'}), + ToolCallPart('tool_b', {'value': 'b'}) + ]) + else: + # This shouldn't be reached with reflection disabled + return ModelResponse(parts=[TextPart('This should not be called')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) + + @agent.tool_plain + def tool_a(value: str) -> str: + return f'Result A: {value}' + + @agent.tool_plain + def tool_b(value: str) -> str: + return f'Result B: {value}' + + with capture_run_messages() as messages: + result = agent.run_sync('Run tools') + + # The result should be "End tool call" based on the implementation + assert result.data == 'End tool call' + + # There should be 4 messages: initial request, tool calls, tool returns, final summary + assert len(messages) == 4 + + # Verify the tool returns + assert isinstance(messages[2], ModelRequest) + assert len(messages[2].parts) == 2 + assert isinstance(messages[2].parts[0], ToolReturnPart) + assert isinstance(messages[2].parts[1], ToolReturnPart) + assert messages[2].parts[0].content == 'Result A: a' + assert messages[2].parts[1].content == 'Result B: b' + + def test_streaming_with_reflection(self) -> None: + """Test streaming behavior with reflection enabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return a tool call + return ModelResponse(parts=[ToolCallPart('get_data', {'param': 'test'})]) + elif len(messages) == 3: + # After tool return - return a summary + return ModelResponse(parts=[TextPart('Streaming data retrieved')]) + else: + return ModelResponse(parts=[TextPart('Unexpected message sequence')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'streaming_data: {param}' + + with capture_run_messages() as messages: + # Using run method instead of run_stream_sync + result = agent.run_sync('Stream data') + + # Verify the interaction completed + assert result.data == 'Streaming data retrieved' + + # Final message should be a text response + assert isinstance(messages[-1], ModelResponse) + assert len(messages[-1].parts) == 1 + assert isinstance(messages[-1].parts[0], TextPart) + + def test_streaming_without_reflection(self) -> None: + """Test streaming behavior with reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return a tool call + return ModelResponse(parts=[ToolCallPart('get_data', {'param': 'test'})]) + else: + # This shouldn't be reached with reflection disabled + return ModelResponse(parts=[TextPart('This should not be called')]) + + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'streaming_data: {param}' + + # Using run method instead of run_stream_sync + with capture_run_messages() as messages: + result = agent.run_sync('Stream data') + + # The result should be "End tool call" based on the implementation + assert result.data == 'End tool call' + + # There should be 4 messages total + assert len(messages) == 4 + + # The last message should be the "End tool call" response + assert isinstance(messages[-1], ModelResponse) + assert len(messages[-1].parts) == 1 + assert isinstance(messages[-1].parts[0], TextPart) + assert messages[-1].parts[0].content == 'End tool call' + + # The second-to-last message should be the tool return + assert isinstance(messages[-2], ModelRequest) + assert len(messages[-2].parts) == 1 + assert isinstance(messages[-2].parts[0], ToolReturnPart) + assert messages[-2].parts[0].content == 'streaming_data: test' + + def test_result_schema_with_reflection(self) -> None: + """Test structured result with reflection enabled.""" + class ResultData(BaseModel): + value: str + + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return a tool call + return ModelResponse(parts=[ToolCallPart('get_data', {'param': 'test'})]) + elif len(messages) == 3: + # After tool return - return a structured result + assert info.result_tools + return ModelResponse(parts=[ + ToolCallPart(info.result_tools[0].name, {'value': 'structured_value'}) + ]) + else: + return ModelResponse(parts=[TextPart('Unexpected message sequence')]) + + agent = Agent( + FunctionModel(model_func), + result_type=ResultData, + reflect_on_tool_call=True + ) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'data: {param}' + + with capture_run_messages() as messages: + result = agent.run_sync('Get structured data') + + assert isinstance(result.data, ResultData) + assert result.data.value == 'structured_value' + + # There should be 5 messages: + # 1. initial request + # 2. tool call + # 3. tool return + # 4. final result call + # 5. final result return + assert len(messages) == 5 + + # Verify message sequence + assert isinstance(messages[0], ModelRequest) # Initial request + assert isinstance(messages[1], ModelResponse) # Tool call + assert isinstance(messages[2], ModelRequest) # Tool return + assert isinstance(messages[3], ModelResponse) # Result call + assert isinstance(messages[4], ModelRequest) # Result return + + # Verify the result tool call + assert isinstance(messages[3].parts[0], ToolCallPart) + assert messages[3].parts[0].tool_name.startswith('final_result') + + def test_result_schema_without_reflection(self) -> None: + """Test structured result with reflection disabled.""" + class ResultData(BaseModel): + value: str + + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + # First call - return both a tool call and a structured result + assert info.result_tools + return ModelResponse(parts=[ + ToolCallPart('get_data', {'param': 'test'}), + ToolCallPart(info.result_tools[0].name, {'value': 'no_reflection_value'}) + ]) + else: + return ModelResponse(parts=[TextPart('This should not be called')]) + + agent = Agent( + FunctionModel(model_func), + result_type=ResultData, + reflect_on_tool_call=False + ) + + @agent.tool_plain + def get_data(param: str) -> str: + return f'data: {param}' + + with capture_run_messages() as messages: + result = agent.run_sync('Get structured data without reflection') + + assert isinstance(result.data, ResultData) + assert result.data.value == 'no_reflection_value' + + # Expected messages: + # 1. Initial request + # 2. Tool call + result call response + # 3. Tool return + result return request + assert len(messages) == 3 + + # Verify message sequence + assert isinstance(messages[0], ModelRequest) # Initial request + assert isinstance(messages[1], ModelResponse) # Tool call + result call + assert isinstance(messages[2], ModelRequest) # Tool return + result return + + # Verify there are two parts in the tool response + assert len(messages[1].parts) == 2 + assert isinstance(messages[1].parts[0], ToolCallPart) + assert isinstance(messages[1].parts[1], ToolCallPart) + assert messages[1].parts[0].tool_name == 'get_data' + assert messages[1].parts[1].tool_name.startswith('final_result') + + # Verify there are two parts in the tool return + assert len(messages[2].parts) == 2 + assert isinstance(messages[2].parts[0], ToolReturnPart) + assert isinstance(messages[2].parts[1], ToolReturnPart) + assert messages[2].parts[0].tool_name == 'get_data' + assert messages[2].parts[1].tool_name.startswith('final_result') \ No newline at end of file From 90123d4ccf504bcde7764b05aa1a25da83a33b0c Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Tue, 4 Mar 2025 13:59:52 +0200 Subject: [PATCH 3/6] Lint --- tests/test_build_graph_deps.py | 24 ++--- tests/test_reflect_on_tool_call.py | 168 ++++++++++++++--------------- 2 files changed, 94 insertions(+), 98 deletions(-) diff --git a/tests/test_build_graph_deps.py b/tests/test_build_graph_deps.py index a79026410..5cc4ef5b8 100644 --- a/tests/test_build_graph_deps.py +++ b/tests/test_build_graph_deps.py @@ -1,12 +1,12 @@ """Tests for the _build_graph_deps method.""" -import pytest from typing import Any -from pydantic_ai.result import FinalResult + +import pytest from pydantic_ai import Agent -from pydantic_ai.models.test import TestModel from pydantic_ai._agent_graph import GraphAgentDeps +from pydantic_ai.models.test import TestModel from pydantic_ai.usage import UsageLimits pytestmark = pytest.mark.anyio @@ -19,7 +19,7 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: """Test that reflect_on_tool_call is passed correctly to GraphAgentDeps.""" # Create an agent with reflect_on_tool_call=True (the default) agent_true = Agent(TestModel()) - + # Create mock parameters for _build_graph_deps deps = None user_prompt = 'test prompt' @@ -30,7 +30,7 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: result_schema = None result_validators: list[Any] = [] run_span = None # type: ignore - + # pyright: reportPrivateUsage=false # Call the method graph_deps = agent_true._build_graph_deps( @@ -42,16 +42,16 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: usage_limits, result_schema, result_validators, - run_span + run_span, ) - + # Verify that reflect_on_tool_call was correctly passed to GraphAgentDeps assert isinstance(graph_deps, GraphAgentDeps) assert graph_deps.reflect_on_tool_call is True - + # Create an agent with reflect_on_tool_call=False agent_false = Agent(TestModel(), reflect_on_tool_call=False) - + # Call the method with the second agent graph_deps = agent_false._build_graph_deps( deps, @@ -62,9 +62,9 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: usage_limits, result_schema, result_validators, - run_span + run_span, ) - + # Verify that reflect_on_tool_call=False was correctly passed to GraphAgentDeps assert isinstance(graph_deps, GraphAgentDeps) - assert graph_deps.reflect_on_tool_call is False \ No newline at end of file + assert graph_deps.reflect_on_tool_call is False diff --git a/tests/test_reflect_on_tool_call.py b/tests/test_reflect_on_tool_call.py index 400dd8cb5..a517b470a 100644 --- a/tests/test_reflect_on_tool_call.py +++ b/tests/test_reflect_on_tool_call.py @@ -1,7 +1,5 @@ """Tests for reflect_on_tool_call functionality.""" -import json -from typing import Any import pytest from pydantic import BaseModel @@ -28,19 +26,20 @@ class TestReflectOnToolCall: def test_agent_constructor_reflects_parameter(self, reflect_on_tool_call: bool) -> None: """Test that the reflect_on_tool_call parameter is stored in the agent.""" agent = Agent(TestModel(), reflect_on_tool_call=reflect_on_tool_call) - + # pyright: reportPrivateUsage=false assert agent._reflect_on_tool_call is reflect_on_tool_call def test_agent_default_to_reflect(self) -> None: """Test that by default, reflect_on_tool_call is True.""" agent = Agent(TestModel()) - + # pyright: reportPrivateUsage=false assert agent._reflect_on_tool_call is True def test_with_reflection(self) -> None: """Test normal behavior with reflection enabled (default behavior).""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return a tool call @@ -51,19 +50,19 @@ def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: else: # This shouldn't happen in this test return ModelResponse(parts=[TextPart('Unexpected message sequence')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) - + @agent.tool_plain def get_data(param: str) -> str: return f'test_data: {param}' - + with capture_run_messages() as messages: result = agent.run_sync('Get data') - + assert result.data == 'Data retrieved: test_data' assert len(messages) == 4 # Initial request, first response, tool return, final summary - + # Verify the last message is a text summary from the model assert isinstance(messages[-1], ModelResponse) assert len(messages[-1].parts) == 1 @@ -72,6 +71,7 @@ def get_data(param: str) -> str: def test_without_reflection(self) -> None: """Test behavior with reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return a tool call @@ -79,29 +79,29 @@ def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: else: # This shouldn't be reached with reflection disabled return ModelResponse(parts=[TextPart('This should not be called')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) - + @agent.tool_plain def get_data(param: str) -> str: return f'test_data: {param}' - + with capture_run_messages() as messages: result = agent.run_sync('Get data') - + # The result should be the text "End tool call" (based on the implementation) assert result.data == 'End tool call' - + # There should be 4 messages: initial request, tool call, tool return # But no final reflection message from the model assert len(messages) == 4 - + # Check message sequence assert isinstance(messages[0], ModelRequest) assert isinstance(messages[1], ModelResponse) - assert isinstance(messages[2], ModelRequest) + assert isinstance(messages[2], ModelRequest) assert isinstance(messages[3], ModelResponse) - + # Verify the last message is a tool return, not a model response assert len(messages[2].parts) == 1 assert isinstance(messages[2].parts[0], ToolReturnPart) @@ -109,67 +109,67 @@ def get_data(param: str) -> str: def test_multiple_tools_with_reflection(self) -> None: """Test behavior with multiple tool calls and reflection enabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return multiple tool calls - return ModelResponse(parts=[ - ToolCallPart('tool_a', {'value': 'a'}), - ToolCallPart('tool_b', {'value': 'b'}) - ]) + return ModelResponse( + parts=[ToolCallPart('tool_a', {'value': 'a'}), ToolCallPart('tool_b', {'value': 'b'})] + ) elif len(messages) == 3: # After tool returns - return a summary return ModelResponse(parts=[TextPart('Tools executed: A and B')]) else: return ModelResponse(parts=[TextPart('Unexpected message sequence')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) - + @agent.tool_plain def tool_a(value: str) -> str: return f'Result A: {value}' - + @agent.tool_plain def tool_b(value: str) -> str: return f'Result B: {value}' - + with capture_run_messages() as messages: result = agent.run_sync('Run tools') - + assert result.data == 'Tools executed: A and B' assert len(messages) == 4 # Initial request, tool calls, tool returns, final summary def test_multiple_tools_without_reflection(self) -> None: """Test behavior with multiple tool calls and reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return multiple tool calls - return ModelResponse(parts=[ - ToolCallPart('tool_a', {'value': 'a'}), - ToolCallPart('tool_b', {'value': 'b'}) - ]) + return ModelResponse( + parts=[ToolCallPart('tool_a', {'value': 'a'}), ToolCallPart('tool_b', {'value': 'b'})] + ) else: # This shouldn't be reached with reflection disabled return ModelResponse(parts=[TextPart('This should not be called')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) - + @agent.tool_plain def tool_a(value: str) -> str: return f'Result A: {value}' - + @agent.tool_plain def tool_b(value: str) -> str: return f'Result B: {value}' - + with capture_run_messages() as messages: result = agent.run_sync('Run tools') - + # The result should be "End tool call" based on the implementation assert result.data == 'End tool call' - + # There should be 4 messages: initial request, tool calls, tool returns, final summary assert len(messages) == 4 - + # Verify the tool returns assert isinstance(messages[2], ModelRequest) assert len(messages[2].parts) == 2 @@ -180,6 +180,7 @@ def tool_b(value: str) -> str: def test_streaming_with_reflection(self) -> None: """Test streaming behavior with reflection enabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return a tool call @@ -189,20 +190,20 @@ def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart('Streaming data retrieved')]) else: return ModelResponse(parts=[TextPart('Unexpected message sequence')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=True) - + @agent.tool_plain def get_data(param: str) -> str: return f'streaming_data: {param}' - + with capture_run_messages() as messages: # Using run method instead of run_stream_sync result = agent.run_sync('Stream data') - + # Verify the interaction completed assert result.data == 'Streaming data retrieved' - + # Final message should be a text response assert isinstance(messages[-1], ModelResponse) assert len(messages[-1].parts) == 1 @@ -210,6 +211,7 @@ def get_data(param: str) -> str: def test_streaming_without_reflection(self) -> None: """Test streaming behavior with reflection disabled.""" + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return a tool call @@ -217,29 +219,29 @@ def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: else: # This shouldn't be reached with reflection disabled return ModelResponse(parts=[TextPart('This should not be called')]) - + agent = Agent(FunctionModel(model_func), reflect_on_tool_call=False) - + @agent.tool_plain def get_data(param: str) -> str: return f'streaming_data: {param}' - + # Using run method instead of run_stream_sync with capture_run_messages() as messages: result = agent.run_sync('Stream data') - + # The result should be "End tool call" based on the implementation assert result.data == 'End tool call' - + # There should be 4 messages total assert len(messages) == 4 - + # The last message should be the "End tool call" response assert isinstance(messages[-1], ModelResponse) assert len(messages[-1].parts) == 1 assert isinstance(messages[-1].parts[0], TextPart) assert messages[-1].parts[0].content == 'End tool call' - + # The second-to-last message should be the tool return assert isinstance(messages[-2], ModelRequest) assert len(messages[-2].parts) == 1 @@ -248,9 +250,10 @@ def get_data(param: str) -> str: def test_result_schema_with_reflection(self) -> None: """Test structured result with reflection enabled.""" + class ResultData(BaseModel): value: str - + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return a tool call @@ -258,100 +261,93 @@ def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: elif len(messages) == 3: # After tool return - return a structured result assert info.result_tools - return ModelResponse(parts=[ - ToolCallPart(info.result_tools[0].name, {'value': 'structured_value'}) - ]) + return ModelResponse(parts=[ToolCallPart(info.result_tools[0].name, {'value': 'structured_value'})]) else: return ModelResponse(parts=[TextPart('Unexpected message sequence')]) - - agent = Agent( - FunctionModel(model_func), - result_type=ResultData, - reflect_on_tool_call=True - ) - + + agent = Agent(FunctionModel(model_func), result_type=ResultData, reflect_on_tool_call=True) + @agent.tool_plain def get_data(param: str) -> str: return f'data: {param}' - + with capture_run_messages() as messages: result = agent.run_sync('Get structured data') - + assert isinstance(result.data, ResultData) assert result.data.value == 'structured_value' - - # There should be 5 messages: - # 1. initial request + + # There should be 5 messages: + # 1. initial request # 2. tool call # 3. tool return # 4. final result call # 5. final result return assert len(messages) == 5 - + # Verify message sequence assert isinstance(messages[0], ModelRequest) # Initial request assert isinstance(messages[1], ModelResponse) # Tool call assert isinstance(messages[2], ModelRequest) # Tool return assert isinstance(messages[3], ModelResponse) # Result call assert isinstance(messages[4], ModelRequest) # Result return - + # Verify the result tool call assert isinstance(messages[3].parts[0], ToolCallPart) assert messages[3].parts[0].tool_name.startswith('final_result') def test_result_schema_without_reflection(self) -> None: """Test structured result with reflection disabled.""" + class ResultData(BaseModel): value: str - + def model_func(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: # First call - return both a tool call and a structured result assert info.result_tools - return ModelResponse(parts=[ - ToolCallPart('get_data', {'param': 'test'}), - ToolCallPart(info.result_tools[0].name, {'value': 'no_reflection_value'}) - ]) + return ModelResponse( + parts=[ + ToolCallPart('get_data', {'param': 'test'}), + ToolCallPart(info.result_tools[0].name, {'value': 'no_reflection_value'}), + ] + ) else: return ModelResponse(parts=[TextPart('This should not be called')]) - - agent = Agent( - FunctionModel(model_func), - result_type=ResultData, - reflect_on_tool_call=False - ) - + + agent = Agent(FunctionModel(model_func), result_type=ResultData, reflect_on_tool_call=False) + @agent.tool_plain def get_data(param: str) -> str: return f'data: {param}' - + with capture_run_messages() as messages: result = agent.run_sync('Get structured data without reflection') - + assert isinstance(result.data, ResultData) assert result.data.value == 'no_reflection_value' - + # Expected messages: # 1. Initial request # 2. Tool call + result call response # 3. Tool return + result return request assert len(messages) == 3 - + # Verify message sequence assert isinstance(messages[0], ModelRequest) # Initial request assert isinstance(messages[1], ModelResponse) # Tool call + result call assert isinstance(messages[2], ModelRequest) # Tool return + result return - + # Verify there are two parts in the tool response assert len(messages[1].parts) == 2 assert isinstance(messages[1].parts[0], ToolCallPart) assert isinstance(messages[1].parts[1], ToolCallPart) assert messages[1].parts[0].tool_name == 'get_data' assert messages[1].parts[1].tool_name.startswith('final_result') - + # Verify there are two parts in the tool return assert len(messages[2].parts) == 2 assert isinstance(messages[2].parts[0], ToolReturnPart) assert isinstance(messages[2].parts[1], ToolReturnPart) assert messages[2].parts[0].tool_name == 'get_data' - assert messages[2].parts[1].tool_name.startswith('final_result') \ No newline at end of file + assert messages[2].parts[1].tool_name.startswith('final_result') From 4f4152d2c33b6c714890e890b6291c5a670e0c7a Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Tue, 4 Mar 2025 18:31:26 +0200 Subject: [PATCH 4/6] Merge main --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++-- tests/test_build_graph_deps.py | 14 +++++++++----- tests/test_reflect_on_tool_call.py | 7 ++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 987dfabd6..c56f46cd2 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -8,7 +8,7 @@ from types import FrameType from typing import Any, Callable, Generic, cast, final, overload -from opentelemetry.trace import NoOpTracer, use_span, Span +from opentelemetry.trace import NoOpTracer, Span, Tracer, use_span from typing_extensions import TypeGuard, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext @@ -1154,7 +1154,7 @@ def _build_graph_deps( result_schema: _result.ResultSchema[RunResultDataT] | None, result_validators: list[_result.ResultValidator[AgentDepsT, RunResultDataT]], run_span: Span, - tracer, + tracer: Tracer, ) -> _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]: return _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( user_deps=deps, diff --git a/tests/test_build_graph_deps.py b/tests/test_build_graph_deps.py index 5cc4ef5b8..ead4aec51 100644 --- a/tests/test_build_graph_deps.py +++ b/tests/test_build_graph_deps.py @@ -1,8 +1,9 @@ """Tests for the _build_graph_deps method.""" -from typing import Any +from typing import Any, cast import pytest +from opentelemetry.trace import NoOpTracer, Span from pydantic_ai import Agent from pydantic_ai._agent_graph import GraphAgentDeps @@ -29,11 +30,12 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: usage_limits = UsageLimits() result_schema = None result_validators: list[Any] = [] - run_span = None # type: ignore + # Create a mock Span instance instead of using the class + run_span = cast(Span, object()) + tracer = NoOpTracer() - # pyright: reportPrivateUsage=false # Call the method - graph_deps = agent_true._build_graph_deps( + graph_deps = agent_true._build_graph_deps( # pyright: ignore[reportPrivateUsage] deps, user_prompt, new_message_index, @@ -43,6 +45,7 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: result_schema, result_validators, run_span, + tracer, ) # Verify that reflect_on_tool_call was correctly passed to GraphAgentDeps @@ -53,7 +56,7 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: agent_false = Agent(TestModel(), reflect_on_tool_call=False) # Call the method with the second agent - graph_deps = agent_false._build_graph_deps( + graph_deps = agent_false._build_graph_deps( # pyright: ignore[reportPrivateUsage] deps, user_prompt, new_message_index, @@ -63,6 +66,7 @@ async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None: result_schema, result_validators, run_span, + tracer, ) # Verify that reflect_on_tool_call=False was correctly passed to GraphAgentDeps diff --git a/tests/test_reflect_on_tool_call.py b/tests/test_reflect_on_tool_call.py index a517b470a..4ff877a60 100644 --- a/tests/test_reflect_on_tool_call.py +++ b/tests/test_reflect_on_tool_call.py @@ -1,6 +1,5 @@ """Tests for reflect_on_tool_call functionality.""" - import pytest from pydantic import BaseModel @@ -27,15 +26,13 @@ def test_agent_constructor_reflects_parameter(self, reflect_on_tool_call: bool) """Test that the reflect_on_tool_call parameter is stored in the agent.""" agent = Agent(TestModel(), reflect_on_tool_call=reflect_on_tool_call) - # pyright: reportPrivateUsage=false - assert agent._reflect_on_tool_call is reflect_on_tool_call + assert agent._reflect_on_tool_call is reflect_on_tool_call # pyright: ignore[reportPrivateUsage] def test_agent_default_to_reflect(self) -> None: """Test that by default, reflect_on_tool_call is True.""" agent = Agent(TestModel()) - # pyright: reportPrivateUsage=false - assert agent._reflect_on_tool_call is True + assert agent._reflect_on_tool_call is True # pyright: ignore[reportPrivateUsage] def test_with_reflection(self) -> None: """Test normal behavior with reflection enabled (default behavior).""" From 689f8091c687ad576af29d1fd0319e120fa474a6 Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Tue, 4 Mar 2025 18:34:02 +0200 Subject: [PATCH 5/6] Optimize runtime --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index eedf8b995..b57030b18 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -302,7 +302,6 @@ async def _make_request( if self._result is not None: return self._result - model_settings, model_request_parameters = await self._prepare_request(ctx) if not ctx.deps.reflect_on_tool_call and isinstance(self.request.parts[0], _messages.ToolReturnPart): model_response = models.ModelResponse( parts=[_messages.TextPart(content='End tool call')], @@ -310,6 +309,7 @@ async def _make_request( ) request_usage = _usage.Usage() else: + model_settings, model_request_parameters = await self._prepare_request(ctx) model_response, request_usage = await ctx.deps.model.request( ctx.state.message_history, model_settings, From 3362da85d9eeb04a382d8ea7496c79bbbfaead39 Mon Sep 17 00:00:00 2001 From: WorldInnovationsDepartment Date: Tue, 4 Mar 2025 18:43:07 +0200 Subject: [PATCH 6/6] fix tests --- tests/test_reflect_on_tool_call.py | 64 ++++++++++++++++-------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/tests/test_reflect_on_tool_call.py b/tests/test_reflect_on_tool_call.py index 4ff877a60..dcb8c16dc 100644 --- a/tests/test_reflect_on_tool_call.py +++ b/tests/test_reflect_on_tool_call.py @@ -89,20 +89,18 @@ def get_data(param: str) -> str: # The result should be the text "End tool call" (based on the implementation) assert result.data == 'End tool call' - # There should be 4 messages: initial request, tool call, tool return - # But no final reflection message from the model - assert len(messages) == 4 + # There should be 3 messages: initial request, tool call, final response + assert len(messages) == 3 # Check message sequence - assert isinstance(messages[0], ModelRequest) - assert isinstance(messages[1], ModelResponse) - assert isinstance(messages[2], ModelRequest) - assert isinstance(messages[3], ModelResponse) + assert isinstance(messages[0], ModelRequest) # Initial request + assert isinstance(messages[1], ModelResponse) # Tool call + assert isinstance(messages[2], ModelResponse) # End tool call response - # Verify the last message is a tool return, not a model response + # Verify the last message is the "End tool call" response assert len(messages[2].parts) == 1 - assert isinstance(messages[2].parts[0], ToolReturnPart) - assert messages[2].parts[0].content == 'test_data: test' + assert isinstance(messages[2].parts[0], TextPart) + assert messages[2].parts[0].content == 'End tool call' def test_multiple_tools_with_reflection(self) -> None: """Test behavior with multiple tool calls and reflection enabled.""" @@ -164,16 +162,25 @@ def tool_b(value: str) -> str: # The result should be "End tool call" based on the implementation assert result.data == 'End tool call' - # There should be 4 messages: initial request, tool calls, tool returns, final summary - assert len(messages) == 4 + # There should be 3 messages: initial request, tool calls, final response + assert len(messages) == 3 - # Verify the tool returns - assert isinstance(messages[2], ModelRequest) - assert len(messages[2].parts) == 2 - assert isinstance(messages[2].parts[0], ToolReturnPart) - assert isinstance(messages[2].parts[1], ToolReturnPart) - assert messages[2].parts[0].content == 'Result A: a' - assert messages[2].parts[1].content == 'Result B: b' + # Verify the message sequence + assert isinstance(messages[0], ModelRequest) # Initial request + assert isinstance(messages[1], ModelResponse) # Tool calls + assert isinstance(messages[2], ModelResponse) # End tool call response + + # Verify the tool calls message has multiple parts + assert len(messages[1].parts) == 2 + assert isinstance(messages[1].parts[0], ToolCallPart) + assert isinstance(messages[1].parts[1], ToolCallPart) + assert messages[1].parts[0].tool_name == 'tool_a' + assert messages[1].parts[1].tool_name == 'tool_b' + + # Verify the final response + assert len(messages[2].parts) == 1 + assert isinstance(messages[2].parts[0], TextPart) + assert messages[2].parts[0].content == 'End tool call' def test_streaming_with_reflection(self) -> None: """Test streaming behavior with reflection enabled.""" @@ -230,8 +237,13 @@ def get_data(param: str) -> str: # The result should be "End tool call" based on the implementation assert result.data == 'End tool call' - # There should be 4 messages total - assert len(messages) == 4 + # There should be 3 messages total + assert len(messages) == 3 + + # Check message sequence + assert isinstance(messages[0], ModelRequest) # Initial request + assert isinstance(messages[1], ModelResponse) # Tool call + assert isinstance(messages[2], ModelResponse) # End tool call # The last message should be the "End tool call" response assert isinstance(messages[-1], ModelResponse) @@ -239,12 +251,6 @@ def get_data(param: str) -> str: assert isinstance(messages[-1].parts[0], TextPart) assert messages[-1].parts[0].content == 'End tool call' - # The second-to-last message should be the tool return - assert isinstance(messages[-2], ModelRequest) - assert len(messages[-2].parts) == 1 - assert isinstance(messages[-2].parts[0], ToolReturnPart) - assert messages[-2].parts[0].content == 'streaming_data: test' - def test_result_schema_with_reflection(self) -> None: """Test structured result with reflection enabled.""" @@ -333,7 +339,7 @@ def get_data(param: str) -> str: # Verify message sequence assert isinstance(messages[0], ModelRequest) # Initial request assert isinstance(messages[1], ModelResponse) # Tool call + result call - assert isinstance(messages[2], ModelRequest) # Tool return + result return + assert isinstance(messages[2], ModelRequest) # Tool returns # Verify there are two parts in the tool response assert len(messages[1].parts) == 2 @@ -342,7 +348,7 @@ def get_data(param: str) -> str: assert messages[1].parts[0].tool_name == 'get_data' assert messages[1].parts[1].tool_name.startswith('final_result') - # Verify there are two parts in the tool return + # Verify tool returns assert len(messages[2].parts) == 2 assert isinstance(messages[2].parts[0], ToolReturnPart) assert isinstance(messages[2].parts[1], ToolReturnPart)