diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index cdcbdd819..2abd9ca5a 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -1,15 +1,10 @@ -from __future__ import annotations - import asyncio import inspect import json from copy import copy from typing import ( - TYPE_CHECKING, Any, Callable, - Dict, - List, Literal, Optional, Sequence, @@ -35,22 +30,20 @@ from langchain_core.tools import BaseTool, InjectedToolArg from langchain_core.tools import tool as create_tool from langchain_core.tools.base import get_all_basemodel_annotations +from pydantic import BaseModel from typing_extensions import Annotated, get_args, get_origin from langgraph.errors import GraphInterrupt from langgraph.store.base import BaseStore from langgraph.utils.runnable import RunnableCallable -if TYPE_CHECKING: - from pydantic import BaseModel - INVALID_TOOL_NAME_ERROR_TEMPLATE = ( "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]." ) TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -def msg_content_output(output: Any) -> str | List[dict]: +def msg_content_output(output: Any) -> Union[str, list[dict]]: recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output @@ -95,7 +88,7 @@ def _handle_tool_error( return content -def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]: +def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]: sig = inspect.signature(handler) params = list(sig.parameters.values()) if params: @@ -194,9 +187,9 @@ def __init__( messages_key: str = "messages", ) -> None: super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) - self.tools_by_name: Dict[str, BaseTool] = {} - self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {} - self.tool_to_store_arg: Dict[str, Optional[str]] = {} + self.tools_by_name: dict[str, BaseTool] = {} + self.tool_to_state_args: dict[str, dict[str, Optional[str]]] = {} + self.tool_to_store_arg: dict[str, Optional[str]] = {} self.handle_tool_errors = handle_tool_errors self.messages_key = messages_key for tool_ in tools: @@ -346,7 +339,7 @@ def _parse_input( BaseModel, ], store: BaseStore, - ) -> Tuple[List[ToolCall], Literal["list", "dict"]]: + ) -> Tuple[list[ToolCall], Literal["list", "dict"]]: if isinstance(input, list): output_type = "list" message: AnyMessage = input[-1] @@ -656,9 +649,9 @@ def _is_injection( return False -def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]: +def _get_state_args(tool: BaseTool) -> dict[str, Optional[str]]: full_schema = tool.get_input_schema() - tool_args_to_state_fields: Dict = {} + tool_args_to_state_fields: dict = {} for name, type_ in get_all_basemodel_annotations(full_schema).items(): injections = [