Skip to content

Change Agent.is_?_node() to use TypeIs instead #1622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chasewalden
Copy link

@chasewalden chasewalden commented Apr 30, 2025

A small change to improve some ergonomics of the Agent.is_*_node() utility functions.

This stems from wanting to do something if it is an AgentNode (but not necessarily one of the concrete node types
Since there is no Agent.is_agent_node(), it is not as easy to check the inverse case (i.e. when using Agent.iter). The only way to this currently is to use isinstance(node, End), which will narrow for the inverse case, but requires importing End from pydantic_graph.

As an example:

from typing_extensions import ParamSpec, TypeVar, Callable, TypeGuard, TypeIs, cast, reveal_type
from pydantic_ai import Agent

P = ParamSpec("P")
R = TypeVar("R")

def rewrite_guard(input: Callable[P, TypeGuard[R]]) -> Callable[P, TypeIs[R]]:
    return cast(Callable[P, TypeIs[R]], input)

better_is_end_node = rewrite_guard(Agent.is_end_node)

my_agent = Agent(...)

async def do_thing():
    async with my_agent.iter(...) as run:
        async for node in run:
            # Type check error, even though its actually sound
            print(node.get_node_id() if not Agent.is_end_node(node) else "END")
            #          ~~~~~~~~~~~
            # > Pyright: Cannot access attribute "get_node_id" for class "End[FinalResult[Any]]"
            # >     Attribute "get_node_id" is unknown

            # Also a type check error
            print("END" if Agent.is_end_node(node) else node.get_node_id())
            #                                                ~~~~~~~~~~~
            # > Pyright: Cannot access attribute "get_node_id" for class "End[FinalResult[Any]]"
            # >     Attribute "get_node_id" is unknown

            # This produces no type check error
            print("END" if isinstance(node, End) else node.get_node_id())

            ### Checking with `reveal_type(...)`

            if Agent.is_end_node(node):
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"
            else:
                reveal_type(node) # > Type of "node" is "AgentNode[None, Any] | End[FinalResult[Any]]"
    
            # ... versus ...
    
            if isinstance(node, End):
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"
            else:
                reveal_type(node) # > Type of "node" is "AgentNode[None, Any]"
    
            ### Also the inverse case
    
            if not Agent.is_end_node(node):
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]] | AgentNode[None, Any]"
            else:
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"
    
            # ... versus ...
    
            if not isinstance(node, End):
                reveal_type(node) # > Type of "node" is "AgentNode[None, Any]"
            else:
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"
    
            ### Using the "better" is_end_node (matches the behavior of `isinstance(node, End)`)
    
            if better_is_end_node(node):
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"
            else:
                reveal_type(node) # > Type of "node" is "AgentNode[None, Any]"
            
            # also works on the inverse
    
            if not better_is_end_node(node):
                reveal_type(node) # > Type of "node" is "AgentNode[None, Any]"
            else:
                reveal_type(node) # > Type of "node" is "End[FinalResult[Any]]"

It looks like changing the signature of the other functions is inconsequential, but I guess it's consistent...
I imagine this is expected behavior anyways.

better_is_user_prompt_node = rewrite_guard(Agent.is_user_prompt_node)
better_is_model_request_node = rewrite_guard(Agent.is_model_request_node)
better_is_call_tools_node = rewrite_guard(Agent.is_call_tools_node)

if better_is_user_prompt_node(node):
    reveal_type(node) # > Type of "node" is "UserPromptNode[None, Any]
elif better_is_model_request_node(node):
    reveal_type(node) # > Type of "node" is "ModelRequestNode[None, Any]"
elif better_is_call_tools_node(node):
    reveal_type(node) # > Type of "node" is "CallToolsNode[None, Any]"
else:
    reveal_type(node) # > Type of "node" is "End[FinalResult[Any]] | AgentNode[None, Any]"

@Kludex Kludex requested a review from dmontagu April 30, 2025 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants