diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index 4a516557a..776c81e30 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -19,7 +19,9 @@ Result, SearchOp, ) -from langgraph.store.base.batch import AsyncBatchedBaseStore +from langgraph.store.base.batch import ( + BatchedBaseStore, +) from langgraph.store.postgres.base import ( _PLACEHOLDER, BasePostgresStore, @@ -36,7 +38,7 @@ logger = logging.getLogger(__name__) -class AsyncPostgresStore(AsyncBatchedBaseStore, BasePostgresStore[_ainternal.Conn]): +class AsyncPostgresStore(BatchedBaseStore, BasePostgresStore[_ainternal.Conn]): """Asynchronous Postgres-backed store with optional vector search using pgvector. !!! example "Examples" diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/base.py b/libs/checkpoint-postgres/langgraph/store/postgres/base.py index 84dd70364..d464353c4 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/base.py @@ -30,7 +30,6 @@ from langgraph.checkpoint.postgres import _ainternal as _ainternal from langgraph.checkpoint.postgres import _internal as _pg_internal from langgraph.store.base import ( - BaseStore, GetOp, IndexConfig, Item, @@ -44,6 +43,7 @@ get_text_at_path, tokenize_path, ) +from langgraph.store.base.batch import SyncBatchedBaseStore if TYPE_CHECKING: from langchain_core.embeddings import Embeddings @@ -533,7 +533,7 @@ def _get_filter_condition(self, key: str, op: str, value: Any) -> tuple[str, lis raise ValueError(f"Unsupported operator: {op}") -class PostgresStore(BaseStore, BasePostgresStore[_pg_internal.Conn]): +class PostgresStore(SyncBatchedBaseStore, BasePostgresStore[_pg_internal.Conn]): """Postgres-backed store with optional vector search using pgvector. !!! example "Examples" diff --git a/libs/checkpoint-postgres/tests/test_async_store.py b/libs/checkpoint-postgres/tests/test_async_store.py index eda0e2820..8dd3c0508 100644 --- a/libs/checkpoint-postgres/tests/test_async_store.py +++ b/libs/checkpoint-postgres/tests/test_async_store.py @@ -1,8 +1,10 @@ # type: ignore +import asyncio import itertools import sys import uuid from collections.abc import AsyncIterator +from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Any, Optional @@ -63,6 +65,96 @@ async def store(request) -> AsyncIterator[AsyncPostgresStore]: await conn.execute(f"DROP DATABASE {database}") +def test_large_batches(store: AsyncPostgresStore) -> None: + N = 1000 + M = 10 + + with ThreadPoolExecutor(max_workers=10) as executor: + for m in range(M): + for i in range(N): + _ = [ + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ), + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ), + executor.submit( + store.search, + ("test",), + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ), + ] + + +async def test_large_batches_async(store: AsyncPostgresStore) -> None: + N = 1000 + M = 10 + coros = [] + for m in range(M): + for i in range(N): + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.aget( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + store.alist_namespaces( + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + store.asearch( + ("test",), + ) + ) + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.adelete( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + + await asyncio.gather(*coros) + + async def test_abatch_order(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) diff --git a/libs/checkpoint-postgres/tests/test_store.py b/libs/checkpoint-postgres/tests/test_store.py index 35dfa2150..ffb1aa563 100644 --- a/libs/checkpoint-postgres/tests/test_store.py +++ b/libs/checkpoint-postgres/tests/test_store.py @@ -1,5 +1,7 @@ # type: ignore +import asyncio +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from typing import Any, Optional from uuid import uuid4 @@ -17,11 +19,7 @@ SearchOp, ) from langgraph.store.postgres import PostgresStore -from tests.conftest import ( - DEFAULT_URI, - VECTOR_TYPES, - CharacterEmbeddings, -) +from tests.conftest import DEFAULT_URI, VECTOR_TYPES, CharacterEmbeddings @pytest.fixture(scope="function", params=["default", "pipe", "pool"]) @@ -59,6 +57,96 @@ def store(request) -> PostgresStore: conn.execute(f"DROP DATABASE {database}") +def test_large_batches(store: PostgresStore) -> None: + N = 1000 + M = 10 + + with ThreadPoolExecutor(max_workers=10) as executor: + for m in range(M): + for i in range(N): + _ = [ + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ), + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ), + executor.submit( + store.search, + ("test",), + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ), + ] + + +async def test_large_batches_async(store: PostgresStore) -> None: + N = 1000 + M = 10 + coros = [] + for m in range(M): + for i in range(N): + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.aget( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + store.alist_namespaces( + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + store.asearch( + ("test",), + ) + ) + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.adelete( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + + await asyncio.gather(*coros) + + def test_batch_order(store: PostgresStore) -> None: # Setup test data store.put(("test", "foo"), "key1", {"data": "value1"}) diff --git a/libs/checkpoint/langgraph/store/base/__init__.py b/libs/checkpoint/langgraph/store/base/__init__.py index c635e43ed..06eeab9cc 100644 --- a/libs/checkpoint/langgraph/store/base/__init__.py +++ b/libs/checkpoint/langgraph/store/base/__init__.py @@ -808,6 +808,8 @@ def list_namespaces( # [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ + if max_depth is not None and max_depth <= 0: + raise ValueError("If provided, max_depth must be greater than 0") match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) @@ -1004,6 +1006,8 @@ async def alist_namespaces( # Returns: [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ + if max_depth is not None and max_depth <= 0: + raise ValueError("If provided, max_depth must be greater than 0") match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index 33c502574..7c8371c1e 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -1,6 +1,9 @@ import asyncio +import threading +import time import weakref -from typing import Any, Literal, Optional, Union +from concurrent.futures import Future +from typing import Any, Iterable, Literal, Optional, Union from langgraph.store.base import ( BaseStore, @@ -11,24 +14,19 @@ NamespacePath, Op, PutOp, + Result, SearchItem, SearchOp, _validate_namespace, ) -class AsyncBatchedBaseStore(BaseStore): +class AsyncBatchedBaseStoreMixin: """Efficiently batch operations in a background task.""" - __slots__ = ("_loop", "_aqueue", "_task") - - def __init__(self) -> None: - self._loop = asyncio.get_running_loop() - self._aqueue: dict[asyncio.Future, Op] = {} - self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) - - def __del__(self) -> None: - self._task.cancel() + _loop: asyncio.AbstractEventLoop + _aqueue: dict[asyncio.Future, Op] + _task: asyncio.Task async def aget( self, @@ -100,6 +98,29 @@ async def alist_namespaces( return await fut +class AsyncBatchedBaseStore(AsyncBatchedBaseStoreMixin, BaseStore): + """Efficiently batch operations in a background task.""" + + __slots__ = ("_loop", "_aqueue", "_task") + + def __init__(self) -> None: + super().__init__() + self._loop = asyncio.get_running_loop() + self._aqueue: dict[asyncio.Future, Op] = {} + self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) + + def __del__(self) -> None: + self._task.cancel() + + def batch(self, ops: Iterable[Op]) -> list[Result]: + futures = [] + for op in ops: + fut = self._loop.create_future() + self._aqueue[fut] = op + futures.append(fut) + return [fut.result() for fut in asyncio.as_completed(futures)] + + def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: """Dedupe operations while preserving order for results. @@ -174,3 +195,181 @@ async def _run( break # remove strong ref to store del s + + +class SyncBatchedBaseStoreMixin(BaseStore): + """Efficiently batch operations in a background thread.""" + + _sync_queue: dict[Future, Op] + _sync_thread: threading.Thread + + def get( + self, + namespace: tuple[str, ...], + key: str, + ) -> Optional[Item]: + fut: Future[Optional[Item]] = Future() + self._sync_queue[fut] = GetOp(namespace, key) + return fut.result() + + def search( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchItem]: + fut: Future[list[SearchItem]] = Future() + self._sync_queue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) + return fut.result() + + def put( + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + index: Optional[Union[Literal[False], list[str]]] = None, + ) -> None: + _validate_namespace(namespace) + fut: Future[None] = Future() + self._sync_queue[fut] = PutOp(namespace, key, value, index) + return fut.result() + + def delete( + self, + namespace: tuple[str, ...], + key: str, + ) -> None: + fut: Future[None] = Future() + self._sync_queue[fut] = PutOp(namespace, key, None) + return fut.result() + + def list_namespaces( + self, + *, + prefix: Optional[NamespacePath] = None, + suffix: Optional[NamespacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + fut: Future[list[tuple[str, ...]]] = Future() + match_conditions = [] + if prefix: + match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) + if suffix: + match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) + + op = ListNamespacesOp( + match_conditions=tuple(match_conditions), + max_depth=max_depth, + limit=limit, + offset=offset, + ) + self._sync_queue[fut] = op + return fut.result() + + +class SyncBatchedBaseStore(SyncBatchedBaseStoreMixin, BaseStore): + """Efficiently batch operations in a background thread.""" + + __slots__ = ("_sync_queue", "_sync_thread") + + def __init__(self) -> None: + super().__init__() + self._sync_queue: dict[Future, Op] = {} + self._sync_thread = threading.Thread( + target=_sync_run, + args=(self._sync_queue, weakref.ref(self)), + daemon=True, + ) + self._sync_thread.start() + + def __del__(self) -> None: + # Signal the thread to stop + if self._sync_thread.is_alive(): + empty_future: Future = Future() + self._sync_queue[empty_future] = None # type: ignore + self._sync_thread.join() + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + futures = [] + for op in ops: + fut: Future[Result] = Future() + self._sync_queue[fut] = op + futures.append(fut) + return [fut.result() for fut in futures] + + +class BatchedBaseStore( + AsyncBatchedBaseStoreMixin, SyncBatchedBaseStoreMixin, BaseStore +): + __slots__ = ( + "_sync_queue", + "_sync_thread", + "_task", + "_loop", + "_aqueue", + ) + + def __init__(self) -> None: + super().__init__() + # Setup async processing + self._loop = asyncio.get_running_loop() + self._aqueue: dict[asyncio.Future, Op] = {} + self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) + + self._sync_queue: dict[Future, Op] = {} + self._sync_thread = threading.Thread( + target=_sync_run, + args=(self._sync_queue, weakref.ref(self)), + daemon=True, + ) + self._sync_thread.start() + + def __del__(self) -> None: + # Signal the thread to stop + if self._sync_thread.is_alive(): + empty_future: Future[None] = Future() + self._sync_queue[empty_future] = None # type: ignore + self._sync_thread.join() + + # Signal the thread to stop + if self._task is not None: + self._task.cancel() + + +def _sync_run(queue: dict[Future, Op], store: weakref.ReferenceType[BaseStore]) -> None: + while True: + time.sleep(0.001) # Yield to other threads + if not queue: + continue + if s := store(): + # get the operations to run + taken = queue.copy() + # action each operation + try: + values = list(taken.values()) + if None in values: # Exit signal + break + listen, dedupped = _dedupe_ops(values) + results = s.batch(dedupped) # Note: Using sync batch here + if listen is not None: + results = [results[ix] for ix in listen] + + # set the results of each operation + for fut, result in zip(taken, results): + fut.set_result(result) + except Exception as e: + for fut in taken: + fut.set_exception(e) + # remove the operations from the queue + for fut in taken: + del queue[fut] + else: + break + # remove strong ref to store + del s