From 6da5a051937a7dacd51b7adc3e763d856b501395 Mon Sep 17 00:00:00 2001 From: Chase Walden Date: Wed, 30 Apr 2025 13:29:04 -0600 Subject: [PATCH] update Agent.is_*_node(...) to use TypeIs --- pydantic_ai_slim/pydantic_ai/agent.py | 30 ++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d9eb1000e..7b0462785 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -11,7 +11,7 @@ from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, TypeGuard, TypeVar, deprecated +from typing_extensions import Literal, Never, TypeIs, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -444,7 +444,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -453,7 +453,23 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None, + *, + output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') @@ -1561,7 +1577,7 @@ def _prepare_output_schema( @staticmethod def is_model_request_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]: + ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: """Check if the node is a `ModelRequestNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1571,7 +1587,7 @@ def is_model_request_node( @staticmethod def is_call_tools_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]: + ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: """Check if the node is a `CallToolsNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1581,7 +1597,7 @@ def is_call_tools_node( @staticmethod def is_user_prompt_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]: + ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: """Check if the node is a `UserPromptNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1591,7 +1607,7 @@ def is_user_prompt_node( @staticmethod def is_end_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[End[result.FinalResult[S]]]: + ) -> TypeIs[End[result.FinalResult[S]]]: """Check if the node is a `End`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.