diff --git a/libs/langgraph/tests/test_state.py b/libs/langgraph/tests/test_state.py index ecf35eb2aa..b69c6ddfa9 100644 --- a/libs/langgraph/tests/test_state.py +++ b/libs/langgraph/tests/test_state.py @@ -1,11 +1,13 @@ +import itertools from typing import Annotated as Annotated2 from typing import Any import pytest +from langchain_core.runnables import RunnableConfig from pydantic.v1 import BaseModel from typing_extensions import Annotated, TypedDict -from langgraph.graph.state import _warn_invalid_state_schema +from langgraph.graph.state import StateGraph, _warn_invalid_state_schema class State(BaseModel): @@ -46,3 +48,42 @@ def test_doesnt_warn_valid_schema(schema: Any): # Assert the function does not raise a warning with pytest.warns(None): _warn_invalid_state_schema(schema) + + +def test_state_schema_with_type_hint(): + class InputState(TypedDict): + question: str + + class OutputState(TypedDict): + input_state: InputState + + def complete_hint(state: InputState) -> OutputState: + return {"input_state": state} + + def miss_first_hint(state, config: RunnableConfig) -> OutputState: + return {"input_state": state} + + def only_return_hint(state, config) -> OutputState: + return {"input_state": state} + + def miss_all_hint(state, config): + return {"input_state": state} + + graph = StateGraph(input=InputState, output=OutputState) + actions = [complete_hint, miss_first_hint, only_return_hint, miss_all_hint] + + for action in actions: + graph.add_node(action) + + graph.set_entry_point(actions[0].__name__) + for a, b in itertools.pairwise(actions): + graph.add_edge(a.__name__, b.__name__) + graph.set_finish_point(actions[-1].__name__) + + graph = graph.compile() + + input_state = InputState(question="Hello World!") + output_state = OutputState(input_state=input_state) + for i, c in enumerate(graph.stream(input_state, stream_mode="updates")): + node_name = actions[i].__name__ + assert c[node_name] == output_state