From 8ef82f3578bc621f63e93bed45a819d8a3e428e5 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Dec 2024 17:40:27 -0800 Subject: [PATCH] Update --- libs/langgraph/langgraph/prebuilt/tool_node.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 3c43bfba8..2abd9ca5a 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -5,8 +5,6 @@ from typing import ( Any, Callable, - Dict, - List, Literal, Optional, Sequence, @@ -45,7 +43,7 @@ TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -def msg_content_output(output: Any) -> str | List[dict]: +def msg_content_output(output: Any) -> Union[str, list[dict]]: recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output @@ -90,7 +88,7 @@ def _handle_tool_error( return content -def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]: +def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]: sig = inspect.signature(handler) params = list(sig.parameters.values()) if params: @@ -189,9 +187,9 @@ def __init__( messages_key: str = "messages", ) -> None: super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) - self.tools_by_name: Dict[str, BaseTool] = {} - self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {} - self.tool_to_store_arg: Dict[str, Optional[str]] = {} + self.tools_by_name: dict[str, BaseTool] = {} + self.tool_to_state_args: dict[str, dict[str, Optional[str]]] = {} + self.tool_to_store_arg: dict[str, Optional[str]] = {} self.handle_tool_errors = handle_tool_errors self.messages_key = messages_key for tool_ in tools: @@ -341,7 +339,7 @@ def _parse_input( BaseModel, ], store: BaseStore, - ) -> Tuple[List[ToolCall], Literal["list", "dict"]]: + ) -> Tuple[list[ToolCall], Literal["list", "dict"]]: if isinstance(input, list): output_type = "list" message: AnyMessage = input[-1] @@ -651,9 +649,9 @@ def _is_injection( return False -def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]: +def _get_state_args(tool: BaseTool) -> dict[str, Optional[str]]: full_schema = tool.get_input_schema() - tool_args_to_state_fields: Dict = {} + tool_args_to_state_fields: dict = {} for name, type_ in get_all_basemodel_annotations(full_schema).items(): injections = [