-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory-collection-schema.py
330 lines (230 loc) · 10.8 KB
/
memory-collection-schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# ===============================================
# Memory Collection Schema in LangGraph
# ===============================================
# -----------------------------------------------
# Load environment variables
# -----------------------------------------------
from dotenv import load_dotenv
load_dotenv()
# -----------------------------------------------
# Defining a Collection Schema
# -----------------------------------------------
from pydantic import BaseModel, Field
class Memory(BaseModel):
content: str = Field(description="The main content of the memory. For example: User expressed interest in learning about French.")
class MemoryCollection(BaseModel):
memories: list[Memory] = Field(description="A collection of memories.")
# -----------------------------------------------
# ChatModel with Structured Output
# -----------------------------------------------
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
# Initialize the model
model = ChatOpenAI(model="gpt-4o", temperature=0)
# Bind schema to model
model_with_structure = model.with_structured_output(MemoryCollection)
# Invoke the model to produce structured output that matches the schema
memory_collection = model_with_structure.invoke([HumanMessage("My name is Sushant. I like to Drive.")])
memory_collection.memories
# -------------------------------------
# Save dictionary representation of each memory to the store.
import uuid
from langgraph.store.memory import InMemoryStore
# Initialize the in-memory store
in_memory_store = InMemoryStore()
# Namespace for the memory to save
user_id = "1"
namespace_for_memory = (user_id, "memories")
# Save a memory to namespace as key and value
key = str(uuid.uuid4())
value = memory_collection.memories[0].model_dump()
in_memory_store.put(namespace_for_memory, key, value)
key = str(uuid.uuid4())
value = memory_collection.memories[1].model_dump()
in_memory_store.put(namespace_for_memory, key, value)
# # Search
# for m in in_memory_store.search(namespace_for_memory):
# print(m.dict())
# -----------------------------------------------
# Updating Collection Schema
# -----------------------------------------------
from trustcall import create_extractor
# Create the extractor
trustcall_extractor = create_extractor(
model,
tools=[Memory],
tool_choice="Memory",
enable_inserts=True,
)
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
# Instruction
instruction = """Extract memories from the following conversation:"""
# Conversation
conversation = [HumanMessage(content="Hi, I'm Sushant."),
AIMessage(content="Nice to meet you, Sushant."),
HumanMessage(content="This morning I had a nice drive from Pune to Mumbai.")]
# Invoke the extractor
result = trustcall_extractor.invoke({"messages": [SystemMessage(content=instruction)] + conversation})
# # Messages contain the tool calls
# for m in result["messages"]:
# m.pretty_print()
# # Responses contain the memories that adhere to the schema
# for m in result["responses"]:
# print(m)
# # Metadata contains the tool call
# for m in result["response_metadata"]:
# print(m)
# -----------------------------------------------
# Update the conversation
updated_conversation = [AIMessage(content="That's great, did you do after?"),
HumanMessage(content="I went to Marzorin and ate a croissant."),
AIMessage(content="What else is on your mind?"),
HumanMessage(content="I was thinking about my New Zealand, and going back this winter!"),]
# Update the instruction
system_msg = """Update existing memories and create new ones based on the following conversation:"""
# We'll save existing memories, giving them an ID, key (tool name), and value
tool_name = "Memory"
existing_memories = [(str(i), tool_name, memory.model_dump()) for i, memory in enumerate(result["responses"])] if result["responses"] else None
existing_memories
# Invoke the extractor with our updated conversation and existing memories
result = trustcall_extractor.invoke({"messages": updated_conversation,
"existing": existing_memories})
# # Messages from the model indicate two tool calls were made
# for m in result["messages"]:
# m.pretty_print()
# # Responses contain the memories that adhere to the schema
# for m in result["responses"]:
# print(m)
# # Metadata contains the tool call
# for m in result["response_metadata"]:
# print(m)
# -----------------------------------------------
# Chatbot with collection schema updating
# -----------------------------------------------
from IPython.display import Image, display
import uuid
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.base import BaseStore
# Initialize the model
model = ChatOpenAI(model="gpt-4o", temperature=0)
# Memory schema
class Memory(BaseModel):
content: str = Field(description="The main content of the memory. For example: User expressed interest in learning about French.")
# Create the Trustcall extractor
trustcall_extractor = create_extractor(
model,
tools=[Memory],
tool_choice="Memory",
enable_inserts=True,
)
# Chatbot instruction
MODEL_SYSTEM_MESSAGE = """
You are a helpful chatbot. You are designed to be a companion to a user.
You have a long term memory which keeps track of information you learn about the user over time.
Current Memory (may include updated memories from this conversation): {memory}
"""
# Trustcall instruction
TRUSTCALL_INSTRUCTION = """
Reflect on following interaction.
Use the provided tools to retain any necessary memories about the user.
Use parallel tool calling to handle updates and insertions simultaneously:
"""
# -----------------------------------------------
def call_model(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""Load memories from the store and use them to personalize the chatbot's response."""
# Get the user ID from the config
user_id = config["configurable"]["user_id"]
# Retrieve memory from the store
namespace = (user_id, "memories")
memories = store.search(namespace)
# Format the memories for the system prompt
info = "\n".join(f"- {mem.value['content']}" for mem in memories)
system_msg = MODEL_SYSTEM_MESSAGE.format(memory=info)
# Respond using memory as well as the chat history
response = model.invoke([SystemMessage(content=system_msg)]+state["messages"])
return {"messages": response}
# -----------------------------------------------
def write_memory(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""Reflect on the chat history and update the memory collection."""
# Get the user ID from the config
user_id = config["configurable"]["user_id"]
# Define the namespace for the memories
namespace = (user_id, "memories")
# Retrieve the most recent memories for context
existing_items = store.search(namespace)
# Format the existing memories for the Trustcall extractor
tool_name = "Memory"
existing_memories = ([(existing_item.key, tool_name, existing_item.value)
for existing_item in existing_items]
if existing_items
else None
)
# Merge the chat history and the instruction
updated_messages=list(merge_message_runs(messages=[SystemMessage(content=TRUSTCALL_INSTRUCTION)] + state["messages"]))
# Invoke the extractor
result = trustcall_extractor.invoke({"messages": updated_messages,
"existing": existing_memories})
# Save the memories from Trustcall to the store
for r, rmeta in zip(result["responses"], result["response_metadata"]):
store.put(namespace,
rmeta.get("json_doc_id", str(uuid.uuid4())),
r.model_dump(mode="json"),
)
# -----------------------------------------------
# Define a graph
builder = StateGraph(MessagesState)
# Nodes
builder.add_node("call_model", call_model)
builder.add_node("write_memory", write_memory)
# Edges
builder.add_edge(START, "call_model")
builder.add_edge("call_model", "write_memory")
builder.add_edge("write_memory", END)
# Store for long-term (across-thread) memory
across_thread_memory = InMemoryStore()
# Checkpointer for short-term (within-thread) memory
within_thread_memory = MemorySaver()
# Compile the graph with the checkpointer and store
graph = builder.compile(checkpointer=within_thread_memory, store=across_thread_memory)
# Visualize
display(Image(graph.get_graph().draw_mermaid_png()))
# -----------------------------------------------
# Run the graph
# -----------------------------------------------
# We supply a thread ID for short-term (within-thread) memory
# We supply a user ID for long-term (across-thread) memory
config = {"configurable": {"thread_id": "1", "user_id": "1"}}
# User input
input_messages = [HumanMessage(content="Hi, my name is Sushant.")]
# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
chunk["messages"][-1].pretty_print()
# -------------------------------------
# User input
input_messages = [HumanMessage(content="I like to drive around Pune")]
# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
chunk["messages"][-1].pretty_print()
# -------------------------------------
# Namespace for the memory to save
user_id = "1"
namespace = ("memories", user_id)
memories = across_thread_memory.search(namespace)
for m in memories:
print(m.dict())
# -------------------------------------
# Continue the conversation in a new thread.
# We supply a thread ID for short-term (within-thread) memory
# We supply a user ID for long-term (across-thread) memory
config = {"configurable": {"thread_id": "2", "user_id": "1"}}
# User input
input_messages = [HumanMessage(content="What bakeries do you recommend for me?")]
# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
chunk["messages"][-1].pretty_print()
# -----------------------------------------------