-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathroute_in_graph.py
108 lines (72 loc) · 3.07 KB
/
route_in_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# ===============================================
# Route in Graph
# ===============================================
# 1. Add a node that will call our tool.
# 2. Add a conditional edge that will look at the chat model model output,
# and route to our tool calling node or simply end if no tool call is performed.
# -----------------------------------------------
# Let's load our environment
from dotenv import load_dotenv
load_dotenv()
# -----------------------------------------------
# Let's define a function to act as our tool
# -----------------------------------------------
def multiply(a: int, b: int) -> int:
"""
Multiplies a and b.
Args:
a: first int
b: second int
"""
return a * b
# -----------------------------------------------
# Let's bind our tool to our chat model
# -----------------------------------------------
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
llm_with_tools = llm.bind_tools([multiply])
# -----------------------------------------------
# Let's build our graph
# -----------------------------------------------
from langgraph.graph import StateGraph, START, END
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# -----------------------------------------------
# Node function to call our tool
def tool_calling_llm(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# -----------------------------------------------
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("tool_calling_llm", tool_calling_llm)
builder.add_node("tools", ToolNode([multiply]))
builder.add_edge(START, "tool_calling_llm")
builder.add_conditional_edges(
"tool_calling_llm",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_edge("tools", END)
# -----------------------------------------------
# Compile graph
graph = builder.compile()
# -----------------------------------------------
# View
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
# -----------------------------------------------
# Let's invoke the graph with a simple message
# -----------------------------------------------
from langchain_core.messages import HumanMessage
messages = graph.invoke({"messages": HumanMessage(content="Hello! How are you?")})
for msg in messages["messages"]:
msg.pretty_print()
# -----------------------------------------------
# Let's invoke with a math multiplication message
# -----------------------------------------------
from langchain_core.messages import HumanMessage
messages = graph.invoke({"messages": HumanMessage(content="What is 4 multiplied by 3")})
for msg in messages["messages"]:
msg.pretty_print()
# -----------------------------------------------