From 4406828c39b5b16f4f53217aeca91c2288429735 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:27:47 -0800 Subject: [PATCH 1/2] test --- .../langgraph/store/postgres/aio.py | 15 +- .../tests/test_async_store.py | 101 +++++++++ .../langgraph/store/base/__init__.py | 4 + libs/checkpoint/langgraph/store/base/batch.py | 207 +++++++++++++++++- 4 files changed, 313 insertions(+), 14 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index 4a516557a..2b4309d5e 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -19,7 +19,11 @@ Result, SearchOp, ) -from langgraph.store.base.batch import AsyncBatchedBaseStore +from langgraph.store.base.batch import ( + BatchedBaseStore, + AsyncBatchedBaseStore, + SyncBatchedBaseStore, +) from langgraph.store.postgres.base import ( _PLACEHOLDER, BasePostgresStore, @@ -156,8 +160,15 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: return results + # def batch(self, ops: Iterable[Op]) -> list[Result]: + # return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() def batch(self, ops: Iterable[Op]) -> list[Result]: - return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() + futures = [] + for op in ops: + fut = self._loop.create_future() + self._aqueue[fut] = op + futures.append(fut) + return [fut.result() for fut in asyncio.as_completed(futures)] @classmethod @asynccontextmanager diff --git a/libs/checkpoint-postgres/tests/test_async_store.py b/libs/checkpoint-postgres/tests/test_async_store.py index eda0e2820..7c9cd38ad 100644 --- a/libs/checkpoint-postgres/tests/test_async_store.py +++ b/libs/checkpoint-postgres/tests/test_async_store.py @@ -1,4 +1,5 @@ # type: ignore +import asyncio import itertools import sys import uuid @@ -63,6 +64,106 @@ async def store(request) -> AsyncIterator[AsyncPostgresStore]: await conn.execute(f"DROP DATABASE {database}") +async def test_large_batches(store: AsyncPostgresStore) -> None: + N = 100 + M = 10 + coros = [] + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=10) as executor: + for m in range(M): + ops = [] + for i in range(N): + for i in range(N): + coros.append( + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + executor.submit( + store.search, + ("test",), + ) + ) + coros.append( + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ) + ) + # ops.extend( + # [ + # PutOp( + # namespace=("test", "foo", "bar", "baz", str(m % 2)), + # key=f"key{i}", # {m}", + # value=None, + # ), + # GetOp(namespace=("test",), key=f"key{i}{m}"), + # ListNamespacesOp( + # match_conditions=None, max_depth=i + 1, limit=m, offset=0 + # ), + # SearchOp( + # namespace_prefix=("test",), + # filter=None, + # limit=10, + # offset=0, + # ), + # ] + # + # ops.extend( + # [ + # # PutOp( + # # namespace=("test", "foo", "bar", "baz", str(m % 2)), + # # key=f"key{i}", # {m}", + # # value={"data": f"value{i}{m}"}, + # # ), + # # GetOp(namespace=("test",), key=f"key{i}{m}"), + # # ListNamespacesOp( + # # match_conditions=None, max_depth=i + 1, limit=m, offset=0 + # # ), + # # SearchOp( + # # namespace_prefix=("test",), + # # filter={"data": f"value{i}{m}"}, + # # limit=10, + # # offset=0, + # # ), + # ] + # ) + + # coros.extend(ops) + # executor.map(store.batch, [[op] for op in coros]) + + # await asyncio.gather(*coros) # *[store.abatch(ops) for ops in coros]) + + async def test_abatch_order(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) diff --git a/libs/checkpoint/langgraph/store/base/__init__.py b/libs/checkpoint/langgraph/store/base/__init__.py index c635e43ed..06eeab9cc 100644 --- a/libs/checkpoint/langgraph/store/base/__init__.py +++ b/libs/checkpoint/langgraph/store/base/__init__.py @@ -808,6 +808,8 @@ def list_namespaces( # [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ + if max_depth is not None and max_depth <= 0: + raise ValueError("If provided, max_depth must be greater than 0") match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) @@ -1004,6 +1006,8 @@ async def alist_namespaces( # Returns: [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ + if max_depth is not None and max_depth <= 0: + raise ValueError("If provided, max_depth must be greater than 0") match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index 33c502574..a3e2d343b 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -1,6 +1,9 @@ import asyncio +import threading +import time import weakref -from typing import Any, Literal, Optional, Union +from concurrent.futures import Future +from typing import Any, Iterable, Literal, Optional, Union from langgraph.store.base import ( BaseStore, @@ -17,19 +20,9 @@ ) -class AsyncBatchedBaseStore(BaseStore): +class AsyncBatchedBaseStoreMixin: """Efficiently batch operations in a background task.""" - __slots__ = ("_loop", "_aqueue", "_task") - - def __init__(self) -> None: - self._loop = asyncio.get_running_loop() - self._aqueue: dict[asyncio.Future, Op] = {} - self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) - - def __del__(self) -> None: - self._task.cancel() - async def aget( self, namespace: tuple[str, ...], @@ -99,6 +92,29 @@ async def alist_namespaces( self._aqueue[fut] = op return await fut + # def batch(self, ops: Iterable[Op]) -> list[Result]: + # futures = [] + # for op in ops: + # fut = self._loop.create_future() + # self._aqueue[fut] = op + # futures.append(fut) + # return [fut.result() for fut in asyncio.as_completed(futures)] + + +class AsyncBatchedBaseStore(AsyncBatchedBaseStoreMixin, BaseStore): + """Efficiently batch operations in a background task.""" + + __slots__ = ("_loop", "_aqueue", "_task") + + def __init__(self) -> None: + super().__init__() + self._loop = asyncio.get_running_loop() + self._aqueue: dict[asyncio.Future, Op] = {} + self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) + + def __del__(self) -> None: + self._task.cancel() + def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: """Dedupe operations while preserving order for results. @@ -174,3 +190,170 @@ async def _run( break # remove strong ref to store del s + + +class SyncBatchedBaseStoreMixin(BaseStore): + """Efficiently batch operations in a background thread.""" + + def get( + self, + namespace: tuple[str, ...], + key: str, + ) -> Optional[Item]: + fut = Future() + self._queue[fut] = GetOp(namespace, key) + return fut.result() + + def search( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchItem]: + fut = Future() + self._queue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) + return fut.result() + + def put( + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + index: Optional[Union[Literal[False], list[str]]] = None, + ) -> None: + _validate_namespace(namespace) + fut = Future() + self._queue[fut] = PutOp(namespace, key, value, index) + return fut.result() + + def delete( + self, + namespace: tuple[str, ...], + key: str, + ) -> None: + fut = Future() + self._queue[fut] = PutOp(namespace, key, None) + return fut.result() + + def list_namespaces( + self, + *, + prefix: Optional[NamespacePath] = None, + suffix: Optional[NamespacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + fut = Future() + match_conditions = [] + if prefix: + match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) + if suffix: + match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) + + op = ListNamespacesOp( + match_conditions=tuple(match_conditions), + max_depth=max_depth, + limit=limit, + offset=offset, + ) + self._queue[fut] = op + return fut.result() + + +class SyncBatchedBaseStore(SyncBatchedBaseStoreMixin, BaseStore): + """Efficiently batch operations in a background thread.""" + + __slots__ = ("_sync_queue", "_sync_thread") + + def __init__(self) -> None: + super().__init__() + self._sync_queue: dict[Future, Op] = {} + self._sync_thread = threading.Thread( + target=_sync_run, + args=(self._sync_queue, weakref.ref(self)), + daemon=True, + ) + self._sync_thread.start() + + def __del__(self) -> None: + # Signal the thread to stop + if self._sync_thread.is_alive(): + empty_future = Future() + self._sync_queue[empty_future] = None # type: ignore + self._sync_thread.join() + + +class BatchedBaseStore( + AsyncBatchedBaseStoreMixin, SyncBatchedBaseStoreMixin, BaseStore +): + __slots__ = ( + "_sync_queue", + "_sync_thread", + "_task", + "_loop", + "_aqueue", + ) + + def __init__(self) -> None: + super().__init__() + # Setup async processing + self._loop = asyncio.get_running_loop() + self._aqueue: dict[asyncio.Future, Op] = {} + self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) + + self._sync_queue: dict[Future, Op] = {} + self._sync_thread = threading.Thread( + target=_sync_run, + args=(self._sync_queue, weakref.ref(self)), + daemon=True, + ) + self._sync_thread.start() + + def __del__(self) -> None: + # Signal the thread to stop + if self._sync_thread.is_alive(): + empty_future = Future() + self._sync_queue[empty_future] = None # type: ignore + self._sync_thread.join() + + # Signal the thread to stop + if self._task is not None: + self._task.cancel() + + +def _sync_run(queue: dict[Future, Op], store: weakref.ReferenceType[BaseStore]) -> None: + while True: + time.sleep(0.001) # Yield to other threads + if not queue: + continue + if s := store(): + # get the operations to run + taken = queue.copy() + # action each operation + try: + values = list(taken.values()) + if None in values: # Exit signal + break + listen, dedupped = _dedupe_ops(values) + results = s.batch(dedupped) # Note: Using sync batch here + if listen is not None: + results = [results[ix] for ix in listen] + + # set the results of each operation + for fut, result in zip(taken, results): + fut.set_result(result) + except Exception as e: + for fut in taken: + fut.set_exception(e) + # remove the operations from the queue + for fut in taken: + del queue[fut] + else: + break + # remove strong ref to store + del s From 054ca2f78dd274ae68edef990b0dbcba16b0ba86 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:00:57 -0800 Subject: [PATCH 2/2] Lint and fix --- .../langgraph/store/postgres/aio.py | 13 +- .../langgraph/store/postgres/base.py | 4 +- .../tests/test_async_store.py | 177 +++++++++--------- libs/checkpoint-postgres/tests/test_store.py | 98 +++++++++- libs/checkpoint/langgraph/store/base/batch.py | 56 ++++-- 5 files changed, 217 insertions(+), 131 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index 2b4309d5e..776c81e30 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -21,8 +21,6 @@ ) from langgraph.store.base.batch import ( BatchedBaseStore, - AsyncBatchedBaseStore, - SyncBatchedBaseStore, ) from langgraph.store.postgres.base import ( _PLACEHOLDER, @@ -40,7 +38,7 @@ logger = logging.getLogger(__name__) -class AsyncPostgresStore(AsyncBatchedBaseStore, BasePostgresStore[_ainternal.Conn]): +class AsyncPostgresStore(BatchedBaseStore, BasePostgresStore[_ainternal.Conn]): """Asynchronous Postgres-backed store with optional vector search using pgvector. !!! example "Examples" @@ -160,15 +158,8 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: return results - # def batch(self, ops: Iterable[Op]) -> list[Result]: - # return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() def batch(self, ops: Iterable[Op]) -> list[Result]: - futures = [] - for op in ops: - fut = self._loop.create_future() - self._aqueue[fut] = op - futures.append(fut) - return [fut.result() for fut in asyncio.as_completed(futures)] + return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() @classmethod @asynccontextmanager diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/base.py b/libs/checkpoint-postgres/langgraph/store/postgres/base.py index 84dd70364..d464353c4 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/base.py @@ -30,7 +30,6 @@ from langgraph.checkpoint.postgres import _ainternal as _ainternal from langgraph.checkpoint.postgres import _internal as _pg_internal from langgraph.store.base import ( - BaseStore, GetOp, IndexConfig, Item, @@ -44,6 +43,7 @@ get_text_at_path, tokenize_path, ) +from langgraph.store.base.batch import SyncBatchedBaseStore if TYPE_CHECKING: from langchain_core.embeddings import Embeddings @@ -533,7 +533,7 @@ def _get_filter_condition(self, key: str, op: str, value: Any) -> tuple[str, lis raise ValueError(f"Unsupported operator: {op}") -class PostgresStore(BaseStore, BasePostgresStore[_pg_internal.Conn]): +class PostgresStore(SyncBatchedBaseStore, BasePostgresStore[_pg_internal.Conn]): """Postgres-backed store with optional vector search using pgvector. !!! example "Examples" diff --git a/libs/checkpoint-postgres/tests/test_async_store.py b/libs/checkpoint-postgres/tests/test_async_store.py index 7c9cd38ad..8dd3c0508 100644 --- a/libs/checkpoint-postgres/tests/test_async_store.py +++ b/libs/checkpoint-postgres/tests/test_async_store.py @@ -4,6 +4,7 @@ import sys import uuid from collections.abc import AsyncIterator +from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Any, Optional @@ -64,104 +65,94 @@ async def store(request) -> AsyncIterator[AsyncPostgresStore]: await conn.execute(f"DROP DATABASE {database}") -async def test_large_batches(store: AsyncPostgresStore) -> None: - N = 100 +def test_large_batches(store: AsyncPostgresStore) -> None: + N = 1000 M = 10 - coros = [] - from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=10) as executor: for m in range(M): - ops = [] for i in range(N): - for i in range(N): - coros.append( - executor.submit( - store.put, - ("test", "foo", "bar", "baz", str(m % 2)), - f"key{i}", - value={"foo": "bar" + str(i)}, - ) - ) - coros.append( - executor.submit( - store.get, - ("test", "foo", "bar", "baz", str(m % 2)), - f"key{i}", - ) - ) - coros.append( - executor.submit( - store.list_namespaces, - prefix=None, - max_depth=m + 1, - ) - ) - coros.append( - executor.submit( - store.search, - ("test",), - ) - ) - coros.append( - executor.submit( - store.put, - ("test", "foo", "bar", "baz", str(m % 2)), - f"key{i}", - value={"foo": "bar" + str(i)}, - ) - ) - coros.append( - executor.submit( - store.put, - ("test", "foo", "bar", "baz", str(m % 2)), - f"key{i}", - None, - ) - ) - # ops.extend( - # [ - # PutOp( - # namespace=("test", "foo", "bar", "baz", str(m % 2)), - # key=f"key{i}", # {m}", - # value=None, - # ), - # GetOp(namespace=("test",), key=f"key{i}{m}"), - # ListNamespacesOp( - # match_conditions=None, max_depth=i + 1, limit=m, offset=0 - # ), - # SearchOp( - # namespace_prefix=("test",), - # filter=None, - # limit=10, - # offset=0, - # ), - # ] - # - # ops.extend( - # [ - # # PutOp( - # # namespace=("test", "foo", "bar", "baz", str(m % 2)), - # # key=f"key{i}", # {m}", - # # value={"data": f"value{i}{m}"}, - # # ), - # # GetOp(namespace=("test",), key=f"key{i}{m}"), - # # ListNamespacesOp( - # # match_conditions=None, max_depth=i + 1, limit=m, offset=0 - # # ), - # # SearchOp( - # # namespace_prefix=("test",), - # # filter={"data": f"value{i}{m}"}, - # # limit=10, - # # offset=0, - # # ), - # ] - # ) - - # coros.extend(ops) - # executor.map(store.batch, [[op] for op in coros]) - - # await asyncio.gather(*coros) # *[store.abatch(ops) for ops in coros]) + _ = [ + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ), + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ), + executor.submit( + store.search, + ("test",), + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ), + ] + + +async def test_large_batches_async(store: AsyncPostgresStore) -> None: + N = 1000 + M = 10 + coros = [] + for m in range(M): + for i in range(N): + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.aget( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + store.alist_namespaces( + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + store.asearch( + ("test",), + ) + ) + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.adelete( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + + await asyncio.gather(*coros) async def test_abatch_order(store: AsyncPostgresStore) -> None: diff --git a/libs/checkpoint-postgres/tests/test_store.py b/libs/checkpoint-postgres/tests/test_store.py index 35dfa2150..ffb1aa563 100644 --- a/libs/checkpoint-postgres/tests/test_store.py +++ b/libs/checkpoint-postgres/tests/test_store.py @@ -1,5 +1,7 @@ # type: ignore +import asyncio +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from typing import Any, Optional from uuid import uuid4 @@ -17,11 +19,7 @@ SearchOp, ) from langgraph.store.postgres import PostgresStore -from tests.conftest import ( - DEFAULT_URI, - VECTOR_TYPES, - CharacterEmbeddings, -) +from tests.conftest import DEFAULT_URI, VECTOR_TYPES, CharacterEmbeddings @pytest.fixture(scope="function", params=["default", "pipe", "pool"]) @@ -59,6 +57,96 @@ def store(request) -> PostgresStore: conn.execute(f"DROP DATABASE {database}") +def test_large_batches(store: PostgresStore) -> None: + N = 1000 + M = 10 + + with ThreadPoolExecutor(max_workers=10) as executor: + for m in range(M): + for i in range(N): + _ = [ + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ), + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ), + executor.submit( + store.search, + ("test",), + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ), + ] + + +async def test_large_batches_async(store: PostgresStore) -> None: + N = 1000 + M = 10 + coros = [] + for m in range(M): + for i in range(N): + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.aget( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + store.alist_namespaces( + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + store.asearch( + ("test",), + ) + ) + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.adelete( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + + await asyncio.gather(*coros) + + def test_batch_order(store: PostgresStore) -> None: # Setup test data store.put(("test", "foo"), "key1", {"data": "value1"}) diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index a3e2d343b..7c8371c1e 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -14,6 +14,7 @@ NamespacePath, Op, PutOp, + Result, SearchItem, SearchOp, _validate_namespace, @@ -23,6 +24,10 @@ class AsyncBatchedBaseStoreMixin: """Efficiently batch operations in a background task.""" + _loop: asyncio.AbstractEventLoop + _aqueue: dict[asyncio.Future, Op] + _task: asyncio.Task + async def aget( self, namespace: tuple[str, ...], @@ -92,14 +97,6 @@ async def alist_namespaces( self._aqueue[fut] = op return await fut - # def batch(self, ops: Iterable[Op]) -> list[Result]: - # futures = [] - # for op in ops: - # fut = self._loop.create_future() - # self._aqueue[fut] = op - # futures.append(fut) - # return [fut.result() for fut in asyncio.as_completed(futures)] - class AsyncBatchedBaseStore(AsyncBatchedBaseStoreMixin, BaseStore): """Efficiently batch operations in a background task.""" @@ -115,6 +112,14 @@ def __init__(self) -> None: def __del__(self) -> None: self._task.cancel() + def batch(self, ops: Iterable[Op]) -> list[Result]: + futures = [] + for op in ops: + fut = self._loop.create_future() + self._aqueue[fut] = op + futures.append(fut) + return [fut.result() for fut in asyncio.as_completed(futures)] + def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: """Dedupe operations while preserving order for results. @@ -195,13 +200,16 @@ async def _run( class SyncBatchedBaseStoreMixin(BaseStore): """Efficiently batch operations in a background thread.""" + _sync_queue: dict[Future, Op] + _sync_thread: threading.Thread + def get( self, namespace: tuple[str, ...], key: str, ) -> Optional[Item]: - fut = Future() - self._queue[fut] = GetOp(namespace, key) + fut: Future[Optional[Item]] = Future() + self._sync_queue[fut] = GetOp(namespace, key) return fut.result() def search( @@ -214,8 +222,8 @@ def search( limit: int = 10, offset: int = 0, ) -> list[SearchItem]: - fut = Future() - self._queue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) + fut: Future[list[SearchItem]] = Future() + self._sync_queue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) return fut.result() def put( @@ -226,8 +234,8 @@ def put( index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: _validate_namespace(namespace) - fut = Future() - self._queue[fut] = PutOp(namespace, key, value, index) + fut: Future[None] = Future() + self._sync_queue[fut] = PutOp(namespace, key, value, index) return fut.result() def delete( @@ -235,8 +243,8 @@ def delete( namespace: tuple[str, ...], key: str, ) -> None: - fut = Future() - self._queue[fut] = PutOp(namespace, key, None) + fut: Future[None] = Future() + self._sync_queue[fut] = PutOp(namespace, key, None) return fut.result() def list_namespaces( @@ -248,7 +256,7 @@ def list_namespaces( limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: - fut = Future() + fut: Future[list[tuple[str, ...]]] = Future() match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) @@ -261,7 +269,7 @@ def list_namespaces( limit=limit, offset=offset, ) - self._queue[fut] = op + self._sync_queue[fut] = op return fut.result() @@ -283,10 +291,18 @@ def __init__(self) -> None: def __del__(self) -> None: # Signal the thread to stop if self._sync_thread.is_alive(): - empty_future = Future() + empty_future: Future = Future() self._sync_queue[empty_future] = None # type: ignore self._sync_thread.join() + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + futures = [] + for op in ops: + fut: Future[Result] = Future() + self._sync_queue[fut] = op + futures.append(fut) + return [fut.result() for fut in futures] + class BatchedBaseStore( AsyncBatchedBaseStoreMixin, SyncBatchedBaseStoreMixin, BaseStore @@ -317,7 +333,7 @@ def __init__(self) -> None: def __del__(self) -> None: # Signal the thread to stop if self._sync_thread.is_alive(): - empty_future = Future() + empty_future: Future[None] = Future() self._sync_queue[empty_future] = None # type: ignore self._sync_thread.join()