-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdynamic-breakpoints.py
99 lines (65 loc) · 2.42 KB
/
dynamic-breakpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# ===============================================
# Dynamic Breakpoints in Graph - Human in loop
# ===============================================
# -----------------------------------------------
# Load environment variables
# -----------------------------------------------
from dotenv import load_dotenv
load_dotenv()
# -----------------------------------------------
# Graph with NodeInterrupt
# -----------------------------------------------
from langgraph.graph import StateGraph, START, END
from langgraph.errors import NodeInterrupt
from langgraph.checkpoint.memory import MemorySaver
from typing_extensions import TypedDict
from IPython.display import display, Image
class State(TypedDict):
input: str
# Define Node functions
def step_1(state:State) -> State:
print("---Step 1---")
return state
def step_2(state:State) -> State:
# Conditional NodeInterrupt
if len(state['input']) > 5:
raise NodeInterrupt(f"Received input that is longer than 5 characters: {state['input']}")
print("---Step 2---")
return state
def step_3(state:State) -> State:
print("---Step 3---")
return state
# Create Graph
builder = StateGraph(state_schema=State)
# Add nodes
builder.add_node("step_1", step_1)
builder.add_node("step_2", step_2)
builder.add_node("step_3", step_3)
# Add edges
builder.add_edge(START, "step_1")
builder.add_edge("step_1", "step_2")
builder.add_edge("step_2", "step_3")
builder.add_edge("step_3", END)
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
# Show Graph
display(Image(graph.get_graph().draw_mermaid_png()))
# -----------------------------------------------
# Give input and run the graph
# -----------------------------------------------
input_message = {"input" : "Hello World!"}
thread = {"configurable" : {"thread_id" : "2"}}
for event in graph.stream(input_message, thread, stream_mode="values"):
print(event)
# Get the state and next node
state = graph.get_state(thread)
print(state.next)
# State log
print(state.tasks)
# -----------------------------------------------
# To resume the graph let's update the state
# -----------------------------------------------
graph.update_state(thread, {"input" : "Hello"},)
for event in graph.stream(None, thread, stream_mode="values"):
print(event)
# -----------------------------------------------