Skip to content

Disable reflection after tool use #1040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

23 changes: 17 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
usage_limits: _usage.UsageLimits
max_result_retries: int
end_strategy: EndStrategy
reflect_on_tool_call: bool

result_schema: _result.ResultSchema[ResultDataT] | None
result_tools: list[ToolDefinition]
Expand Down Expand Up @@ -295,16 +296,26 @@ async def _stream(
assert self._result is not None # this should be set by the previous line

async def _make_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
) -> CallToolsNode[DepsT, NodeRunEndT]:
if self._result is not None:
return self._result

model_settings, model_request_parameters = await self._prepare_request(ctx)
model_response, request_usage = await ctx.deps.model.request(
ctx.state.message_history, model_settings, model_request_parameters
)
ctx.state.usage.incr(_usage.Usage(), requests=1)
if not ctx.deps.reflect_on_tool_call and isinstance(self.request.parts[0], _messages.ToolReturnPart):
model_response = models.ModelResponse(
parts=[_messages.TextPart(content='End tool call')],
model_name=ctx.deps.model.model_name,
)
request_usage = _usage.Usage()
else:
model_settings, model_request_parameters = await self._prepare_request(ctx)
model_response, request_usage = await ctx.deps.model.request(
ctx.state.message_history,
model_settings,
model_request_parameters,
)
ctx.state.usage.incr(_usage.Usage(), requests=1)

return self._finish_handling(ctx, model_response, request_usage)

Expand Down
78 changes: 58 additions & 20 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from types import FrameType
from typing import Any, Callable, Generic, cast, final, overload

from opentelemetry.trace import NoOpTracer, use_span
from opentelemetry.trace import NoOpTracer, Span, Tracer, use_span
from typing_extensions import TypeGuard, TypeVar, deprecated

from pydantic_graph import End, Graph, GraphRun, GraphRunContext
Expand Down Expand Up @@ -116,6 +116,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
"""Automatically instrument with OpenTelemetry. Will use Logfire if it's configured."""

_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
_reflect_on_tool_call: bool = dataclasses.field(repr=False)
_result_tool_name: str = dataclasses.field(repr=False)
_result_tool_description: str | None = dataclasses.field(repr=False)
_result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
instrument: bool = False,
reflect_on_tool_call: bool = True,
):
"""Create an agent.

Expand Down Expand Up @@ -178,6 +180,9 @@ def __init__(
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
instrument: Automatically instrument with OpenTelemetry. Will use Logfire if it's configured.
reflect_on_tool_call: Whether to generate a final response with an LLM call after tool calls.
When set to `False`, the agent will return tool call results directly without generating
a final reflection response. Defaults to `True`.
"""
if model is None or defer_model_check:
self.model = model
Expand Down Expand Up @@ -207,6 +212,7 @@ def __init__(

self._default_retries = retries
self._max_result_retries = result_retries if result_retries is not None else retries
self._reflect_on_tool_call = reflect_on_tool_call
for tool in tools:
if isinstance(tool, Tool):
self._register_tool(tool)
Expand Down Expand Up @@ -314,7 +320,7 @@ async def iter(
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDepsT = None,
deps: AgentDepsT | None = None,
model_settings: ModelSettings | None = None,
usage_limits: _usage.UsageLimits | None = None,
usage: _usage.Usage | None = None,
Expand Down Expand Up @@ -434,22 +440,17 @@ async def main():
'logfire.msg': f'{agent_name} run',
},
)

graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
result_schema=result_schema,
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
result_validators=result_validators,
function_tools=self._function_tools,
run_span=run_span,
tracer=tracer,
graph_deps = self._build_graph_deps(
deps,
user_prompt,
new_message_index,
model_used,
model_settings,
usage_limits,
result_schema,
result_validators,
run_span,
tracer,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
Expand Down Expand Up @@ -1124,17 +1125,54 @@ def _get_model(self, model: models.Model | models.KnownModelName | None) -> mode

return model_

def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
def _get_deps(self: Agent[T, ResultDataT], deps: T | None) -> T:
"""Get deps for a run.

If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.

We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.

Args:
deps: Dependencies to use for this run, can be None.

Returns:
The dependencies to use for the run.
"""
if some_deps := self._override_deps:
return some_deps.value
else:
return deps
return deps # type: ignore

def _build_graph_deps(
self,
deps: AgentDepsT,
user_prompt: str | Sequence[_messages.UserContent],
new_message_index: int,
model_used: models.Model,
model_settings: ModelSettings | None,
usage_limits: _usage.UsageLimits,
result_schema: _result.ResultSchema[RunResultDataT] | None,
result_validators: list[_result.ResultValidator[AgentDepsT, RunResultDataT]],
run_span: Span,
tracer: Tracer,
) -> _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]:
return _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
result_schema=result_schema,
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
result_validators=result_validators,
function_tools=self._function_tools,
reflect_on_tool_call=self._reflect_on_tool_call,
run_span=run_span,
tracer=tracer,
)

def _infer_name(self, function_frame: FrameType | None) -> None:
"""Infer the agent name from the call frame.
Expand Down
74 changes: 74 additions & 0 deletions tests/test_build_graph_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Tests for the _build_graph_deps method."""

from typing import Any, cast

import pytest
from opentelemetry.trace import NoOpTracer, Span

from pydantic_ai import Agent
from pydantic_ai._agent_graph import GraphAgentDeps
from pydantic_ai.models.test import TestModel
from pydantic_ai.usage import UsageLimits

pytestmark = pytest.mark.anyio


class TestBuildGraphDeps:
"""Tests for the _build_graph_deps method that was added in the commit."""

async def test_reflect_on_tool_call_passed_to_graph_deps(self) -> None:
"""Test that reflect_on_tool_call is passed correctly to GraphAgentDeps."""
# Create an agent with reflect_on_tool_call=True (the default)
agent_true = Agent(TestModel())

# Create mock parameters for _build_graph_deps
deps = None
user_prompt = 'test prompt'
new_message_index = 0
model_used = TestModel()
model_settings = None
usage_limits = UsageLimits()
result_schema = None
result_validators: list[Any] = []
# Create a mock Span instance instead of using the class
run_span = cast(Span, object())
tracer = NoOpTracer()

# Call the method
graph_deps = agent_true._build_graph_deps( # pyright: ignore[reportPrivateUsage]
deps,
user_prompt,
new_message_index,
model_used,
model_settings,
usage_limits,
result_schema,
result_validators,
run_span,
tracer,
)

# Verify that reflect_on_tool_call was correctly passed to GraphAgentDeps
assert isinstance(graph_deps, GraphAgentDeps)
assert graph_deps.reflect_on_tool_call is True

# Create an agent with reflect_on_tool_call=False
agent_false = Agent(TestModel(), reflect_on_tool_call=False)

# Call the method with the second agent
graph_deps = agent_false._build_graph_deps( # pyright: ignore[reportPrivateUsage]
deps,
user_prompt,
new_message_index,
model_used,
model_settings,
usage_limits,
result_schema,
result_validators,
run_span,
tracer,
)

# Verify that reflect_on_tool_call=False was correctly passed to GraphAgentDeps
assert isinstance(graph_deps, GraphAgentDeps)
assert graph_deps.reflect_on_tool_call is False
Loading