-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbreakpoints.py
129 lines (82 loc) · 3.02 KB
/
breakpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# ===============================================
# Breakpoints in Graph - Human in loop
# ===============================================
# -----------------------------------------------
# Load environment variables
# -----------------------------------------------
from dotenv import load_dotenv
load_dotenv()
# -----------------------------------------------
# Tools - functions
# -----------------------------------------------
from langchain_openai import ChatOpenAI
def multiply(a: int, b: int) -> int:
"""
Multiply a and b.
Args:
a: first int
b: second int
"""
return a * b
def add(a: int, b: int) -> int:
"""
Adds a and b.
Args:
a: first int
b: second int
"""
return a + b
def divide(a: int, b: int) -> float:
"""
Divides a and b.
Args:
a: first int
b: second int
"""
return a / b
tools = [add, multiply, divide]
llm = ChatOpenAI(model="gpt-4o-mini")
llm_with_tools = llm.bind_tools(tools)
# -----------------------------------------------
# Define a Graph
# -----------------------------------------------
from langgraph.graph import StateGraph, START, END
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from IPython.display import Image, display
# System Message
sys_message = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic on a set of inputs.")
# Mode
def assistant(state:MessagesState):
return {"messages" : [llm_with_tools.invoke([sys_message] + state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
memory = MemorySaver()
graph = builder.compile(checkpointer=memory, interrupt_before=["tools"])
# Show graph
display(Image(graph.get_graph().draw_mermaid_png()))
# -----------------------------------------------
# Give input and run the graph
# -----------------------------------------------
input_message = {"messages" : HumanMessage(content="Multiply 4 with 3")}
thread = {"configurable" : {"thread_id" : "2"}}
for event in graph.stream(input_message, thread, stream_mode="values"):
event['messages'][-1].pretty_print()
# Get the state and next node
state = graph.get_state(thread)
state.next
# User feedback
user_approval = input("Do you want to call a tool? (Yes/No)")
# Continue from current state
if user_approval == "Yes" :
for event in graph.stream(None, thread, stream_mode="values"):
event['messages'][-1].pretty_print()
else:
print("Operation Cancelled by the User")
# -----------------------------------------------