Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: update get_state to handle nested subgraph state #1108

Merged
merged 56 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
261cdf8
langgraph: update get_state to handle nested subgraph state
vbarda Jul 23, 2024
2268802
Merge branch 'main' into vb/update-get-state
vbarda Jul 23, 2024
92ae8f4
async methods
vbarda Jul 23, 2024
e615aab
cleanup names + make subgraph state optional
vbarda Jul 24, 2024
b43ef64
tests
vbarda Jul 24, 2024
ae696d4
add sync history
vbarda Jul 24, 2024
6ae2c6c
Merge branch 'main' into vb/update-get-state
vbarda Jul 24, 2024
a7d4846
use .list for looking up prefix-matched checkpoints
vbarda Jul 24, 2024
56bf9c9
Merge branch 'main' into vb/update-get-state
vbarda Jul 24, 2024
c9d6a41
Merge branch 'main' into vb/update-get-state
vbarda Aug 12, 2024
322cfc4
cleanup
vbarda Aug 12, 2024
9948125
add checkpointer=INHERIT_CHECKPOINTER
vbarda Aug 12, 2024
fb05bdc
return all checkpoints from .list
vbarda Aug 12, 2024
246dd0f
Merge branch 'vb/update-get-state' of github.com:langchain-ai/langgra…
vbarda Aug 12, 2024
5033044
update checkpointer tests
vbarda Aug 12, 2024
6d4cdc9
opt-in
vbarda Aug 12, 2024
abe9b7c
lint
vbarda Aug 12, 2024
a03886b
Merge branch 'main' into vb/update-get-state
vbarda Aug 12, 2024
f65d9b2
pass subgraph nodes/channels
vbarda Aug 12, 2024
b5caf1a
Merge branch 'vb/update-get-state' of github.com:langchain-ai/langgra…
vbarda Aug 13, 2024
0456f52
Merge branch 'main' into vb/update-get-state
vbarda Aug 13, 2024
0135c6f
correct check for using parent checkpointer
vbarda Aug 13, 2024
3295274
fix empty snapshot
vbarda Aug 13, 2024
58887a5
checkpoints/interrupts for subgraphs triggered by sends
vbarda Aug 13, 2024
392891f
Merge branch 'main' into vb/update-get-state
vbarda Aug 13, 2024
d961888
update logic for latest snapshot's subgraph snapshots
vbarda Aug 13, 2024
409b915
code review
vbarda Aug 14, 2024
6531ec7
remove inherit checkpointer
vbarda Aug 14, 2024
7fa9789
correctly propagate all subgraph attributes
vbarda Aug 14, 2024
45054df
remove include_subgraph_state kwarg
vbarda Aug 15, 2024
a94168a
Merge branch 'main' into vb/update-get-state
vbarda Aug 21, 2024
f51e7ea
pass pending writes in checkpointers
vbarda Aug 21, 2024
0a87b9f
update more tests
vbarda Aug 21, 2024
e7bc74e
remove refactors
vbarda Aug 21, 2024
578ec48
Merge branch 'main' into vb/update-get-state
vbarda Aug 21, 2024
0b6088f
remove futures.clear
vbarda Aug 21, 2024
426125c
Merge branch 'main' into vb/update-get-state
vbarda Aug 21, 2024
5654d8f
Merge branch 'main' into vb/update-get-state
vbarda Aug 22, 2024
6c7d9c3
remove interrupts
vbarda Aug 22, 2024
acd8acf
lint
vbarda Aug 22, 2024
4935cf5
lint
vbarda Aug 22, 2024
72893d9
code review
vbarda Aug 22, 2024
9f6e57d
more code review
vbarda Aug 22, 2024
4162be8
optimize subgraph state lookups
vbarda Aug 23, 2024
065055e
small change
vbarda Aug 23, 2024
2bac0d0
re-trigger CI
vbarda Aug 23, 2024
1333d8b
cleanup
vbarda Aug 23, 2024
4c4d705
Merge branch 'main' into vb/update-get-state
vbarda Aug 23, 2024
1f29925
add max recursion depth
vbarda Aug 23, 2024
7144291
lint
vbarda Aug 23, 2024
15692ac
refactor to remove nested DB calls
vbarda Aug 23, 2024
cb30f68
remove more reused code
vbarda Aug 23, 2024
904a1a3
extra paranoia
vbarda Aug 23, 2024
bf4dc5d
Merge branch 'main' into vb/update-get-state
vbarda Aug 26, 2024
85e698e
filter on checkpoint NS
vbarda Aug 26, 2024
507930e
lint
vbarda Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,6 @@ def _search_where(
if config:
wheres.append("thread_id = %s ")
param_values.append(config["configurable"]["thread_id"])
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
wheres.append("checkpoint_ns = %s")
param_values.append(checkpoint_ns)

# construct predicate for metadata filter
if filter:
Expand Down
26 changes: 6 additions & 20 deletions libs/checkpoint-postgres/tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,15 @@ async def test_asearch(self):
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to root graph checkpoints)
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c
async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 1
assert search_results_5[0].config["configurable"]["checkpoint_ns"] == ""

