Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 5, 2024
1 parent 2d6ddd0 commit 8ef82f3
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Expand Down Expand Up @@ -45,7 +43,7 @@
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


def msg_content_output(output: Any) -> str | List[dict]:
def msg_content_output(output: Any) -> Union[str, list[dict]]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
Expand Down Expand Up @@ -90,7 +88,7 @@ def _handle_tool_error(
return content


def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
Expand Down Expand Up @@ -189,9 +187,9 @@ def __init__(
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {}
self.tool_to_store_arg: Dict[str, Optional[str]] = {}
self.tools_by_name: dict[str, BaseTool] = {}
self.tool_to_state_args: dict[str, dict[str, Optional[str]]] = {}
self.tool_to_store_arg: dict[str, Optional[str]] = {}
self.handle_tool_errors = handle_tool_errors
self.messages_key = messages_key
for tool_ in tools:
Expand Down Expand Up @@ -341,7 +339,7 @@ def _parse_input(
BaseModel,
],
store: BaseStore,
) -> Tuple[List[ToolCall], Literal["list", "dict"]]:
) -> Tuple[list[ToolCall], Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
Expand Down Expand Up @@ -651,9 +649,9 @@ def _is_injection(
return False


def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]:
def _get_state_args(tool: BaseTool) -> dict[str, Optional[str]]:
full_schema = tool.get_input_schema()
tool_args_to_state_fields: Dict = {}
tool_args_to_state_fields: dict = {}

for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
Expand Down

0 comments on commit 8ef82f3

Please sign in to comment.