From 0ba3e4d08522e3a5229fbfa4a1f00662e7a27a95 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 20 Aug 2024 15:03:09 -0700 Subject: [PATCH] Use batched async kv inside loop --- libs/langgraph/langgraph/pregel/loop.py | 10 +++++++--- libs/langgraph/tests/test_pregel_async.py | 3 +-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index a684185061..5d175597ac 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -49,6 +49,8 @@ INTERRUPT, ) from langgraph.errors import EmptyInputError, GraphInterrupt +from langgraph.kv.base import BaseKV +from langgraph.kv.batch import AsyncBatchedKV from langgraph.managed.base import ( AsyncManagedValuesManager, ManagedValueMapping, @@ -102,7 +104,7 @@ class PregelLoop: ] ] graph: "Pregel" - + kv: Optional[BaseKV] submit: Submit channels: Mapping[str, BaseChannel] managed: ManagedValueMapping @@ -418,6 +420,7 @@ def __init__( graph: "Pregel", ) -> None: super().__init__(input, config=config, checkpointer=checkpointer, graph=graph) + self.kv = graph.kv self.stack = ExitStack() self.stack.push(self._suppress_interrupt) if checkpointer: @@ -470,7 +473,7 @@ def __enter__(self) -> Self: self.managed = self.stack.enter_context( ManagedValuesManager( self.graph.managed_values_dict, - patch_config(self.config, configurable={CONFIG_KEY_KV: self.graph.kv}), + patch_config(self.config, configurable={CONFIG_KEY_KV: self.kv}), ) ) self.status = "pending" @@ -501,6 +504,7 @@ def __init__( graph: "Pregel", ) -> None: super().__init__(input, config=config, checkpointer=checkpointer, graph=graph) + self.kv = AsyncBatchedKV(graph.kv) if graph.kv else None self.stack = AsyncExitStack() self.stack.push(self._suppress_interrupt) if checkpointer: @@ -557,7 +561,7 @@ async def __aenter__(self) -> Self: self.managed = await self.stack.enter_async_context( AsyncManagedValuesManager( self.graph.managed_values_dict, - patch_config(self.config, configurable={CONFIG_KEY_KV: self.graph.kv}), + patch_config(self.config, configurable={CONFIG_KEY_KV: self.kv}), ) ) self.status = "pending" diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index c01da5bc8b..ae4df21fa7 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -52,7 +52,6 @@ from langgraph.graph import END, Graph, StateGraph from langgraph.graph.graph import START from langgraph.graph.message import MessageGraph, add_messages -from langgraph.kv.batch import AsyncBatchedKV from langgraph.kv.memory import MemoryKV from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import ( @@ -4745,7 +4744,7 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: async with AsyncSqliteSaver.from_conn_string(":memory:") as saver: tool_two = tool_two_graph.compile( - kv=AsyncBatchedKV(MemoryKV()), + kv=MemoryKV(), checkpointer=saver, interrupt_before=["tool_two_fast", "tool_two_slow"], )