From 8fc326a7b8a8b66baefed25a1a948e909bfcbd20 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 22 Apr 2025 13:35:14 -0400 Subject: [PATCH] Make input/new items available in the run context --- src/agents/result.py | 5 +++++ src/agents/run.py | 12 +++++++++++- src/agents/run_context.py | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/agents/result.py b/src/agents/result.py index 0d8372c8..f3e26ca6 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -8,6 +8,8 @@ from typing_extensions import TypeVar +from agents.run_context import RunContextWrapper + from ._run_impl import QueueCompleteSentinel from .agent import Agent from .agent_output import AgentOutputSchemaBase @@ -50,6 +52,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + context_wrapper: RunContextWrapper[Any] + """The context wrapper that was used to run the agent.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: diff --git a/src/agents/run.py b/src/agents/run.py index 2af558d5..8f90af5a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -170,6 +170,7 @@ async def run( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context, # type: ignore + _input=copy.deepcopy(input), ) input_guardrail_results: list[InputGuardrailResult] = [] @@ -255,6 +256,9 @@ async def run( original_input = turn_result.original_input generated_items = turn_result.generated_items + context_wrapper._input = copy.deepcopy(original_input) + context_wrapper._new_items = copy.deepcopy(generated_items) + if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await cls._run_output_guardrails( current_agent.output_guardrails + (run_config.output_guardrails or []), @@ -270,6 +274,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -407,7 +412,8 @@ def run_streamed( output_schema = cls._get_output_schema(starting_agent) context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context # type: ignore + context=context, # type: ignore + _input=copy.deepcopy(input), ) streamed_result = RunResultStreaming( @@ -423,6 +429,7 @@ def run_streamed( output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, + context_wrapper=context_wrapper, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -576,6 +583,9 @@ async def _run_streamed_impl( streamed_result.input = turn_result.original_input streamed_result.new_items = turn_result.generated_items + context_wrapper._input = copy.deepcopy(streamed_result.input) + context_wrapper._new_items = copy.deepcopy(streamed_result.new_items) + if isinstance(turn_result.next_step, NextStepHandoff): current_agent = turn_result.next_step.new_agent current_span.finish(reset_current=True) diff --git a/src/agents/run_context.py b/src/agents/run_context.py index 579a215f..31c391fa 100644 --- a/src/agents/run_context.py +++ b/src/agents/run_context.py @@ -1,8 +1,12 @@ +from __future__ import annotations + +import copy from dataclasses import dataclass, field from typing import Any, Generic from typing_extensions import TypeVar +from .items import RunItem, TResponseInputItem from .usage import Usage TContext = TypeVar("TContext", default=Any) @@ -24,3 +28,21 @@ class RunContextWrapper(Generic[TContext]): """The usage of the agent run so far. For streamed responses, the usage will be stale until the last chunk of the stream is processed. """ + + _new_items: list[RunItem] = field(default_factory=list, repr=False) + """The new items created during the agent run.""" + + _input: str | list[TResponseInputItem] = field(default_factory=list, repr=False) + """The original input that you passed to `Runner.run()`.""" + + @property + def input(self) -> str | list[TResponseInputItem]: + """The original input that you passed to `Runner.run()`.""" + + return copy.deepcopy(self._input) + + @property + def new_items(self) -> list[RunItem]: + """The new items created during the agent run.""" + + return copy.deepcopy(self._new_items)