# search by config and checkpoint_ns
search_results_6 = [
c
async for c in saver.alist(
{
"configurable": {
"thread_id": "thread-2",
"checkpoint_ns": "inner",
}
}
)
]
assert len(search_results_6) == 1
assert (
search_results_6[0].config["configurable"]["checkpoint_ns"] == "inner"
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params
25 changes: 6 additions & 19 deletions libs/checkpoint-postgres/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,14 @@ def test_search(self):
search_results_4 = list(saver.list(None, filter=query_4))
assert len(search_results_4) == 0

# search by config (defaults to root graph checkpoints)
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 1
assert search_results_5[0].config["configurable"]["checkpoint_ns"] == ""

# search by config and checkpoint_ns
search_results_6 = list(
saver.list(
{
"configurable": {
"thread_id": "thread-2",
"checkpoint_ns": "inner",
}
}
)
)
assert len(search_results_6) == 1
assert (
search_results_6[0].config["configurable"]["checkpoint_ns"] == "inner"
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params
3 changes: 0 additions & 3 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def search_where(
if config is not None:
wheres.append("thread_id = ?")
param_values.append(config["configurable"]["thread_id"])
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
wheres.append("checkpoint_ns = ?")
param_values.append(checkpoint_ns)

# construct predicate for metadata filter
if filter:
Expand Down
26 changes: 6 additions & 20 deletions libs/checkpoint-sqlite/tests/test_aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,15 @@ async def test_asearch(self):
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0

# search by config (defaults to root graph checkpoints)
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c
async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 1
assert search_results_5[0].config["configurable"]["checkpoint_ns"] == ""

# search by config and checkpoint_ns
search_results_6 = [
c
async for c in saver.alist(
{
"configurable": {
"thread_id": "thread-2",
"checkpoint_ns": "inner",
}
}
)
]
assert len(search_results_6) == 1
assert (
search_results_6[0].config["configurable"]["checkpoint_ns"] == "inner"
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params
25 changes: 6 additions & 19 deletions libs/checkpoint-sqlite/tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,15 @@ def test_search(self):
search_results_4 = list(saver.list(None, filter=query_4))
assert len(search_results_4) == 0

# search by config (defaults to root graph checkpoints)
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 1
assert search_results_5[0].config["configurable"]["checkpoint_ns"] == ""

# search by config and checkpoint_ns
search_results_6 = list(
saver.list(
{
"configurable": {
"thread_id": "thread-2",
"checkpoint_ns": "inner",
}
}
)
)
assert len(search_results_6) == 1
assert (
search_results_6[0].config["configurable"]["checkpoint_ns"] == "inner"
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params

Expand Down
119 changes: 65 additions & 54 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,62 +177,66 @@ def list(
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
checkpoint_ns = (
config["configurable"].get("checkpoint_ns", "") if config else ""
)
for thread_id in thread_ids:
for checkpoint_id, (checkpoint, metadata_b, parent_checkpoint_id) in sorted(
self.storage[thread_id][checkpoint_ns].items(),
key=lambda x: x[0],
reverse=True,
):
# filter by checkpoint ID
if (
before
and (before_checkpoint_id := get_checkpoint_id(before))
and checkpoint_id >= before_checkpoint_id
):
continue

# filter by metadata
metadata = self.serde.loads_typed(metadata_b)
if filter and not all(
query_value == metadata[query_key]
for query_key, query_value in filter.items()
for checkpoint_ns in self.storage[thread_id].keys():
for checkpoint_id, (
checkpoint,
metadata_b,
parent_checkpoint_id,
) in sorted(
self.storage[thread_id][checkpoint_ns].items(),
key=lambda x: x[0],
reverse=True,
):
continue

# limit search results
if limit is not None and limit <= 0:
break
elif limit is not None:
limit -= 1

writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()

yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
# filter by checkpoint ID
if (
before
and (before_checkpoint_id := get_checkpoint_id(before))
and checkpoint_id >= before_checkpoint_id
):
continue

# filter by metadata
metadata = self.serde.loads_typed(metadata_b)
if filter and not all(
query_value == metadata.get(query_key)
for query_key, query_value in filter.items()
):
continue

# limit search results
if limit is not None and limit <= 0:
break
elif limit is not None:
limit -= 1

writes = self.writes[
(thread_id, checkpoint_ns, checkpoint_id)
].values()

yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint=self.serde.loads_typed(checkpoint),
metadata=metadata,
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
},
checkpoint=self.serde.loads_typed(checkpoint),
metadata=metadata,
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None,
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
)
if parent_checkpoint_id
else None,
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
)

def put(
self,
Expand Down Expand Up @@ -338,7 +342,14 @@ async def alist(
"""
loop = asyncio.get_running_loop()
iter = await loop.run_in_executor(
None, partial(self.list, before=before, limit=limit, filter=filter), config
None,
partial(
self.list,
before=before,
limit=limit,
filter=filter,
),
config,
)
while True:
# handling StopIteration exception inside coroutine won't work
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _default(self, obj):
return self._encode_constructor_args(obj.__class__, args=[obj.value])
elif isinstance(obj, SendProtocol):
return self._encode_constructor_args(
obj.__class__, kwargs={"node": obj.node, "arg": obj.arg}
obj.__class__, kwargs={"node": obj.node, "arg": obj.arg, "id": obj.id}
)
elif isinstance(obj, (bytes, bytearray)):
return self._encode_constructor_args(
Expand Down
1 change: 1 addition & 0 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SendProtocol(Protocol):
# Mirrors langgraph.constants.Send
node: str
arg: Any
id: str

def __hash__(self) -> int:
...
Expand Down
23 changes: 9 additions & 14 deletions libs/checkpoint/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,20 @@ async def test_search(self):
assert search_results_2[0].metadata == self.metadata_2

search_results_3 = list(self.memory_saver.list(None, filter=query_3))
assert len(search_results_3) == 2
assert len(search_results_3) == 3

search_results_4 = list(self.memory_saver.list(None, filter=query_4))
assert len(search_results_4) == 0

# search by config (defaults to root graph checkpoints)
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
self.memory_saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 1
assert search_results_5[0].config["configurable"]["checkpoint_ns"] == ""

# search by config and checkpoint_ns
search_results_6 = list(
self.memory_saver.list(
{"configurable": {"thread_id": "thread-2", "checkpoint_ns": "inner"}}
)
)
assert len(search_results_6) == 1
assert search_results_6[0].config["configurable"]["checkpoint_ns"] == "inner"
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}

# TODO: test before and limit params

Expand All @@ -110,6 +104,7 @@ async def test_asearch(self):
# save checkpoints
self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1: CheckpointMetadata = {"source": "input"} # search by 1 key
Expand All @@ -135,7 +130,7 @@ async def test_asearch(self):
search_results_3 = [
c async for c in self.memory_saver.alist(None, filter=query_3)
]
assert len(search_results_3) == 2
assert len(search_results_3) == 3

search_results_4 = [
c async for c in self.memory_saver.alist(None, filter=query_4)
Expand Down
Loading
Loading