Skip to content

Commit

Permalink
Merge pull request #2502 from langchain-ai/vb/fix-annotation
Browse files Browse the repository at this point in the history
langgraph: fix issue w/ type annotations in tools_condition
  • Loading branch information
nfcampos authored Dec 5, 2024
2 parents 73e3f5a + 8ef82f3 commit 759a712
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from __future__ import annotations

import asyncio
import inspect
import json
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Expand All @@ -35,22 +30,20 @@
from langchain_core.tools import BaseTool, InjectedToolArg
from langchain_core.tools import tool as create_tool
from langchain_core.tools.base import get_all_basemodel_annotations
from pydantic import BaseModel
from typing_extensions import Annotated, get_args, get_origin

from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.utils.runnable import RunnableCallable

if TYPE_CHECKING:
from pydantic import BaseModel

INVALID_TOOL_NAME_ERROR_TEMPLATE = (
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
)
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


def msg_content_output(output: Any) -> str | List[dict]:
def msg_content_output(output: Any) -> Union[str, list[dict]]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
Expand Down Expand Up @@ -95,7 +88,7 @@ def _handle_tool_error(
return content


def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
Expand Down Expand Up @@ -194,9 +187,9 @@ def __init__(
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {}
self.tool_to_store_arg: Dict[str, Optional[str]] = {}
self.tools_by_name: dict[str, BaseTool] = {}
self.tool_to_state_args: dict[str, dict[str, Optional[str]]] = {}
self.tool_to_store_arg: dict[str, Optional[str]] = {}
self.handle_tool_errors = handle_tool_errors
self.messages_key = messages_key
for tool_ in tools:
Expand Down Expand Up @@ -346,7 +339,7 @@ def _parse_input(
BaseModel,
],
store: BaseStore,
) -> Tuple[List[ToolCall], Literal["list", "dict"]]:
) -> Tuple[list[ToolCall], Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
Expand Down Expand Up @@ -656,9 +649,9 @@ def _is_injection(
return False


def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]:
def _get_state_args(tool: BaseTool) -> dict[str, Optional[str]]:
full_schema = tool.get_input_schema()
tool_args_to_state_fields: Dict = {}
tool_args_to_state_fields: dict = {}

for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
Expand Down

0 comments on commit 759a712

Please sign in to comment.