Skip to content

Commit

Permalink
test: add #1331 test case
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaian10 committed Aug 16, 2024
1 parent a9a76f3 commit 27f622b
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion libs/langgraph/tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import itertools
from typing import Annotated as Annotated2
from typing import Any

import pytest
from langchain_core.runnables import RunnableConfig
from pydantic.v1 import BaseModel
from typing_extensions import Annotated, TypedDict

from langgraph.graph.state import _warn_invalid_state_schema
from langgraph.graph.state import StateGraph, _warn_invalid_state_schema


class State(BaseModel):
Expand Down Expand Up @@ -46,3 +48,42 @@ def test_doesnt_warn_valid_schema(schema: Any):
# Assert the function does not raise a warning
with pytest.warns(None):
_warn_invalid_state_schema(schema)


def test_state_schema_with_type_hint():
class InputState(TypedDict):
question: str

class OutputState(TypedDict):
input_state: InputState

def complete_hint(state: InputState) -> OutputState:
return {"input_state": state}

def miss_first_hint(state, config: RunnableConfig) -> OutputState:
return {"input_state": state}

def only_return_hint(state, config) -> OutputState:
return {"input_state": state}

def miss_all_hint(state, config):
return {"input_state": state}

graph = StateGraph(input=InputState, output=OutputState)
actions = [complete_hint, miss_first_hint, only_return_hint, miss_all_hint]

for action in actions:
graph.add_node(action)

graph.set_entry_point(actions[0].__name__)
for a, b in itertools.pairwise(actions):
graph.add_edge(a.__name__, b.__name__)
graph.set_finish_point(actions[-1].__name__)

graph = graph.compile()

input_state = InputState(question="Hello World!")
output_state = OutputState(input_state=input_state)
for i, c in enumerate(graph.stream(input_state, stream_mode="updates")):
node_name = actions[i].__name__
assert c[node_name] == output_state

0 comments on commit 27f622b

Please sign in to comment.