Skip to content

Commit

Permalink
Use batched async kv inside loop
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Aug 20, 2024
1 parent 8b58b5e commit 0ba3e4d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 7 additions & 3 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
INTERRUPT,
)
from langgraph.errors import EmptyInputError, GraphInterrupt
from langgraph.kv.base import BaseKV
from langgraph.kv.batch import AsyncBatchedKV
from langgraph.managed.base import (
AsyncManagedValuesManager,
ManagedValueMapping,
Expand Down Expand Up @@ -102,7 +104,7 @@ class PregelLoop:
]
]
graph: "Pregel"

kv: Optional[BaseKV]
submit: Submit
channels: Mapping[str, BaseChannel]
managed: ManagedValueMapping
Expand Down Expand Up @@ -418,6 +420,7 @@ def __init__(
graph: "Pregel",
) -> None:
super().__init__(input, config=config, checkpointer=checkpointer, graph=graph)
self.kv = graph.kv
self.stack = ExitStack()
self.stack.push(self._suppress_interrupt)
if checkpointer:
Expand Down Expand Up @@ -470,7 +473,7 @@ def __enter__(self) -> Self:
self.managed = self.stack.enter_context(
ManagedValuesManager(
self.graph.managed_values_dict,
patch_config(self.config, configurable={CONFIG_KEY_KV: self.graph.kv}),
patch_config(self.config, configurable={CONFIG_KEY_KV: self.kv}),
)
)
self.status = "pending"
Expand Down Expand Up @@ -501,6 +504,7 @@ def __init__(
graph: "Pregel",
) -> None:
super().__init__(input, config=config, checkpointer=checkpointer, graph=graph)
self.kv = AsyncBatchedKV(graph.kv) if graph.kv else None
self.stack = AsyncExitStack()
self.stack.push(self._suppress_interrupt)
if checkpointer:
Expand Down Expand Up @@ -557,7 +561,7 @@ async def __aenter__(self) -> Self:
self.managed = await self.stack.enter_async_context(
AsyncManagedValuesManager(
self.graph.managed_values_dict,
patch_config(self.config, configurable={CONFIG_KEY_KV: self.graph.kv}),
patch_config(self.config, configurable={CONFIG_KEY_KV: self.kv}),
)
)
self.status = "pending"
Expand Down
3 changes: 1 addition & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.graph import START
from langgraph.graph.message import MessageGraph, add_messages
from langgraph.kv.batch import AsyncBatchedKV
from langgraph.kv.memory import MemoryKV
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import (
Expand Down Expand Up @@ -4745,7 +4744,7 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State:

async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
tool_two = tool_two_graph.compile(
kv=AsyncBatchedKV(MemoryKV()),
kv=MemoryKV(),
checkpointer=saver,
interrupt_before=["tool_two_fast", "tool_two_slow"],
)
Expand Down

0 comments on commit 0ba3e4d

Please sign in to comment.