From 90bd12ace7d1b35ee92327296fe79061976c7b92 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 24 Sep 2024 17:10:51 -0700 Subject: [PATCH 01/26] Implement new Store interface --- .../langgraph/managed/shared_value.py | 22 ++--- libs/langgraph/langgraph/store/base.py | 60 +++++++++-- libs/langgraph/langgraph/store/batch.py | 88 ++++++++++------- libs/langgraph/langgraph/store/memory.py | 71 ++++++++++--- libs/langgraph/langgraph/store/store.py | 99 +++++++++++++++++++ libs/langgraph/langgraph/utils/runnable.py | 31 ++++-- 6 files changed, 295 insertions(+), 76 deletions(-) create mode 100644 libs/langgraph/langgraph/store/store.py diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index 0f94dd1a1..607c2363f 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -21,7 +21,7 @@ ConfiguredManagedValue, WritableManagedValue, ) -from langgraph.store.base import BaseStore +from langgraph.store.base import BaseStore, PutOp, SearchOp V = dict[str, Any] @@ -58,8 +58,8 @@ def on(scope: str) -> ConfiguredManagedValue: def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: with super().enter(config, **kwargs) as value: if value.store is not None: - saved = value.store.list([value.ns]) - value.value = saved[value.ns] or {} + saved = value.store.search([SearchOp(value.ns)]) + value.value = {it.id: it.value for it in saved[0]} yield value @classmethod @@ -67,8 +67,8 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]: async with super().aenter(config, **kwargs) as value: if value.store is not None: - saved = await value.store.alist([value.ns]) - value.value = saved[value.ns] or {} + saved = await value.store.asearch([SearchOp(value.ns)]) + value.value = {it.id: it.value for it in saved[0]} yield value def __init__( @@ -87,7 +87,7 @@ def __init__( if self.store is None: pass elif scope_value := config[CONF].get(self.scope): - self.ns = f"scoped:{scope}:{key}:{scope_value}" + self.ns = ("scoped", scope, key, scope_value) else: raise ValueError( f"Scope {scope} for shared state key not in config.configurable" @@ -96,21 +96,19 @@ def __init__( def __call__(self, step: int) -> Value: return self.value.copy() - def _process_update( - self, values: Sequence[Update] - ) -> list[tuple[str, str, Optional[dict[str, Any]]]]: - writes: list[tuple[str, str, Optional[dict[str, Any]]]] = [] + def _process_update(self, values: Sequence[Update]) -> list[PutOp]: + writes: list[PutOp] = [] for vv in values: for k, v in vv.items(): if v is None: if k in self.value: del self.value[k] - writes.append((self.ns, k, None)) + writes.append(PutOp(self.ns, k, None)) elif not isinstance(v, dict): raise InvalidUpdateError("Received a non-dict value") else: self.value[k] = v - writes.append((self.ns, k, v)) + writes.append(PutOp(self.ns, k, v)) return writes def update(self, values: Sequence[Update]) -> None: diff --git a/libs/langgraph/langgraph/store/base.py b/libs/langgraph/langgraph/store/base.py index 046483f2e..31c2a07ca 100644 --- a/libs/langgraph/langgraph/store/base.py +++ b/libs/langgraph/langgraph/store/base.py @@ -1,21 +1,61 @@ -from typing import Any, List, Optional +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple, Optional -V = Any +SCORE_RECENCY = "recency" +SCORE_RELEVANCE = "relevance" + + +@dataclass +class Item: + value: dict[str, Any] + # search metadata + scores: dict[str, float] + # item metadata + id: str + namespace: tuple[str, ...] + created_at: datetime + updated_at: datetime + last_accessed_at: datetime + + +class GetOp(NamedTuple): + namespace: tuple[str, ...] + id: str + + +class SearchOp(NamedTuple): + namespace_prefix: tuple[str, ...] + query: Optional[str] = None + filter: Optional[dict[str, Any]] = None + weights: Optional[dict[str, float]] = None + limit: int = 10 + offset: int = 0 + + +class PutOp(NamedTuple): + namespace: tuple[str, ...] + id: str + value: Optional[dict[str, Any]] class BaseStore: - def list(self, prefixes: List[str]) -> dict[str, dict[str, V]]: - # list[namespace] -> dict[namespace, list[value]] + __slots__ = ("__weakref__",) + + def get(self, ops: list[GetOp]) -> list[Optional[Item]]: + raise NotImplementedError + + def search(self, ops: list[SearchOp]) -> list[list[Item]]: + raise NotImplementedError + + def put(self, ops: list[PutOp]) -> None: raise NotImplementedError - def put(self, writes: List[tuple[str, str, Optional[V]]]) -> None: - # list[(namespace, key, value | none)] -> None + async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: raise NotImplementedError - async def alist(self, prefixes: List[str]) -> dict[str, dict[str, V]]: - # list[namespace] -> dict[namespace, list[value]] + async def asearch(self, ops: list[SearchOp]) -> list[list[Item]]: raise NotImplementedError - async def aput(self, writes: List[tuple[str, str, Optional[V]]]) -> None: - # list[(namespace, key, value | none)] -> None + async def aput(self, ops: list[PutOp]) -> None: raise NotImplementedError diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/langgraph/langgraph/store/batch.py index 54eb20d47..2aa3a7fb0 100644 --- a/libs/langgraph/langgraph/store/batch.py +++ b/libs/langgraph/langgraph/store/batch.py @@ -1,65 +1,85 @@ import asyncio -from typing import NamedTuple, Optional, Union +from typing import Optional, Union, cast -from langgraph.store.base import BaseStore, V +from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp - -class ListOp(NamedTuple): - prefixes: list[str] - - -class PutOp(NamedTuple): - writes: list[tuple[str, str, Optional[V]]] +Ops = Union[list[GetOp], list[SearchOp], list[PutOp]] class AsyncBatchedStore(BaseStore): def __init__(self, store: BaseStore) -> None: - self.store = store - self.aqueue: dict[asyncio.Future, Union[ListOp, PutOp]] = {} - self.task = asyncio.create_task(_run(self.aqueue, self.store)) + self._store = store + self._loop = asyncio.get_running_loop() + self._aqueue: dict[asyncio.Future, Ops] = {} + self._task = self._loop.create_task(_run(self._aqueue, self._store)) def __del__(self) -> None: - self.task.cancel() + self._task.cancel() + + async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: + fut = self._loop.create_future() + self._aqueue[fut] = ops + return await fut - async def alist(self, prefixes: list[str]) -> dict[str, dict[str, V]]: - fut = asyncio.get_running_loop().create_future() - self.aqueue[fut] = ListOp(prefixes) + async def asearch(self, ops: list[SearchOp]) -> list[list[Item]]: + fut = self._loop.create_future() + self._aqueue[fut] = ops return await fut - async def aput(self, writes: list[tuple[str, str, Optional[V]]]) -> None: - fut = asyncio.get_running_loop().create_future() - self.aqueue[fut] = PutOp(writes) + async def aput(self, ops: list[PutOp]) -> None: + fut = self._loop.create_future() + self._aqueue[fut] = ops return await fut -async def _run( - aqueue: dict[asyncio.Future, Union[ListOp, PutOp]], store: BaseStore -) -> None: +async def _run(aqueue: dict[asyncio.Future, Ops], store: BaseStore) -> None: while True: await asyncio.sleep(0) if not aqueue: continue - # this could use a lock, if we want thread safety + # get the operations to run taken = aqueue.copy() - aqueue.clear() # action each operation - lists = {f: o for f, o in taken.items() if isinstance(o, ListOp)} - if lists: + if gets := { + f: cast(list[GetOp], ops) + for f, ops in taken.items() + if isinstance(ops[0], GetOp) + }: + try: + results = await store.aget([g for ops in gets.values() for g in ops]) + for fut, ops in gets.items(): + fut.set_result(results[: len(ops)]) + results = results[len(ops) :] + except Exception as e: + for fut in gets: + fut.set_exception(e) + if searches := { + f: cast(list[SearchOp], ops) + for f, ops in taken.items() + if isinstance(ops[0], SearchOp) + }: try: - results = await store.alist( - [p for op in lists.values() for p in op.prefixes] + results = await store.asearch( + [s for ops in searches.values() for s in ops] ) - for fut, op in lists.items(): - fut.set_result({k: results.get(k) for k in op.prefixes}) + for fut, ops in searches.items(): + fut.set_result(results[: len(ops)]) + results = results[len(ops) :] except Exception as e: - for fut in lists: + for fut in searches: fut.set_exception(e) - puts = {f: o for f, o in taken.items() if isinstance(o, PutOp)} - if puts: + if puts := { + f: cast(list[PutOp], ops) + for f, ops in taken.items() + if isinstance(ops[0], PutOp) + }: try: - await store.aput([w for op in puts.values() for w in op.writes]) + await store.aput([p for ops in puts.values() for p in ops]) for fut in puts: fut.set_result(None) except Exception as e: for fut in puts: fut.set_exception(e) + # remove the operations from the queue + for fut in taken: + del aqueue[fut] diff --git a/libs/langgraph/langgraph/store/memory.py b/libs/langgraph/langgraph/store/memory.py index 48fa2884f..66ba80555 100644 --- a/libs/langgraph/langgraph/store/memory.py +++ b/libs/langgraph/langgraph/store/memory.py @@ -1,25 +1,70 @@ from collections import defaultdict +from datetime import datetime, timezone from typing import List, Optional -from langgraph.store.base import BaseStore, V +from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp class MemoryStore(BaseStore): def __init__(self) -> None: - self.data: dict[str, dict[str, V]] = defaultdict(dict) + self.data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) - def list(self, prefixes: List[str]) -> dict[str, dict[str, V]]: - return {prefix: self.data[prefix] for prefix in prefixes} + def get(self, ops: List[GetOp]) -> List[Optional[Item]]: + items = [self.data[op.namespace].get(op.id) for op in ops] + for item in items: + if item is not None: + item.last_accessed_at = datetime.now(timezone.utc) + return items - async def alist(self, prefixes: List[str]) -> dict[str, dict[str, V]]: - return self.list(prefixes) + def search(self, ops: List[SearchOp]) -> List[List[Item]]: + results: list[list[Item]] = [] + for op in ops: + candidates = [ + item + for namespace, items in self.data.items() + if ( + namespace[: len(op.namespace_prefix)] == op.namespace_prefix + if len(namespace) >= len(op.namespace_prefix) + else False + ) + for item in items.values() + ] + if op.query is not None: + raise NotImplementedError("Search queries are not supported") + if op.filter: + candidates = [ + item + for item in candidates + if item.value.items() >= op.filter.items() + ] + if op.weights: + raise NotImplementedError("Search weights are not supported") + results.append(candidates[op.offset : op.offset + op.limit]) + return results - def put(self, writes: List[tuple[str, str, Optional[V]]]) -> None: - for namespace, key, value in writes: - if value is None: - self.data[namespace].pop(key, None) + def put(self, ops: List[PutOp]) -> None: + for op in ops: + if op.value is None: + self.data[op.namespace].pop(op.id, None) + elif op.id in self.data[op.namespace]: + self.data[op.namespace][op.id].value = op.value + self.data[op.namespace][op.id].updated_at = datetime.now(timezone.utc) else: - self.data[namespace][key] = value + self.data[op.namespace][op.id] = Item( + value=op.value, + scores={}, + id=op.id, + namespace=op.namespace, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + last_accessed_at=datetime.now(timezone.utc), + ) - async def aput(self, writes: List[tuple[str, str, Optional[V]]]) -> None: - return self.put(writes) + async def aget(self, ops: List[GetOp]) -> List[Optional[Item]]: + return self.get(ops) + + async def asearch(self, ops: List[SearchOp]) -> List[List[Item]]: + return self.search(ops) + + async def aput(self, ops: List[PutOp]) -> None: + return self.put(ops) diff --git a/libs/langgraph/langgraph/store/store.py b/libs/langgraph/langgraph/store/store.py new file mode 100644 index 000000000..fdfb4821d --- /dev/null +++ b/libs/langgraph/langgraph/store/store.py @@ -0,0 +1,99 @@ +from typing import Any, Optional +from weakref import WeakKeyDictionary + +from langgraph.store.base import BaseStore, GetOp, Item, SearchOp + +CACHE: WeakKeyDictionary[BaseStore, "Store"] = WeakKeyDictionary() + + +def _get_store(store: BaseStore) -> "Store": + if store not in CACHE: + CACHE[store] = Store(store) + return CACHE[store] + + +class Store: + __slots__ = ("_store",) + + def __init__(self, store: BaseStore) -> None: + self._store = store + + def get( + self, + namespace: tuple[str, ...], + id: str, + ) -> Optional[Item]: + return self._store.get([GetOp(namespace, id)])[0] + + def search( + self, + namesapce_prefix: tuple[str, ...], + /, + *, + query: Optional[str], + filter: Optional[dict[str, Any]], + weights: Optional[dict[str, float]], + limit: int = 10, + offset: int = 0, + ) -> list[Item]: + return self._store.search( + [ + SearchOp(namesapce_prefix, query, filter, weights, limit, offset), + ] + )[0] + + def put( + self, + namespace: tuple[str, ...], + id: str, + value: dict[str, Any], + ) -> None: + self._store.put([GetOp(namespace, id, value)]) + + def delete( + self, + namespace: tuple[str, ...], + id: str, + ) -> None: + self._store.put([GetOp(namespace, id, None)]) + + async def aget( + self, + namespace: tuple[str, ...], + id: str, + ) -> Optional[Item]: + return await self._store.aget([GetOp(namespace, id)])[0] + + async def asearch( + self, + namesapce_prefix: tuple[str, ...], + /, + *, + query: Optional[str], + filter: Optional[dict[str, Any]], + weights: Optional[dict[str, float]], + limit: int = 10, + offset: int = 0, + ) -> list[Item]: + return ( + await self._store.asearch( + [ + SearchOp(namesapce_prefix, query, filter, weights, limit, offset), + ] + ) + )[0] + + async def aput( + self, + namespace: tuple[str, ...], + id: str, + value: dict[str, Any], + ) -> None: + await self._store.aput([GetOp(namespace, id, value)]) + + async def adelete( + self, + namespace: tuple[str, ...], + id: str, + ) -> None: + await self._store.aput([GetOp(namespace, id, None)]) diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 0a90217a0..29b38cac4 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -34,7 +34,8 @@ from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard -from langgraph.constants import CONF, CONFIG_KEY_STREAM_WRITER +from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER +from langgraph.store.store import Store, _get_store from langgraph.types import StreamWriter from langgraph.utils.config import ( ensure_config, @@ -59,12 +60,22 @@ class StrEnum(str, enum.Enum): ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) -KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( +KWARGS_CONFIG_KEYS: tuple[ + tuple[str, tuple[Any, ...], str, Any, Optional[Callable[[Any], Any]]], ... +] = ( ( sys.intern("writer"), (StreamWriter, inspect.Parameter.empty), CONFIG_KEY_STREAM_WRITER, lambda _: None, + None, + ), + ( + sys.intern("store"), + (Store, inspect.Parameter.empty), + CONFIG_KEY_STORE, + inspect.Parameter.empty, + _get_store, ), ) """List of kwargs that can be passed to functions, and their corresponding @@ -112,7 +123,7 @@ def __init__( params = inspect.signature(cast(Callable, func or afunc)).parameters self.func_accepts_config = "config" in params self.func_accepts: dict[str, bool] = {} - for kw, typ, _, _ in KWARGS_CONFIG_KEYS: + for kw, typ, _, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) self.func_accepts[kw] = ( p is not None and p.annotation in typ and p.kind in VALID_KINDS @@ -140,9 +151,12 @@ def invoke( kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, ck, defv, map in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - kwargs[kw] = config[CONF].get(ck, defv) + if map is not None: + kwargs[kw] = map(config[CONF].get(ck, defv)) + else: + kwargs[kw] = config[CONF].get(ck, defv) context = copy_context() if self.trace: callback_manager = get_callback_manager_for_config(config, self.tags) @@ -179,9 +193,12 @@ async def ainvoke( kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, ck, defv, map in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - kwargs[kw] = config[CONF].get(ck, defv) + if map is not None: + kwargs[kw] = map(config[CONF].get(ck, defv)) + else: + kwargs[kw] = config[CONF].get(ck, defv) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) From a72afdcd53b6ea407a86d492fa01e0bbb173b713 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 24 Sep 2024 17:30:28 -0700 Subject: [PATCH 02/26] Fix test --- libs/langgraph/langgraph/store/batch.py | 6 +++ libs/langgraph/tests/test_store.py | 67 ++++++++++++++++++------- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/langgraph/langgraph/store/batch.py index 2aa3a7fb0..0d65bb709 100644 --- a/libs/langgraph/langgraph/store/batch.py +++ b/libs/langgraph/langgraph/store/batch.py @@ -17,16 +17,22 @@ def __del__(self) -> None: self._task.cancel() async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: + if any(not isinstance(op, GetOp) for op in ops): + raise TypeError("All operations must be GetOp") fut = self._loop.create_future() self._aqueue[fut] = ops return await fut async def asearch(self, ops: list[SearchOp]) -> list[list[Item]]: + if any(not isinstance(op, SearchOp) for op in ops): + raise TypeError("All operations must be SearchOp") fut = self._loop.create_future() self._aqueue[fut] = ops return await fut async def aput(self, ops: list[PutOp]) -> None: + if any(not isinstance(op, PutOp) for op in ops): + raise TypeError("All operations must be PutOp") fut = self._loop.create_future() self._aqueue[fut] = ops return await fut diff --git a/libs/langgraph/tests/test_store.py b/libs/langgraph/tests/test_store.py index 71494adaf..b95b5e828 100644 --- a/libs/langgraph/tests/test_store.py +++ b/libs/langgraph/tests/test_store.py @@ -1,10 +1,11 @@ import asyncio -from typing import Any, Optional +from datetime import datetime +from typing import Optional import pytest from pytest_mock import MockerFixture -from langgraph.store.base import BaseStore +from langgraph.store.base import BaseStore, GetOp, Item from langgraph.store.batch import AsyncBatchedStore pytestmark = pytest.mark.anyio @@ -12,27 +13,59 @@ async def test_async_batch_store(mocker: MockerFixture) -> None: aget = mocker.stub() - alist = mocker.stub() class MockStore(BaseStore): - async def aget( - self, pairs: list[tuple[str, str]] - ) -> dict[tuple[str, str], Optional[dict[str, Any]]]: - aget(pairs) - return {pair: 1 for pair in pairs} - - async def alist(self, prefixes: list[str]) -> dict[str, dict[str, Any]]: - alist(prefixes) - return {prefix: {prefix: 1} for prefix in prefixes} + async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: + aget(ops) + return [ + Item( + value={}, + scores={}, + id=op.id, + namespace=op.namespace, + created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + last_accessed_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + ) + for op in ops + ] store = AsyncBatchedStore(MockStore()) # concurrent calls are batched results = await asyncio.gather( - store.alist(["a", "b"]), - store.alist(["c", "d"]), + store.aget([GetOp(("a",), "b")]), + store.aget([GetOp(("c",), "d")]), ) - assert results == [{"a": {"a": 1}, "b": {"b": 1}}, {"c": {"c": 1}, "d": {"d": 1}}] - assert [c.args for c in alist.call_args_list] == [ - (["a", "b", "c", "d"],), + assert results == [ + [ + Item( + {}, + {}, + "b", + ("a",), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + ) + ], + [ + Item( + {}, + {}, + "d", + ("c",), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + ) + ], + ] + assert [c.args for c in aget.call_args_list] == [ + ( + [ + GetOp(("a",), "b"), + GetOp(("c",), "d"), + ], + ), ] From 595c5ea65a79ca102532a1ef913fa84b0f081dd8 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 24 Sep 2024 17:38:33 -0700 Subject: [PATCH 03/26] Lint --- libs/langgraph/langgraph/store/batch.py | 16 ++++++++-------- libs/langgraph/langgraph/store/store.py | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/langgraph/langgraph/store/batch.py index 0d65bb709..52e53eab9 100644 --- a/libs/langgraph/langgraph/store/batch.py +++ b/libs/langgraph/langgraph/store/batch.py @@ -52,10 +52,10 @@ async def _run(aqueue: dict[asyncio.Future, Ops], store: BaseStore) -> None: if isinstance(ops[0], GetOp) }: try: - results = await store.aget([g for ops in gets.values() for g in ops]) - for fut, ops in gets.items(): - fut.set_result(results[: len(ops)]) - results = results[len(ops) :] + gresults = await store.aget([g for ops in gets.values() for g in ops]) + for fut, gops in gets.items(): + fut.set_result(gresults[: len(gops)]) + gresults = gresults[len(gops) :] except Exception as e: for fut in gets: fut.set_exception(e) @@ -65,12 +65,12 @@ async def _run(aqueue: dict[asyncio.Future, Ops], store: BaseStore) -> None: if isinstance(ops[0], SearchOp) }: try: - results = await store.asearch( + sresults = await store.asearch( [s for ops in searches.values() for s in ops] ) - for fut, ops in searches.items(): - fut.set_result(results[: len(ops)]) - results = results[len(ops) :] + for fut, sops in searches.items(): + fut.set_result(sresults[: len(sops)]) + sresults = sresults[len(sops) :] except Exception as e: for fut in searches: fut.set_exception(e) diff --git a/libs/langgraph/langgraph/store/store.py b/libs/langgraph/langgraph/store/store.py index fdfb4821d..d6dbfbb36 100644 --- a/libs/langgraph/langgraph/store/store.py +++ b/libs/langgraph/langgraph/store/store.py @@ -1,7 +1,7 @@ from typing import Any, Optional from weakref import WeakKeyDictionary -from langgraph.store.base import BaseStore, GetOp, Item, SearchOp +from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp CACHE: WeakKeyDictionary[BaseStore, "Store"] = WeakKeyDictionary() @@ -48,21 +48,21 @@ def put( id: str, value: dict[str, Any], ) -> None: - self._store.put([GetOp(namespace, id, value)]) + self._store.put([PutOp(namespace, id, value)]) def delete( self, namespace: tuple[str, ...], id: str, ) -> None: - self._store.put([GetOp(namespace, id, None)]) + self._store.put([PutOp(namespace, id, None)]) async def aget( self, namespace: tuple[str, ...], id: str, ) -> Optional[Item]: - return await self._store.aget([GetOp(namespace, id)])[0] + return (await self._store.aget([GetOp(namespace, id)]))[0] async def asearch( self, @@ -89,11 +89,11 @@ async def aput( id: str, value: dict[str, Any], ) -> None: - await self._store.aput([GetOp(namespace, id, value)]) + await self._store.aput([PutOp(namespace, id, value)]) async def adelete( self, namespace: tuple[str, ...], id: str, ) -> None: - await self._store.aput([GetOp(namespace, id, None)]) + await self._store.aput([PutOp(namespace, id, None)]) From cab8a2444570886428a2b465bd55ac385e8ac4ed Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:30:56 -0700 Subject: [PATCH 04/26] Spelling --- libs/langgraph/langgraph/store/store.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langgraph/langgraph/store/store.py b/libs/langgraph/langgraph/store/store.py index d6dbfbb36..c5a96424d 100644 --- a/libs/langgraph/langgraph/store/store.py +++ b/libs/langgraph/langgraph/store/store.py @@ -27,7 +27,7 @@ def get( def search( self, - namesapce_prefix: tuple[str, ...], + namespace_prefix: tuple[str, ...], /, *, query: Optional[str], @@ -38,7 +38,7 @@ def search( ) -> list[Item]: return self._store.search( [ - SearchOp(namesapce_prefix, query, filter, weights, limit, offset), + SearchOp(namespace_prefix, query, filter, weights, limit, offset), ] )[0] @@ -66,7 +66,7 @@ async def aget( async def asearch( self, - namesapce_prefix: tuple[str, ...], + namespace_prefix: tuple[str, ...], /, *, query: Optional[str], @@ -78,7 +78,7 @@ async def asearch( return ( await self._store.asearch( [ - SearchOp(namesapce_prefix, query, filter, weights, limit, offset), + SearchOp(namespace_prefix, query, filter, weights, limit, offset), ] ) )[0] From d449ab533db557022bc75fe2470f97db31875650 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:35:29 -0700 Subject: [PATCH 05/26] Support stringized annotations --- libs/langgraph/langgraph/utils/runnable.py | 5 ++- libs/langgraph/tests/test_runnable.py | 43 ++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 libs/langgraph/tests/test_runnable.py diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 29b38cac4..2c5222985 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -120,7 +120,10 @@ def __init__( # check signature if func is None and afunc is None: raise ValueError("At least one of func or afunc must be provided.") - params = inspect.signature(cast(Callable, func or afunc)).parameters + params = inspect.signature( + cast(Callable, func or afunc), eval_str=True + ).parameters + self.func_accepts_config = "config" in params self.func_accepts: dict[str, bool] = {} for kw, typ, _, _, _ in KWARGS_CONFIG_KEYS: diff --git a/libs/langgraph/tests/test_runnable.py b/libs/langgraph/tests/test_runnable.py new file mode 100644 index 000000000..3f318ed0b --- /dev/null +++ b/libs/langgraph/tests/test_runnable.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any + +from langgraph.store.store import Store +from langgraph.types import StreamWriter +from langgraph.utils.runnable import RunnableCallable + + +def test_runnable_callable_func_accepts(): + def sync_func(x: Any) -> str: + return f"{x}" + + async def async_func(x: Any) -> str: + return f"{x}" + + def func_with_store(x: Any, store: Store) -> str: + return f"{x}" + + def func_with_writer(x: Any, writer: StreamWriter) -> str: + return f"{x}" + + async def afunc_with_store(x: Any, store: Store) -> str: + return f"{x}" + + async def afunc_with_writer(x: Any, writer: StreamWriter) -> str: + return f"{x}" + + runnables = { + "sync": RunnableCallable(sync_func), + "async": RunnableCallable(func=None, afunc=async_func), + "with_store": RunnableCallable(func_with_store), + "with_writer": RunnableCallable(func_with_writer), + "awith_store": RunnableCallable(afunc_with_store), + "awith_writer": RunnableCallable(afunc_with_writer), + } + + expected_store = {"with_store": True, "awith_store": True} + expected_writer = {"with_writer": True, "awith_writer": True} + + for name, runnable in runnables.items(): + assert runnable.func_accepts["writer"] == expected_writer.get(name, False) + assert runnable.func_accepts["store"] == expected_store.get(name, False) From 6b1bf06fb80787f1664a3d251b56afcf6318e987 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:40:38 -0700 Subject: [PATCH 06/26] String Match --- libs/langgraph/langgraph/utils/runnable.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 2c5222985..4aa83b16f 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -65,14 +65,14 @@ class StrEnum(str, enum.Enum): ] = ( ( sys.intern("writer"), - (StreamWriter, inspect.Parameter.empty), + (StreamWriter, "StreamWriter", inspect.Parameter.empty), CONFIG_KEY_STREAM_WRITER, lambda _: None, None, ), ( sys.intern("store"), - (Store, inspect.Parameter.empty), + (Store, "Store", inspect.Parameter.empty), CONFIG_KEY_STORE, inspect.Parameter.empty, _get_store, @@ -120,9 +120,7 @@ def __init__( # check signature if func is None and afunc is None: raise ValueError("At least one of func or afunc must be provided.") - params = inspect.signature( - cast(Callable, func or afunc), eval_str=True - ).parameters + params = inspect.signature(cast(Callable, func or afunc)).parameters self.func_accepts_config = "config" in params self.func_accepts: dict[str, bool] = {} From 3bf4270cf8edc2bb105c69a0b8e3d6c21134f642 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 25 Sep 2024 08:15:31 -0700 Subject: [PATCH 07/26] Inject store --- libs/langgraph/langgraph/pregel/algo.py | 11 +++++++++++ libs/langgraph/langgraph/pregel/loop.py | 2 ++ libs/langgraph/tests/test_algo.py | 11 ++++++++++- .../langgraph/scheduler/kafka/executor.py | 1 + 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 99de1f130..0969d029c 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -35,6 +35,7 @@ CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_SEND, + CONFIG_KEY_STORE, CONFIG_KEY_TASK_ID, EMPTY_SEQ, INTERRUPT, @@ -54,6 +55,7 @@ from langgraph.pregel.log import logger from langgraph.pregel.manager import ChannelsManager from langgraph.pregel.read import PregelNode +from langgraph.store.base import BaseStore from langgraph.types import All, PregelExecutableTask, PregelTask from langgraph.utils.config import merge_configs, patch_config @@ -274,6 +276,7 @@ def prepare_next_tasks( step: int, *, for_execution: Literal[False], + store: Literal[None] = None, checkpointer: Literal[None] = None, manager: Literal[None] = None, ) -> dict[str, PregelTask]: ... @@ -289,6 +292,7 @@ def prepare_next_tasks( step: int, *, for_execution: Literal[True], + store: Optional[BaseStore], checkpointer: Optional[BaseCheckpointSaver], manager: Union[None, ParentRunManager, AsyncParentRunManager], ) -> dict[str, PregelExecutableTask]: ... @@ -303,6 +307,7 @@ def prepare_next_tasks( step: int, *, for_execution: bool, + store: Optional[BaseStore] = None, checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]: @@ -322,6 +327,7 @@ def prepare_next_tasks( config=config, step=step, for_execution=for_execution, + store=store, checkpointer=checkpointer, manager=manager, ): @@ -339,6 +345,7 @@ def prepare_next_tasks( config=config, step=step, for_execution=for_execution, + store=store, checkpointer=checkpointer, manager=manager, ): @@ -357,6 +364,7 @@ def prepare_single_task( config: RunnableConfig, step: int, for_execution: bool, + store: Optional[BaseStore] = None, checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[None, PregelTask, PregelExecutableTask]: @@ -438,6 +446,9 @@ def prepare_single_task( PregelTaskWrites(packet.node, writes, triggers), config, ), + CONFIG_KEY_STORE: ( + store or configurable.get(CONFIG_KEY_STORE) + ), CONFIG_KEY_CHECKPOINTER: ( checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 071f798d7..634dc440a 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -367,6 +367,7 @@ def tick( self.step, for_execution=True, manager=manager, + store=self.store, checkpointer=self.checkpointer, ) # we don't need to save the writes for the last task that completes @@ -495,6 +496,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: self.config, self.step, for_execution=True, + store=None, checkpointer=None, manager=None, ) diff --git a/libs/langgraph/tests/test_algo.py b/libs/langgraph/tests/test_algo.py index 203280488..4e259f29e 100644 --- a/libs/langgraph/tests/test_algo.py +++ b/libs/langgraph/tests/test_algo.py @@ -17,7 +17,16 @@ def test_prepare_next_tasks() -> None: ) assert ( prepare_next_tasks( - checkpoint, processes, channels, managed, config, 0, for_execution=True + checkpoint, + processes, + channels, + managed, + config, + 0, + for_execution=True, + checkpointer=None, + store=None, + manager=None, ) == {} ) diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py index c803239e8..d8e150f5b 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py @@ -194,6 +194,7 @@ async def attempt(self, msg: MessageToExecutor) -> None: step=saved.metadata["step"] + 1, for_execution=True, checkpointer=self.graph.checkpointer, + store=self.graph.store, ): # execute task, saving writes runner = PregelRunner( From 0a9cd3ae309d96bca41cdc1cd6d1e43a00409767 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 25 Sep 2024 09:00:38 -0700 Subject: [PATCH 08/26] Merge BaseStore and Store --- .../langgraph/managed/shared_value.py | 14 +- libs/langgraph/langgraph/pregel/loop.py | 2 - libs/langgraph/langgraph/store/base.py | 119 +++++++++++++--- libs/langgraph/langgraph/store/batch.py | 127 +++++++++--------- libs/langgraph/langgraph/store/memory.py | 103 +++++++------- libs/langgraph/langgraph/store/store.py | 99 -------------- libs/langgraph/langgraph/utils/runnable.py | 34 ++--- libs/langgraph/tests/test_runnable.py | 6 +- libs/langgraph/tests/test_store.py | 72 +++++----- 9 files changed, 275 insertions(+), 301 deletions(-) delete mode 100644 libs/langgraph/langgraph/store/store.py diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index 607c2363f..12a15fc93 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -21,7 +21,7 @@ ConfiguredManagedValue, WritableManagedValue, ) -from langgraph.store.base import BaseStore, PutOp, SearchOp +from langgraph.store.base import BaseStore, PutOp V = dict[str, Any] @@ -58,8 +58,8 @@ def on(scope: str) -> ConfiguredManagedValue: def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: with super().enter(config, **kwargs) as value: if value.store is not None: - saved = value.store.search([SearchOp(value.ns)]) - value.value = {it.id: it.value for it in saved[0]} + saved = value.store.search(value.ns) + value.value = {it.id: it.value for it in saved} yield value @classmethod @@ -67,8 +67,8 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Self]: async with super().aenter(config, **kwargs) as value: if value.store is not None: - saved = await value.store.asearch([SearchOp(value.ns)]) - value.value = {it.id: it.value for it in saved[0]} + saved = await value.store.asearch(value.ns) + value.value = {it.id: it.value for it in saved} yield value def __init__( @@ -115,10 +115,10 @@ def update(self, values: Sequence[Update]) -> None: if self.store is None: self._process_update(values) else: - return self.store.put(self._process_update(values)) + return self.store.batch(self._process_update(values)) async def aupdate(self, writes: Sequence[Update]) -> None: if self.store is None: self._process_update(writes) else: - return await self.store.aput(self._process_update(writes)) + return await self.store.abatch(self._process_update(writes)) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 634dc440a..6353680c3 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -101,7 +101,6 @@ from langgraph.pregel.read import PregelNode from langgraph.pregel.utils import get_new_channel_versions from langgraph.store.base import BaseStore -from langgraph.store.batch import AsyncBatchedStore from langgraph.types import All, PregelExecutableTask, StreamMode from langgraph.utils.config import patch_configurable @@ -785,7 +784,6 @@ def __init__( check_subgraphs=check_subgraphs, debug=debug, ) - self.store = AsyncBatchedStore(self.store) if self.store else None self.stack = AsyncExitStack() if checkpointer: self.checkpointer_get_next_version = checkpointer.get_next_version diff --git a/libs/langgraph/langgraph/store/base.py b/libs/langgraph/langgraph/store/base.py index 31c2a07ca..f9b027ecb 100644 --- a/libs/langgraph/langgraph/store/base.py +++ b/libs/langgraph/langgraph/store/base.py @@ -1,6 +1,7 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, NamedTuple, Optional +from typing import Any, Iterable, NamedTuple, Optional, Sequence, Union SCORE_RECENCY = "recency" SCORE_RELEVANCE = "relevance" @@ -39,23 +40,105 @@ class PutOp(NamedTuple): value: Optional[dict[str, Any]] -class BaseStore: - __slots__ = ("__weakref__",) - - def get(self, ops: list[GetOp]) -> list[Optional[Item]]: - raise NotImplementedError - - def search(self, ops: list[SearchOp]) -> list[list[Item]]: - raise NotImplementedError +Op = Union[GetOp, SearchOp, PutOp] +Result = Union[Item, list[Item], None] - def put(self, ops: list[PutOp]) -> None: - raise NotImplementedError - async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: - raise NotImplementedError - - async def asearch(self, ops: list[SearchOp]) -> list[list[Item]]: - raise NotImplementedError +class BaseStore(ABC): + __slots__ = ("__weakref__",) - async def aput(self, ops: list[PutOp]) -> None: - raise NotImplementedError + # abstract methods + + @abstractmethod + def batch( + self, + ops: Iterable[Op], + ) -> Sequence[Result]: ... + + @abstractmethod + async def abatch( + self, + ops: Iterable[Op], + ) -> Sequence[Result]: ... + + # convenience methods + + def get( + self, + namespace: tuple[str, ...], + id: str, + ) -> Optional[Item]: + return self.batch([GetOp(namespace, id)])[0] + + def search( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + weights: Optional[dict[str, float]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[Item]: + return self.batch( + [ + SearchOp(namespace_prefix, query, filter, weights, limit, offset), + ] + )[0] + + def put( + self, + namespace: tuple[str, ...], + id: str, + value: dict[str, Any], + ) -> None: + self.batch([PutOp(namespace, id, value)]) + + def delete( + self, + namespace: tuple[str, ...], + id: str, + ) -> None: + self.batch([PutOp(namespace, id, None)]) + + async def aget( + self, + namespace: tuple[str, ...], + id: str, + ) -> Optional[Item]: + return (await self.abatch([GetOp(namespace, id)]))[0] + + async def asearch( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + weights: Optional[dict[str, float]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[Item]: + return ( + await self.abatch( + [ + SearchOp(namespace_prefix, query, filter, weights, limit, offset), + ] + ) + )[0] + + async def aput( + self, + namespace: tuple[str, ...], + id: str, + value: dict[str, Any], + ) -> None: + await self.abatch([PutOp(namespace, id, value)]) + + async def adelete( + self, + namespace: tuple[str, ...], + id: str, + ) -> None: + await self.abatch([PutOp(namespace, id, None)]) diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/langgraph/langgraph/store/batch.py index 52e53eab9..75f73915e 100644 --- a/libs/langgraph/langgraph/store/batch.py +++ b/libs/langgraph/langgraph/store/batch.py @@ -1,91 +1,90 @@ import asyncio -from typing import Optional, Union, cast +import weakref +from typing import Any, Optional -from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp +from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, SearchOp -Ops = Union[list[GetOp], list[SearchOp], list[PutOp]] +class AsyncBatchedBaseStore(BaseStore): + __slots__ = ("_loop", "_aqueue", "_task") -class AsyncBatchedStore(BaseStore): - def __init__(self, store: BaseStore) -> None: - self._store = store + def __init__(self) -> None: self._loop = asyncio.get_running_loop() - self._aqueue: dict[asyncio.Future, Ops] = {} - self._task = self._loop.create_task(_run(self._aqueue, self._store)) + self._aqueue: dict[asyncio.Future, Op] = {} + self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) def __del__(self) -> None: self._task.cancel() - async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: - if any(not isinstance(op, GetOp) for op in ops): - raise TypeError("All operations must be GetOp") + async def aget( + self, + namespace: tuple[str, ...], + id: str, + ) -> Optional[Item]: fut = self._loop.create_future() - self._aqueue[fut] = ops + self._aqueue[fut] = GetOp(namespace, id) return await fut - async def asearch(self, ops: list[SearchOp]) -> list[list[Item]]: - if any(not isinstance(op, SearchOp) for op in ops): - raise TypeError("All operations must be SearchOp") + async def asearch( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + weights: Optional[dict[str, float]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[Item]: fut = self._loop.create_future() - self._aqueue[fut] = ops + self._aqueue[fut] = SearchOp( + namespace_prefix, query, filter, weights, limit, offset + ) return await fut - async def aput(self, ops: list[PutOp]) -> None: - if any(not isinstance(op, PutOp) for op in ops): - raise TypeError("All operations must be PutOp") + async def aput( + self, + namespace: tuple[str, ...], + id: str, + value: dict[str, Any], + ) -> None: fut = self._loop.create_future() - self._aqueue[fut] = ops + self._aqueue[fut] = PutOp(namespace, id, value) + return await fut + + async def adelete( + self, + namespace: tuple[str, ...], + id: str, + ) -> None: + fut = self._loop.create_future() + self._aqueue[fut] = PutOp(namespace, id, None) return await fut -async def _run(aqueue: dict[asyncio.Future, Ops], store: BaseStore) -> None: +async def _run( + aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore] +) -> None: while True: await asyncio.sleep(0) if not aqueue: continue - # get the operations to run - taken = aqueue.copy() - # action each operation - if gets := { - f: cast(list[GetOp], ops) - for f, ops in taken.items() - if isinstance(ops[0], GetOp) - }: - try: - gresults = await store.aget([g for ops in gets.values() for g in ops]) - for fut, gops in gets.items(): - fut.set_result(gresults[: len(gops)]) - gresults = gresults[len(gops) :] - except Exception as e: - for fut in gets: - fut.set_exception(e) - if searches := { - f: cast(list[SearchOp], ops) - for f, ops in taken.items() - if isinstance(ops[0], SearchOp) - }: - try: - sresults = await store.asearch( - [s for ops in searches.values() for s in ops] - ) - for fut, sops in searches.items(): - fut.set_result(sresults[: len(sops)]) - sresults = sresults[len(sops) :] - except Exception as e: - for fut in searches: - fut.set_exception(e) - if puts := { - f: cast(list[PutOp], ops) - for f, ops in taken.items() - if isinstance(ops[0], PutOp) - }: + if s := store(): + # get the operations to run + taken = aqueue.copy() + # action each operation try: - await store.aput([p for ops in puts.values() for p in ops]) - for fut in puts: - fut.set_result(None) + results = await s.abatch(taken.values()) + # set the results of each operation + for fut, result in zip(taken, results): + fut.set_result(result) except Exception as e: - for fut in puts: + for fut in taken: fut.set_exception(e) - # remove the operations from the queue - for fut in taken: - del aqueue[fut] + # remove the operations from the queue + for fut in taken: + del aqueue[fut] + else: + break + # remove strong ref to store + del s diff --git a/libs/langgraph/langgraph/store/memory.py b/libs/langgraph/langgraph/store/memory.py index 66ba80555..2b18a9433 100644 --- a/libs/langgraph/langgraph/store/memory.py +++ b/libs/langgraph/langgraph/store/memory.py @@ -1,70 +1,63 @@ from collections import defaultdict from datetime import datetime, timezone -from typing import List, Optional -from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp +from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, Result, SearchOp class MemoryStore(BaseStore): def __init__(self) -> None: self.data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) - def get(self, ops: List[GetOp]) -> List[Optional[Item]]: - items = [self.data[op.namespace].get(op.id) for op in ops] - for item in items: - if item is not None: - item.last_accessed_at = datetime.now(timezone.utc) - return items - - def search(self, ops: List[SearchOp]) -> List[List[Item]]: - results: list[list[Item]] = [] + def batch(self, ops: list[Op]) -> list[Result]: + results: list[Result] = [] for op in ops: - candidates = [ - item - for namespace, items in self.data.items() - if ( - namespace[: len(op.namespace_prefix)] == op.namespace_prefix - if len(namespace) >= len(op.namespace_prefix) - else False - ) - for item in items.values() - ] - if op.query is not None: - raise NotImplementedError("Search queries are not supported") - if op.filter: + if isinstance(op, GetOp): + item = self.data[op.namespace].get(op.id) + if item is not None: + item.last_accessed_at = datetime.now(timezone.utc) + results.append(item) + elif isinstance(op, SearchOp): candidates = [ item - for item in candidates - if item.value.items() >= op.filter.items() + for namespace, items in self.data.items() + if ( + namespace[: len(op.namespace_prefix)] == op.namespace_prefix + if len(namespace) >= len(op.namespace_prefix) + else False + ) + for item in items.values() ] - if op.weights: - raise NotImplementedError("Search weights are not supported") - results.append(candidates[op.offset : op.offset + op.limit]) + if op.query is not None: + raise NotImplementedError("Search queries are not supported") + if op.filter: + candidates = [ + item + for item in candidates + if item.value.items() >= op.filter.items() + ] + if op.weights: + raise NotImplementedError("Search weights are not supported") + results.append(candidates[op.offset : op.offset + op.limit]) + elif isinstance(op, PutOp): + if op.value is None: + self.data[op.namespace].pop(op.id, None) + elif op.id in self.data[op.namespace]: + self.data[op.namespace][op.id].value = op.value + self.data[op.namespace][op.id].updated_at = datetime.now( + timezone.utc + ) + else: + self.data[op.namespace][op.id] = Item( + value=op.value, + scores={}, + id=op.id, + namespace=op.namespace, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + last_accessed_at=datetime.now(timezone.utc), + ) + results.append(None) return results - def put(self, ops: List[PutOp]) -> None: - for op in ops: - if op.value is None: - self.data[op.namespace].pop(op.id, None) - elif op.id in self.data[op.namespace]: - self.data[op.namespace][op.id].value = op.value - self.data[op.namespace][op.id].updated_at = datetime.now(timezone.utc) - else: - self.data[op.namespace][op.id] = Item( - value=op.value, - scores={}, - id=op.id, - namespace=op.namespace, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - last_accessed_at=datetime.now(timezone.utc), - ) - - async def aget(self, ops: List[GetOp]) -> List[Optional[Item]]: - return self.get(ops) - - async def asearch(self, ops: List[SearchOp]) -> List[List[Item]]: - return self.search(ops) - - async def aput(self, ops: List[PutOp]) -> None: - return self.put(ops) + async def abatch(self, ops: list[Op]) -> list[Result]: + return self.batch(ops) diff --git a/libs/langgraph/langgraph/store/store.py b/libs/langgraph/langgraph/store/store.py deleted file mode 100644 index c5a96424d..000000000 --- a/libs/langgraph/langgraph/store/store.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Any, Optional -from weakref import WeakKeyDictionary - -from langgraph.store.base import BaseStore, GetOp, Item, PutOp, SearchOp - -CACHE: WeakKeyDictionary[BaseStore, "Store"] = WeakKeyDictionary() - - -def _get_store(store: BaseStore) -> "Store": - if store not in CACHE: - CACHE[store] = Store(store) - return CACHE[store] - - -class Store: - __slots__ = ("_store",) - - def __init__(self, store: BaseStore) -> None: - self._store = store - - def get( - self, - namespace: tuple[str, ...], - id: str, - ) -> Optional[Item]: - return self._store.get([GetOp(namespace, id)])[0] - - def search( - self, - namespace_prefix: tuple[str, ...], - /, - *, - query: Optional[str], - filter: Optional[dict[str, Any]], - weights: Optional[dict[str, float]], - limit: int = 10, - offset: int = 0, - ) -> list[Item]: - return self._store.search( - [ - SearchOp(namespace_prefix, query, filter, weights, limit, offset), - ] - )[0] - - def put( - self, - namespace: tuple[str, ...], - id: str, - value: dict[str, Any], - ) -> None: - self._store.put([PutOp(namespace, id, value)]) - - def delete( - self, - namespace: tuple[str, ...], - id: str, - ) -> None: - self._store.put([PutOp(namespace, id, None)]) - - async def aget( - self, - namespace: tuple[str, ...], - id: str, - ) -> Optional[Item]: - return (await self._store.aget([GetOp(namespace, id)]))[0] - - async def asearch( - self, - namespace_prefix: tuple[str, ...], - /, - *, - query: Optional[str], - filter: Optional[dict[str, Any]], - weights: Optional[dict[str, float]], - limit: int = 10, - offset: int = 0, - ) -> list[Item]: - return ( - await self._store.asearch( - [ - SearchOp(namespace_prefix, query, filter, weights, limit, offset), - ] - ) - )[0] - - async def aput( - self, - namespace: tuple[str, ...], - id: str, - value: dict[str, Any], - ) -> None: - await self._store.aput([PutOp(namespace, id, value)]) - - async def adelete( - self, - namespace: tuple[str, ...], - id: str, - ) -> None: - await self._store.aput([PutOp(namespace, id, None)]) diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 4aa83b16f..2545eb49e 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -35,7 +35,7 @@ from typing_extensions import TypeGuard from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER -from langgraph.store.store import Store, _get_store +from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.config import ( ensure_config, @@ -60,22 +60,18 @@ class StrEnum(str, enum.Enum): ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) -KWARGS_CONFIG_KEYS: tuple[ - tuple[str, tuple[Any, ...], str, Any, Optional[Callable[[Any], Any]]], ... -] = ( +KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( ( sys.intern("writer"), (StreamWriter, "StreamWriter", inspect.Parameter.empty), CONFIG_KEY_STREAM_WRITER, lambda _: None, - None, ), ( sys.intern("store"), - (Store, "Store", inspect.Parameter.empty), + (BaseStore, "BaseStore", inspect.Parameter.empty), CONFIG_KEY_STORE, inspect.Parameter.empty, - _get_store, ), ) """List of kwargs that can be passed to functions, and their corresponding @@ -124,7 +120,7 @@ def __init__( self.func_accepts_config = "config" in params self.func_accepts: dict[str, bool] = {} - for kw, typ, _, _, _ in KWARGS_CONFIG_KEYS: + for kw, typ, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) self.func_accepts[kw] = ( p is not None and p.annotation in typ and p.kind in VALID_KINDS @@ -152,12 +148,15 @@ def invoke( kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config - for kw, _, ck, defv, map in KWARGS_CONFIG_KEYS: + _conf = config[CONF] + for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - if map is not None: - kwargs[kw] = map(config[CONF].get(ck, defv)) + if defv is inspect.Parameter.empty and ck not in _conf: + raise ValueError( + f"Missing required config key '{ck}' for '{self.name}'." + ) else: - kwargs[kw] = config[CONF].get(ck, defv) + kwargs[kw] = _conf.get(ck, defv) context = copy_context() if self.trace: callback_manager = get_callback_manager_for_config(config, self.tags) @@ -194,12 +193,15 @@ async def ainvoke( kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config - for kw, _, ck, defv, map in KWARGS_CONFIG_KEYS: + _conf = config[CONF] + for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - if map is not None: - kwargs[kw] = map(config[CONF].get(ck, defv)) + if defv is inspect.Parameter.empty and ck not in _conf: + raise ValueError( + f"Missing required config key '{ck}' for '{self.name}'." + ) else: - kwargs[kw] = config[CONF].get(ck, defv) + kwargs[kw] = _conf.get(ck, defv) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) diff --git a/libs/langgraph/tests/test_runnable.py b/libs/langgraph/tests/test_runnable.py index 3f318ed0b..64d858fd4 100644 --- a/libs/langgraph/tests/test_runnable.py +++ b/libs/langgraph/tests/test_runnable.py @@ -2,7 +2,7 @@ from typing import Any -from langgraph.store.store import Store +from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.runnable import RunnableCallable @@ -14,13 +14,13 @@ def sync_func(x: Any) -> str: async def async_func(x: Any) -> str: return f"{x}" - def func_with_store(x: Any, store: Store) -> str: + def func_with_store(x: Any, store: BaseStore) -> str: return f"{x}" def func_with_writer(x: Any, writer: StreamWriter) -> str: return f"{x}" - async def afunc_with_store(x: Any, store: Store) -> str: + async def afunc_with_store(x: Any, store: BaseStore) -> str: return f"{x}" async def afunc_with_writer(x: Any, writer: StreamWriter) -> str: diff --git a/libs/langgraph/tests/test_store.py b/libs/langgraph/tests/test_store.py index b95b5e828..807a5798e 100644 --- a/libs/langgraph/tests/test_store.py +++ b/libs/langgraph/tests/test_store.py @@ -1,22 +1,25 @@ import asyncio from datetime import datetime -from typing import Optional import pytest from pytest_mock import MockerFixture -from langgraph.store.base import BaseStore, GetOp, Item -from langgraph.store.batch import AsyncBatchedStore +from langgraph.store.base import GetOp, Item, Op, Result +from langgraph.store.batch import AsyncBatchedBaseStore pytestmark = pytest.mark.anyio async def test_async_batch_store(mocker: MockerFixture) -> None: - aget = mocker.stub() + abatch = mocker.stub() - class MockStore(BaseStore): - async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: - aget(ops) + class MockStore(AsyncBatchedBaseStore): + def batch(self, ops: list[Op]) -> list[Result]: + raise NotImplementedError + + async def abatch(self, ops: list[Op]) -> list[Result]: + assert all(isinstance(op, GetOp) for op in ops) + abatch(ops) return [ Item( value={}, @@ -30,42 +33,37 @@ async def aget(self, ops: list[GetOp]) -> list[Optional[Item]]: for op in ops ] - store = AsyncBatchedStore(MockStore()) + store = MockStore() # concurrent calls are batched results = await asyncio.gather( - store.aget([GetOp(("a",), "b")]), - store.aget([GetOp(("c",), "d")]), + store.aget(namespace=("a",), id="b"), + store.aget(namespace=("c",), id="d"), ) assert results == [ - [ - Item( - {}, - {}, - "b", - ("a",), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - ) - ], - [ - Item( - {}, - {}, - "d", - ("c",), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - ) - ], + Item( + {}, + {}, + "b", + ("a",), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + ), + Item( + {}, + {}, + "d", + ("c",), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + datetime(2024, 9, 24, 17, 29, 10, 128397), + ), ] - assert [c.args for c in aget.call_args_list] == [ + assert abatch.call_count == 1 + assert [tuple(c.args[0]) for c in abatch.call_args_list] == [ ( - [ - GetOp(("a",), "b"), - GetOp(("c",), "d"), - ], + GetOp(("a",), "b"), + GetOp(("c",), "d"), ), ] From 4ab823fe1dabc197c7620ca9f79362d49856de39 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 25 Sep 2024 10:19:06 -0700 Subject: [PATCH 09/26] Slots for MemoryStore --- libs/langgraph/langgraph/store/memory.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/libs/langgraph/langgraph/store/memory.py b/libs/langgraph/langgraph/store/memory.py index 2b18a9433..1dd51d2b6 100644 --- a/libs/langgraph/langgraph/store/memory.py +++ b/libs/langgraph/langgraph/store/memory.py @@ -5,21 +5,23 @@ class MemoryStore(BaseStore): + __slots__ = ("_data",) + def __init__(self) -> None: - self.data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) + self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) def batch(self, ops: list[Op]) -> list[Result]: results: list[Result] = [] for op in ops: if isinstance(op, GetOp): - item = self.data[op.namespace].get(op.id) + item = self._data[op.namespace].get(op.id) if item is not None: item.last_accessed_at = datetime.now(timezone.utc) results.append(item) elif isinstance(op, SearchOp): candidates = [ item - for namespace, items in self.data.items() + for namespace, items in self._data.items() if ( namespace[: len(op.namespace_prefix)] == op.namespace_prefix if len(namespace) >= len(op.namespace_prefix) @@ -40,14 +42,14 @@ def batch(self, ops: list[Op]) -> list[Result]: results.append(candidates[op.offset : op.offset + op.limit]) elif isinstance(op, PutOp): if op.value is None: - self.data[op.namespace].pop(op.id, None) - elif op.id in self.data[op.namespace]: - self.data[op.namespace][op.id].value = op.value - self.data[op.namespace][op.id].updated_at = datetime.now( + self._data[op.namespace].pop(op.id, None) + elif op.id in self._data[op.namespace]: + self._data[op.namespace][op.id].value = op.value + self._data[op.namespace][op.id].updated_at = datetime.now( timezone.utc ) else: - self.data[op.namespace][op.id] = Item( + self._data[op.namespace][op.id] = Item( value=op.value, scores={}, id=op.id, From ebdc9896bc05ac9915eb4112b6f8cc90ce9fef52 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:42:49 -0700 Subject: [PATCH 10/26] Add tests for MemoryStore in graph --- libs/langgraph/tests/test_pregel.py | 61 ++++++++++++++++++ libs/langgraph/tests/test_pregel_async.py | 78 +++++++++++++++++++---- 2 files changed, 127 insertions(+), 12 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 87a684ec7..6bfe225de 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -2,6 +2,7 @@ import operator import re import time +import uuid import warnings from collections import Counter from concurrent.futures import ThreadPoolExecutor @@ -70,6 +71,7 @@ StateSnapshot, ) from langgraph.pregel.retry import RetryPolicy +from langgraph.store.base import BaseStore from langgraph.store.memory import MemoryStore from langgraph.types import Interrupt, PregelTask, Send, StreamWriter from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence @@ -11329,3 +11331,62 @@ def child_node_b(state: ChildState): app = parent.compile(checkpointer=checkpointer) with pytest.raises(RandomError): app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}}) + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_store_injected(request: pytest.FixtureRequest, checkpointer_name: str) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + count: Annotated[int, operator.add] + + doc_id = str(uuid.uuid4()) + doc = {"some-key": "this-is-a-val"} + + def node(input: State, config: RunnableConfig, store: BaseStore): + assert isinstance(store, BaseStore) + assert isinstance(store, MemoryStore) + store.put( + ("foo", "bar"), + doc_id, + { + **doc, + "from_thread": config["configurable"]["thread_id"], + "some_val": input["count"], + }, + ) + return {"count": 1} + + graph = StateGraph(State) + graph.add_node("node", node) + graph.add_edge("__start__", "node") + the_store = MemoryStore() + app = graph.compile(store=the_store, checkpointer=checkpointer) + + thread_1 = str(uuid.uuid4()) + result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + assert result == {"count": 1} + returned_doc = the_store.get(("foo", "bar"), doc_id).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} + assert len(the_store.search(("foo", "bar"), doc_id)) == 1 + + # Check update on existing thread + result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + assert result == {"count": 2} + returned_doc = the_store.get(("foo", "bar"), doc_id).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} + assert len(the_store.search(("foo", "bar"), doc_id)) == 1 + + thread_2 = str(uuid.uuid4()) + + result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) + assert result == {"count": 1} + returned_doc = the_store.get(("foo", "bar"), doc_id).value + assert returned_doc == { + **doc, + "from_thread": thread_2, + "some_val": 1, + } # Overwrites the whole doc + assert ( + len(the_store.search(("foo", "bar"), doc_id)) == 1 + ) # still overwriting the same one diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 0fe7210c8..f620304b2 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2,6 +2,7 @@ import operator import re import sys +import uuid from collections import Counter from contextlib import asynccontextmanager, contextmanager from time import perf_counter @@ -24,9 +25,7 @@ import httpx import pytest -from langchain_core.messages import ( - ToolCall, -) +from langchain_core.messages import ToolCall from langchain_core.runnables import ( RunnableConfig, RunnableLambda, @@ -57,17 +56,11 @@ from langgraph.graph.graph import START from langgraph.graph.message import MessageGraph, add_messages from langgraph.managed.shared_value import SharedValue -from langgraph.prebuilt.chat_agent_executor import ( - create_tool_calling_executor, -) +from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor from langgraph.prebuilt.tool_node import ToolNode -from langgraph.pregel import ( - Channel, - GraphRecursionError, - Pregel, - StateSnapshot, -) +from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot from langgraph.pregel.retry import RetryPolicy +from langgraph.store.base import BaseStore from langgraph.store.memory import MemoryStore from langgraph.types import Interrupt, PregelTask, Send, StreamWriter from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence @@ -9614,3 +9607,64 @@ def __call__(self, state): assert (await graph.ainvoke([], {"configurable": {"thread_id": "foo"}})) == [ "1" ] * 4 + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_store_injected_async( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + count: Annotated[int, operator.add] + + doc_id = str(uuid.uuid4()) + doc = {"some-key": "this-is-a-val"} + + async def node(input: State, config: RunnableConfig, store: BaseStore): + assert isinstance(store, BaseStore) + assert isinstance(store, MemoryStore) + await store.aput( + ("foo", "bar"), + doc_id, + { + **doc, + "from_thread": config["configurable"]["thread_id"], + "some_val": input["count"], + }, + ) + return {"count": 1} + + graph = StateGraph(State) + graph.add_node("node", node) + graph.add_edge("__start__", "node") + the_store = MemoryStore() + app = graph.compile(store=the_store, checkpointer=checkpointer) + + thread_1 = str(uuid.uuid4()) + result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + assert result == {"count": 1} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} + assert len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + + # Check update on existing thread + result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + assert result == {"count": 2} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} + assert len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + + thread_2 = str(uuid.uuid4()) + + result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) + assert result == {"count": 1} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == { + **doc, + "from_thread": thread_2, + "some_val": 1, + } # Overwrites the whole doc + assert ( + len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + ) # still overwriting the same one From 6af537b3fa1fa97d8fbe996b727e0da68cbf23db Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 26 Sep 2024 10:51:37 -0700 Subject: [PATCH 11/26] Fixes --- libs/langgraph/langgraph/pregel/__init__.py | 12 ++++++++++-- libs/langgraph/langgraph/pregel/algo.py | 3 +++ libs/langgraph/langgraph/store/base.py | 10 ++++++++-- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index cd3c8150d..bee52f455 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -61,6 +61,7 @@ CONFIG_KEY_READ, CONFIG_KEY_RESUMING, CONFIG_KEY_SEND, + CONFIG_KEY_STORE, CONFIG_KEY_STREAM, CONFIG_KEY_STREAM_WRITER, CONFIG_KEY_TASK_ID, @@ -1096,6 +1097,10 @@ def _defaults( raise ValueError( f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}" ) + if CONFIG_KEY_STORE in config.get(CONF, {}): + store: Optional[BaseStore] = config[CONF][CONFIG_KEY_STORE] + else: + store = self.store return ( debug, set(stream_mode), @@ -1103,6 +1108,7 @@ def _defaults( interrupt_before, interrupt_after, checkpointer, + store, ) def stream( @@ -1219,6 +1225,7 @@ def output() -> Iterator: interrupt_before_, interrupt_after_, checkpointer, + store, ) = self._defaults( config, stream_mode=stream_mode, @@ -1241,7 +1248,7 @@ def output() -> Iterator: input, stream=StreamProtocol(stream.put, stream_modes), config=config, - store=self.store, + store=store, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, @@ -1434,6 +1441,7 @@ def output() -> Iterator: interrupt_before_, interrupt_after_, checkpointer, + store, ) = self._defaults( config, stream_mode=stream_mode, @@ -1456,7 +1464,7 @@ def output() -> Iterator: input, stream=StreamProtocol(stream.put_nowait, stream_modes), config=config, - store=self.store, + store=store, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 0969d029c..daa942a39 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -556,6 +556,9 @@ def prepare_single_task( PregelTaskWrites(name, writes, triggers), config, ), + CONFIG_KEY_STORE: ( + store or configurable.get(CONFIG_KEY_STORE) + ), CONFIG_KEY_CHECKPOINTER: ( checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER) diff --git a/libs/langgraph/langgraph/store/base.py b/libs/langgraph/langgraph/store/base.py index f9b027ecb..bb5195fef 100644 --- a/libs/langgraph/langgraph/store/base.py +++ b/libs/langgraph/langgraph/store/base.py @@ -1,12 +1,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Iterable, NamedTuple, Optional, Sequence, Union +from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence, Union SCORE_RECENCY = "recency" SCORE_RELEVANCE = "relevance" +class Weight(NamedTuple): + field: Union[str, Literal["recency"], Literal["relevance"]] + weight: float + default: float = 0.0 + + @dataclass class Item: value: dict[str, Any] @@ -29,7 +35,7 @@ class SearchOp(NamedTuple): namespace_prefix: tuple[str, ...] query: Optional[str] = None filter: Optional[dict[str, Any]] = None - weights: Optional[dict[str, float]] = None + weights: Optional[list[Weight]] = None limit: int = 10 offset: int = 0 From c056300bcf82115a63c533c36ae9f2f7b80ec462 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 26 Sep 2024 10:52:40 -0700 Subject: [PATCH 12/26] Remove unimplemented args --- libs/langgraph/langgraph/store/base.py | 10 ++-------- libs/langgraph/langgraph/store/memory.py | 4 ---- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/libs/langgraph/langgraph/store/base.py b/libs/langgraph/langgraph/store/base.py index bb5195fef..d848dd321 100644 --- a/libs/langgraph/langgraph/store/base.py +++ b/libs/langgraph/langgraph/store/base.py @@ -33,9 +33,7 @@ class GetOp(NamedTuple): class SearchOp(NamedTuple): namespace_prefix: tuple[str, ...] - query: Optional[str] = None filter: Optional[dict[str, Any]] = None - weights: Optional[list[Weight]] = None limit: int = 10 offset: int = 0 @@ -81,15 +79,13 @@ def search( namespace_prefix: tuple[str, ...], /, *, - query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, - weights: Optional[dict[str, float]] = None, limit: int = 10, offset: int = 0, ) -> list[Item]: return self.batch( [ - SearchOp(namespace_prefix, query, filter, weights, limit, offset), + SearchOp(namespace_prefix, filter, limit, offset), ] )[0] @@ -120,16 +116,14 @@ async def asearch( namespace_prefix: tuple[str, ...], /, *, - query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, - weights: Optional[dict[str, float]] = None, limit: int = 10, offset: int = 0, ) -> list[Item]: return ( await self.abatch( [ - SearchOp(namespace_prefix, query, filter, weights, limit, offset), + SearchOp(namespace_prefix, filter, limit, offset), ] ) )[0] diff --git a/libs/langgraph/langgraph/store/memory.py b/libs/langgraph/langgraph/store/memory.py index 1dd51d2b6..2b16c893f 100644 --- a/libs/langgraph/langgraph/store/memory.py +++ b/libs/langgraph/langgraph/store/memory.py @@ -29,16 +29,12 @@ def batch(self, ops: list[Op]) -> list[Result]: ) for item in items.values() ] - if op.query is not None: - raise NotImplementedError("Search queries are not supported") if op.filter: candidates = [ item for item in candidates if item.value.items() >= op.filter.items() ] - if op.weights: - raise NotImplementedError("Search weights are not supported") results.append(candidates[op.offset : op.offset + op.limit]) elif isinstance(op, PutOp): if op.value is None: From 3ee8758bf9825c18dd692ef5d3a3bd536f95d471 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:52:19 -0700 Subject: [PATCH 13/26] Update test signature --- libs/langgraph/tests/test_pregel.py | 8 +++----- libs/langgraph/tests/test_pregel_async.py | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 6bfe225de..2e55946e2 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -11368,14 +11368,14 @@ def node(input: State, config: RunnableConfig, store: BaseStore): assert result == {"count": 1} returned_doc = the_store.get(("foo", "bar"), doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} - assert len(the_store.search(("foo", "bar"), doc_id)) == 1 + assert len(the_store.search(("foo", "bar"))) == 1 # Check update on existing thread result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) assert result == {"count": 2} returned_doc = the_store.get(("foo", "bar"), doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} - assert len(the_store.search(("foo", "bar"), doc_id)) == 1 + assert len(the_store.search(("foo", "bar"))) == 1 thread_2 = str(uuid.uuid4()) @@ -11387,6 +11387,4 @@ def node(input: State, config: RunnableConfig, store: BaseStore): "from_thread": thread_2, "some_val": 1, } # Overwrites the whole doc - assert ( - len(the_store.search(("foo", "bar"), doc_id)) == 1 - ) # still overwriting the same one + assert len(the_store.search(("foo", "bar"))) == 1 # still overwriting the same one diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index f620304b2..bd8e03b0d 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -9646,14 +9646,14 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): assert result == {"count": 1} returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} - assert len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + assert len((await the_store.asearch(("foo", "bar")))) == 1 # Check update on existing thread result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) assert result == {"count": 2} returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} - assert len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + assert len((await the_store.asearch(("foo", "bar")))) == 1 thread_2 = str(uuid.uuid4()) @@ -9666,5 +9666,5 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): "some_val": 1, } # Overwrites the whole doc assert ( - len((await the_store.asearch(("foo", "bar"), doc_id))) == 1 + len((await the_store.asearch(("foo", "bar")))) == 1 ) # still overwriting the same one From 7225cc278f829f6c88628c146a19affd3d816b9c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 26 Sep 2024 10:56:20 -0700 Subject: [PATCH 14/26] Lint --- libs/langgraph/langgraph/pregel/__init__.py | 1 + libs/langgraph/langgraph/store/batch.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index bee52f455..3d7414494 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1071,6 +1071,7 @@ def _defaults( Union[All, Sequence[str]], Union[All, Sequence[str]], Optional[BaseCheckpointSaver], + Optional[BaseStore], ]: if config["recursion_limit"] < 1: raise ValueError("recursion_limit must be at least 1") diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/langgraph/langgraph/store/batch.py index 75f73915e..3f0d85101 100644 --- a/libs/langgraph/langgraph/store/batch.py +++ b/libs/langgraph/langgraph/store/batch.py @@ -37,9 +37,7 @@ async def asearch( offset: int = 0, ) -> list[Item]: fut = self._loop.create_future() - self._aqueue[fut] = SearchOp( - namespace_prefix, query, filter, weights, limit, offset - ) + self._aqueue[fut] = SearchOp(namespace_prefix, filter, limit, offset) return await fut async def aput( From 615b06b444dac25e5d48bade899a2c2bd8fc94c1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 26 Sep 2024 10:57:19 -0700 Subject: [PATCH 15/26] Remove --- libs/langgraph/langgraph/store/base.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/libs/langgraph/langgraph/store/base.py b/libs/langgraph/langgraph/store/base.py index d848dd321..4477f5eef 100644 --- a/libs/langgraph/langgraph/store/base.py +++ b/libs/langgraph/langgraph/store/base.py @@ -1,16 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence, Union - -SCORE_RECENCY = "recency" -SCORE_RELEVANCE = "relevance" - - -class Weight(NamedTuple): - field: Union[str, Literal["recency"], Literal["relevance"]] - weight: float - default: float = 0.0 +from typing import Any, Iterable, NamedTuple, Optional, Sequence, Union @dataclass From f97db1b4ef033b276523ccd971069eea043b9f3a Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:01:46 -0700 Subject: [PATCH 16/26] Update etst --- libs/langgraph/tests/test_pregel.py | 16 ++--- libs/langgraph/tests/test_pregel_async.py | 77 ++++++++++++----------- 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 2e55946e2..7cfaf5335 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -11357,21 +11357,21 @@ def node(input: State, config: RunnableConfig, store: BaseStore): ) return {"count": 1} - graph = StateGraph(State) - graph.add_node("node", node) - graph.add_edge("__start__", "node") + builder = StateGraph(State) + builder.add_node("node", node) + builder.add_edge("__start__", "node") the_store = MemoryStore() - app = graph.compile(store=the_store, checkpointer=checkpointer) + graph = builder.compile(store=the_store, checkpointer=checkpointer) thread_1 = str(uuid.uuid4()) - result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) assert result == {"count": 1} returned_doc = the_store.get(("foo", "bar"), doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} assert len(the_store.search(("foo", "bar"))) == 1 # Check update on existing thread - result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) + result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) assert result == {"count": 2} returned_doc = the_store.get(("foo", "bar"), doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} @@ -11379,12 +11379,12 @@ def node(input: State, config: RunnableConfig, store: BaseStore): thread_2 = str(uuid.uuid4()) - result = app.invoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) + result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) assert result == {"count": 1} returned_doc = the_store.get(("foo", "bar"), doc_id).value assert returned_doc == { **doc, "from_thread": thread_2, - "some_val": 1, + "some_val": 0, } # Overwrites the whole doc assert len(the_store.search(("foo", "bar"))) == 1 # still overwriting the same one diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index bd8e03b0d..6dfdff955 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -9610,11 +9610,7 @@ def __call__(self, state): @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) -async def test_store_injected_async( - request: pytest.FixtureRequest, checkpointer_name: str -) -> None: - checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") - +async def test_store_injected_async(checkpointer_name: str) -> None: class State(TypedDict): count: Annotated[int, operator.add] @@ -9635,36 +9631,43 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): ) return {"count": 1} - graph = StateGraph(State) - graph.add_node("node", node) - graph.add_edge("__start__", "node") + builder = StateGraph(State) + builder.add_node("node", node) + builder.add_edge("__start__", "node") the_store = MemoryStore() - app = graph.compile(store=the_store, checkpointer=checkpointer) - - thread_1 = str(uuid.uuid4()) - result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) - assert result == {"count": 1} - returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value - assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} - assert len((await the_store.asearch(("foo", "bar")))) == 1 - - # Check update on existing thread - result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) - assert result == {"count": 2} - returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value - assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} - assert len((await the_store.asearch(("foo", "bar")))) == 1 - - thread_2 = str(uuid.uuid4()) - - result = await app.ainvoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) - assert result == {"count": 1} - returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value - assert returned_doc == { - **doc, - "from_thread": thread_2, - "some_val": 1, - } # Overwrites the whole doc - assert ( - len((await the_store.asearch(("foo", "bar")))) == 1 - ) # still overwriting the same one + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(store=the_store, checkpointer=checkpointer) + + thread_1 = str(uuid.uuid4()) + result = await graph.ainvoke( + {"count": 0}, {"configurable": {"thread_id": thread_1}} + ) + assert result == {"count": 1} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} + assert len((await the_store.asearch(("foo", "bar")))) == 1 + + # Check update on existing thread + result = await graph.ainvoke( + {"count": 0}, {"configurable": {"thread_id": thread_1}} + ) + assert result == {"count": 2} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 1} + assert len((await the_store.asearch(("foo", "bar")))) == 1 + + thread_2 = str(uuid.uuid4()) + + result = await graph.ainvoke( + {"count": 0}, {"configurable": {"thread_id": thread_2}} + ) + assert result == {"count": 1} + returned_doc = (await the_store.aget(("foo", "bar"), doc_id)).value + assert returned_doc == { + **doc, + "from_thread": thread_2, + "some_val": 0, + } # Overwrites the whole doc + assert ( + len((await the_store.asearch(("foo", "bar")))) == 1 + ) # still overwriting the same one From 944676a7a514310a8cc2ad98c6cc9c1cb7f1fd61 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:47:53 -0700 Subject: [PATCH 17/26] Update kafka test --- libs/scheduler-kafka/tests/test_subgraph.py | 6 ++++++ libs/scheduler-kafka/tests/test_subgraph_sync.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/libs/scheduler-kafka/tests/test_subgraph.py b/libs/scheduler-kafka/tests/test_subgraph.py index 54172ef9b..fd2530843 100644 --- a/libs/scheduler-kafka/tests/test_subgraph.py +++ b/libs/scheduler-kafka/tests/test_subgraph.py @@ -194,6 +194,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + '__pregel_store': None, "__pregel_task_id": history[0].tasks[0].id, "checkpoint_id": None, "checkpoint_map": { @@ -257,6 +258,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + '__pregel_store': None, "__pregel_task_id": history[0].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { @@ -350,6 +352,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + '__pregel_store': None, "__pregel_task_id": history[0].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { @@ -453,6 +456,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + '__pregel_store': None, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": None, "checkpoint_map": { @@ -511,6 +515,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + '__pregel_store': None, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { @@ -625,6 +630,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + '__pregel_store': None, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { diff --git a/libs/scheduler-kafka/tests/test_subgraph_sync.py b/libs/scheduler-kafka/tests/test_subgraph_sync.py index 16576b343..32d0ceea0 100644 --- a/libs/scheduler-kafka/tests/test_subgraph_sync.py +++ b/libs/scheduler-kafka/tests/test_subgraph_sync.py @@ -193,6 +193,7 @@ def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + '__pregel_store': None, "__pregel_task_id": history[0].tasks[0].id, "checkpoint_id": None, "checkpoint_map": { @@ -254,6 +255,7 @@ def test_subgraph_w_interrupt( "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, + '__pregel_store': None, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_task_id": history[0].tasks[0].id, @@ -348,6 +350,7 @@ def test_subgraph_w_interrupt( "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, + '__pregel_store': None, "__pregel_resuming": False, "__pregel_task_id": history[0].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], @@ -450,6 +453,7 @@ def test_subgraph_w_interrupt( "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, + '__pregel_store': None, "__pregel_resuming": True, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": None, @@ -508,6 +512,7 @@ def test_subgraph_w_interrupt( "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, + '__pregel_store': None, "__pregel_resuming": True, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], @@ -623,6 +628,7 @@ def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + '__pregel_store': None, "__pregel_task_id": history[1].tasks[0].id, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { From a6b4c0c6df3525b2d3feb720893844b023ff37fb Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:04:44 -0700 Subject: [PATCH 18/26] Mv to checkpoint lib --- libs/{langgraph => checkpoint}/langgraph/store/base.py | 0 libs/{langgraph => checkpoint}/langgraph/store/batch.py | 0 libs/{langgraph => checkpoint}/langgraph/store/memory.py | 0 libs/{langgraph => checkpoint}/tests/test_store.py | 0 libs/langgraph/langgraph/store/__init__.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename libs/{langgraph => checkpoint}/langgraph/store/base.py (100%) rename libs/{langgraph => checkpoint}/langgraph/store/batch.py (100%) rename libs/{langgraph => checkpoint}/langgraph/store/memory.py (100%) rename libs/{langgraph => checkpoint}/tests/test_store.py (100%) delete mode 100644 libs/langgraph/langgraph/store/__init__.py diff --git a/libs/langgraph/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py similarity index 100% rename from libs/langgraph/langgraph/store/base.py rename to libs/checkpoint/langgraph/store/base.py diff --git a/libs/langgraph/langgraph/store/batch.py b/libs/checkpoint/langgraph/store/batch.py similarity index 100% rename from libs/langgraph/langgraph/store/batch.py rename to libs/checkpoint/langgraph/store/batch.py diff --git a/libs/langgraph/langgraph/store/memory.py b/libs/checkpoint/langgraph/store/memory.py similarity index 100% rename from libs/langgraph/langgraph/store/memory.py rename to libs/checkpoint/langgraph/store/memory.py diff --git a/libs/langgraph/tests/test_store.py b/libs/checkpoint/tests/test_store.py similarity index 100% rename from libs/langgraph/tests/test_store.py rename to libs/checkpoint/tests/test_store.py diff --git a/libs/langgraph/langgraph/store/__init__.py b/libs/langgraph/langgraph/store/__init__.py deleted file mode 100644 index e69de29bb..000000000 From ba7f7063892d321d90837664b967b5d341c8d12c Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:18:47 -0700 Subject: [PATCH 19/26] Move to checkpoint package --- libs/checkpoint/langgraph/store/base.py | 168 ++++++++++++++-------- libs/checkpoint/langgraph/store/batch.py | 2 + libs/checkpoint/langgraph/store/memory.py | 8 +- libs/langgraph/tests/test_pregel.py | 8 +- libs/langgraph/tests/test_pregel_async.py | 8 +- 5 files changed, 129 insertions(+), 65 deletions(-) diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index 4477f5eef..dd65353b2 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -1,3 +1,9 @@ +"""Base classes and types for persistent key-value stores. + +Stores enable persistence and memory that can be shared across threads, +scoped to user IDs, assistant IDs, or other arbitrary namespaces. +""" + from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime @@ -6,33 +12,55 @@ @dataclass class Item: + """Represents a stored item with metadata.""" + value: dict[str, Any] - # search metadata + """The stored data.""" scores: dict[str, float] - # item metadata + """Relevance scores for the item.""" id: str + """Unique identifier within the namespace.""" namespace: tuple[str, ...] + """Hierarchical path for organizing items.""" created_at: datetime + """Timestamp of item creation.""" updated_at: datetime + """Timestamp of last update.""" last_accessed_at: datetime + """Timestamp of last access.""" class GetOp(NamedTuple): + """Operation to retrieve an item by namespace and ID.""" + namespace: tuple[str, ...] + """Hierarchical path for the item.""" id: str + """Unique identifier within the namespace.""" class SearchOp(NamedTuple): + """Operation to search for items within a namespace prefix.""" + namespace_prefix: tuple[str, ...] + """Hierarchical path prefix to search within.""" filter: Optional[dict[str, Any]] = None + """Key-value pairs to filter results.""" limit: int = 10 + """Maximum number of items to return.""" offset: int = 0 + """Number of items to skip before returning results.""" class PutOp(NamedTuple): + """Operation to store or update an item.""" + namespace: tuple[str, ...] + """Hierarchical path for the item.""" id: str + """Unique identifier within the namespace.""" value: Optional[dict[str, Any]] + """Data to be stored, or None to delete.""" Op = Union[GetOp, SearchOp, PutOp] @@ -40,29 +68,28 @@ class PutOp(NamedTuple): class BaseStore(ABC): - __slots__ = ("__weakref__",) + """Abstract base class for key-value stores.""" - # abstract methods + __slots__ = ("__weakref__",) @abstractmethod - def batch( - self, - ops: Iterable[Op], - ) -> Sequence[Result]: ... + def batch(self, ops: Iterable[Op]) -> Sequence[Result]: + """Execute a batch of operations synchronously.""" @abstractmethod - async def abatch( - self, - ops: Iterable[Op], - ) -> Sequence[Result]: ... + async def abatch(self, ops: Iterable[Op]) -> Sequence[Result]: + """Execute a batch of operations asynchronously.""" - # convenience methods + def get(self, namespace: tuple[str, ...], id: str) -> Optional[Item]: + """Retrieve a single item. - def get( - self, - namespace: tuple[str, ...], - id: str, - ) -> Optional[Item]: + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + + Returns: + The retrieved item or None if not found. + """ return self.batch([GetOp(namespace, id)])[0] def search( @@ -74,32 +101,48 @@ def search( limit: int = 10, offset: int = 0, ) -> list[Item]: - return self.batch( - [ - SearchOp(namespace_prefix, filter, limit, offset), - ] - )[0] - - def put( - self, - namespace: tuple[str, ...], - id: str, - value: dict[str, Any], - ) -> None: + """Search for items within a namespace prefix. + + Args: + namespace_prefix: Hierarchical path prefix to search within. + filter: Key-value pairs to filter results. + limit: Maximum number of items to return. + offset: Number of items to skip before returning results. + + Returns: + List of items matching the search criteria. + """ + return self.batch([SearchOp(namespace_prefix, filter, limit, offset)])[0] + + def put(self, namespace: tuple[str, ...], id: str, value: dict[str, Any]) -> None: + """Store or update an item. + + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + value: Dictionary containing the item's data. + """ self.batch([PutOp(namespace, id, value)]) - def delete( - self, - namespace: tuple[str, ...], - id: str, - ) -> None: + def delete(self, namespace: tuple[str, ...], id: str) -> None: + """Delete an item. + + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + """ self.batch([PutOp(namespace, id, None)]) - async def aget( - self, - namespace: tuple[str, ...], - id: str, - ) -> Optional[Item]: + async def aget(self, namespace: tuple[str, ...], id: str) -> Optional[Item]: + """Asynchronously retrieve a single item. + + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + + Returns: + The retrieved item or None if not found. + """ return (await self.abatch([GetOp(namespace, id)]))[0] async def asearch( @@ -111,25 +154,38 @@ async def asearch( limit: int = 10, offset: int = 0, ) -> list[Item]: - return ( - await self.abatch( - [ - SearchOp(namespace_prefix, filter, limit, offset), - ] - ) - )[0] + """Asynchronously search for items within a namespace prefix. + + Args: + namespace_prefix: Hierarchical path prefix to search within. + filter: Key-value pairs to filter results. + limit: Maximum number of items to return. + offset: Number of items to skip before returning results. + + Returns: + List of items matching the search criteria. + """ + return (await self.abatch([SearchOp(namespace_prefix, filter, limit, offset)]))[ + 0 + ] async def aput( - self, - namespace: tuple[str, ...], - id: str, - value: dict[str, Any], + self, namespace: tuple[str, ...], id: str, value: dict[str, Any] ) -> None: + """Asynchronously store or update an item. + + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + value: Dictionary containing the item's data. + """ await self.abatch([PutOp(namespace, id, value)]) - async def adelete( - self, - namespace: tuple[str, ...], - id: str, - ) -> None: + async def adelete(self, namespace: tuple[str, ...], id: str) -> None: + """Asynchronously delete an item. + + Args: + namespace: Hierarchical path for the item. + id: Unique identifier within the namespace. + """ await self.abatch([PutOp(namespace, id, None)]) diff --git a/libs/checkpoint/langgraph/store/batch.py b/libs/checkpoint/langgraph/store/batch.py index 3f0d85101..148acb8d1 100644 --- a/libs/checkpoint/langgraph/store/batch.py +++ b/libs/checkpoint/langgraph/store/batch.py @@ -6,6 +6,8 @@ class AsyncBatchedBaseStore(BaseStore): + """Efficiently batch operations in a background task.""" + __slots__ = ("_loop", "_aqueue", "_task") def __init__(self) -> None: diff --git a/libs/checkpoint/langgraph/store/memory.py b/libs/checkpoint/langgraph/store/memory.py index 2b16c893f..8b0807989 100644 --- a/libs/checkpoint/langgraph/store/memory.py +++ b/libs/checkpoint/langgraph/store/memory.py @@ -4,7 +4,13 @@ from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, Result, SearchOp -class MemoryStore(BaseStore): +class InMemoryStore(BaseStore): + """A KV store backed by an in-memory python dictionary. + + Useful for testing/experimentation and lightweight PoC's. + For actual persistence, use a Store backed by a proper database. + """ + __slots__ = ("_data",) def __init__(self) -> None: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 7cfaf5335..a7125fc43 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -72,7 +72,7 @@ ) from langgraph.pregel.retry import RetryPolicy from langgraph.store.base import BaseStore -from langgraph.store.memory import MemoryStore +from langgraph.store.memory import InMemoryStore from langgraph.types import Interrupt, PregelTask, Send, StreamWriter from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence from tests.conftest import ALL_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS @@ -6808,7 +6808,7 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: } tool_two = tool_two_graph.compile( - store=MemoryStore(), + store=InMemoryStore(), checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"], ) @@ -11345,7 +11345,7 @@ class State(TypedDict): def node(input: State, config: RunnableConfig, store: BaseStore): assert isinstance(store, BaseStore) - assert isinstance(store, MemoryStore) + assert isinstance(store, InMemoryStore) store.put( ("foo", "bar"), doc_id, @@ -11360,7 +11360,7 @@ def node(input: State, config: RunnableConfig, store: BaseStore): builder = StateGraph(State) builder.add_node("node", node) builder.add_edge("__start__", "node") - the_store = MemoryStore() + the_store = InMemoryStore() graph = builder.compile(store=the_store, checkpointer=checkpointer) thread_1 = str(uuid.uuid4()) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 6dfdff955..1ba0335ce 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -61,7 +61,7 @@ from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot from langgraph.pregel.retry import RetryPolicy from langgraph.store.base import BaseStore -from langgraph.store.memory import MemoryStore +from langgraph.store.memory import InMemoryStore from langgraph.types import Interrupt, PregelTask, Send, StreamWriter from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence from tests.conftest import ( @@ -5416,7 +5416,7 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: async with awith_checkpointer(checkpointer_name) as checkpointer: tool_two = tool_two_graph.compile( - store=MemoryStore(), + store=InMemoryStore(), checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"], ) @@ -9619,7 +9619,7 @@ class State(TypedDict): async def node(input: State, config: RunnableConfig, store: BaseStore): assert isinstance(store, BaseStore) - assert isinstance(store, MemoryStore) + assert isinstance(store, InMemoryStore) await store.aput( ("foo", "bar"), doc_id, @@ -9634,7 +9634,7 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): builder = StateGraph(State) builder.add_node("node", node) builder.add_edge("__start__", "node") - the_store = MemoryStore() + the_store = InMemoryStore() async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(store=the_store, checkpointer=checkpointer) From 99a88eb7304981518025995bfb8c73416accf5d0 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:25:41 -0700 Subject: [PATCH 20/26] Lint fixes --- libs/checkpoint/langgraph/store/memory.py | 5 +++-- libs/checkpoint/tests/test_store.py | 12 +++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/libs/checkpoint/langgraph/store/memory.py b/libs/checkpoint/langgraph/store/memory.py index 8b0807989..2d002de9d 100644 --- a/libs/checkpoint/langgraph/store/memory.py +++ b/libs/checkpoint/langgraph/store/memory.py @@ -1,5 +1,6 @@ from collections import defaultdict from datetime import datetime, timezone +from typing import Iterable from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, Result, SearchOp @@ -16,7 +17,7 @@ class InMemoryStore(BaseStore): def __init__(self) -> None: self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) - def batch(self, ops: list[Op]) -> list[Result]: + def batch(self, ops: Iterable[Op]) -> list[Result]: results: list[Result] = [] for op in ops: if isinstance(op, GetOp): @@ -63,5 +64,5 @@ def batch(self, ops: list[Op]) -> list[Result]: results.append(None) return results - async def abatch(self, ops: list[Op]) -> list[Result]: + async def abatch(self, ops: Iterable[Op]) -> list[Result]: return self.batch(ops) diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index 807a5798e..fecc54dc1 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -1,31 +1,29 @@ import asyncio from datetime import datetime +from typing import Iterable -import pytest from pytest_mock import MockerFixture from langgraph.store.base import GetOp, Item, Op, Result from langgraph.store.batch import AsyncBatchedBaseStore -pytestmark = pytest.mark.anyio - async def test_async_batch_store(mocker: MockerFixture) -> None: abatch = mocker.stub() class MockStore(AsyncBatchedBaseStore): - def batch(self, ops: list[Op]) -> list[Result]: + def batch(self, ops: Iterable[Op]) -> list[Result]: raise NotImplementedError - async def abatch(self, ops: list[Op]) -> list[Result]: + async def abatch(self, ops: Iterable[Op]) -> list[Result]: assert all(isinstance(op, GetOp) for op in ops) abatch(ops) return [ Item( value={}, scores={}, - id=op.id, - namespace=op.namespace, + id=getattr(op, "id", ""), + namespace=getattr(op, "namespace", ()), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), last_accessed_at=datetime(2024, 9, 24, 17, 29, 10, 128397), From d309f580424af4bb44f30d8209bab91e6cb5e688 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:39:52 -0700 Subject: [PATCH 21/26] py.typed --- libs/checkpoint/langgraph/store/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 libs/checkpoint/langgraph/store/py.typed diff --git a/libs/checkpoint/langgraph/store/py.typed b/libs/checkpoint/langgraph/store/py.typed new file mode 100644 index 000000000..e69de29bb From 3daae2db62219d9186e51f69ea9c3fb9c61ca061 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:47:25 -0700 Subject: [PATCH 22/26] Update --- libs/checkpoint/langgraph/store/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index dd65353b2..33c40a0cb 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Iterable, NamedTuple, Optional, Sequence, Union +from typing import Any, Iterable, NamedTuple, Optional, Union @dataclass @@ -73,11 +73,11 @@ class BaseStore(ABC): __slots__ = ("__weakref__",) @abstractmethod - def batch(self, ops: Iterable[Op]) -> Sequence[Result]: + def batch(self, ops: Iterable[Op]) -> list[Result]: """Execute a batch of operations synchronously.""" @abstractmethod - async def abatch(self, ops: Iterable[Op]) -> Sequence[Result]: + async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute a batch of operations asynchronously.""" def get(self, namespace: tuple[str, ...], id: str) -> Optional[Item]: From b9cf088ff4b2afab7c5d655964c5893b8339ec9e Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:57:47 -0700 Subject: [PATCH 23/26] doc --- libs/checkpoint/langgraph/store/base.py | 48 +++++++++++++++++++++---- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index 33c40a0cb..4f0aac7c9 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -15,17 +15,35 @@ class Item: """Represents a stored item with metadata.""" value: dict[str, Any] - """The stored data.""" + """The stored data as a dictionary. + + Keys are filterable. + """ + scores: dict[str, float] - """Relevance scores for the item.""" + """Relevance scores for the item. + + Keys can include built-in scores like 'recency' and 'relevance', + as well as any key present in the 'value' dictionary. This allows + for multi-dimensional scoring of items. + """ + id: str """Unique identifier within the namespace.""" + namespace: tuple[str, ...] - """Hierarchical path for organizing items.""" + """Hierarchical path defining the collection in which this document resides. + + Represented as a tuple of strings, allowing for nested categorization. + For example: ("documents", 'user123') + """ + created_at: datetime """Timestamp of item creation.""" + updated_at: datetime """Timestamp of last update.""" + last_accessed_at: datetime """Timestamp of last access.""" @@ -53,14 +71,30 @@ class SearchOp(NamedTuple): class PutOp(NamedTuple): - """Operation to store or update an item.""" + """Operation to store, update, or delete an item.""" namespace: tuple[str, ...] - """Hierarchical path for the item.""" + """Hierarchical path for the item. + + Represented as a tuple of strings, allowing for nested categorization. + For example: ("documents", "user123") + """ + id: str - """Unique identifier within the namespace.""" + """Unique identifier for the document. + + Should be distinct within its namespace. + """ + value: Optional[dict[str, Any]] - """Data to be stored, or None to delete.""" + """Data to be stored, or None to delete the item. + + Schema: + - Should be a dictionary where: + - Keys are strings representing field names + - Values can be of any serializable type + - If None, it indicates that the item should be deleted + """ Op = Union[GetOp, SearchOp, PutOp] From c352b608b6af80fddd4a9e199a1e708c73ab9dac Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:31:12 -0700 Subject: [PATCH 24/26] Add namespace listing (#1889) --- .../langgraph/checkpoint/serde/jsonplus.py | 13 + libs/checkpoint/langgraph/store/base.py | 264 ++++++++++++++---- libs/checkpoint/langgraph/store/batch.py | 14 +- libs/checkpoint/langgraph/store/memory.py | 75 ++++- libs/checkpoint/tests/test_jsonplus.py | 8 + libs/checkpoint/tests/test_store.py | 232 +++++++++++++-- .../langgraph/managed/shared_value.py | 4 +- 7 files changed, 512 insertions(+), 98 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py index 707442d34..30276e47e 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py @@ -26,6 +26,7 @@ from langgraph.checkpoint.serde.base import SerializerProtocol from langgraph.checkpoint.serde.types import SendProtocol +from langgraph.store.base import Item LC_REVIVER = Reviver() @@ -403,6 +404,18 @@ def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]: ), ), ) + elif isinstance(obj, Item): + return msgpack.ExtType( + EXT_CONSTRUCTOR_KW_ARGS, + _msgpack_enc( + ( + obj.__class__.__module__, + obj.__class__.__name__, + {k: getattr(obj, k) for k in obj.__slots__}, + ), + ), + ) + elif isinstance(obj, BaseException): return repr(obj) else: diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index 4f0aac7c9..c4b6bd552 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -5,55 +5,61 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass from datetime import datetime -from typing import Any, Iterable, NamedTuple, Optional, Union +from typing import Any, Iterable, Literal, NamedTuple, Optional, Union -@dataclass class Item: - """Represents a stored item with metadata.""" - - value: dict[str, Any] - """The stored data as a dictionary. - - Keys are filterable. + """Represents a stored item with metadata. + + Args: + value (dict[str, Any]): The stored data as a dictionary. Keys are filterable. + (str): Unique identifier within the namespace. + namespace (tuple[str, ...]): Hierarchical path defining the collection in which this document resides. + Represented as a tuple of strings, allowing for nested categorization. + For example: ("documents", 'user123') + created_at (datetime): Timestamp of item creation. + updated_at (datetime): Timestamp of last update. """ - scores: dict[str, float] - """Relevance scores for the item. - - Keys can include built-in scores like 'recency' and 'relevance', - as well as any key present in the 'value' dictionary. This allows - for multi-dimensional scoring of items. - """ - - id: str - """Unique identifier within the namespace.""" - - namespace: tuple[str, ...] - """Hierarchical path defining the collection in which this document resides. - - Represented as a tuple of strings, allowing for nested categorization. - For example: ("documents", 'user123') - """ - - created_at: datetime - """Timestamp of item creation.""" - - updated_at: datetime - """Timestamp of last update.""" + __slots__ = ("value", "key", "namespace", "created_at", "updated_at") - last_accessed_at: datetime - """Timestamp of last access.""" + def __init__( + self, + *, + value: dict[str, Any], + key: str, + namespace: tuple[str, ...], + created_at: datetime, + updated_at: datetime, + ): + self.value = value + self.key = key + self.namespace = tuple(namespace) + self.created_at = created_at + self.updated_at = updated_at + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Item): + return False + return ( + self.value == other.value + and self.key == other.key + and self.namespace == other.namespace + and self.created_at == other.created_at + and self.updated_at == other.updated_at + ) + + def __hash__(self) -> int: + return hash((self.namespace, self.key)) class GetOp(NamedTuple): - """Operation to retrieve an item by namespace and ID.""" + """Operation to retrieve an item by namespace and key.""" namespace: tuple[str, ...] """Hierarchical path for the item.""" - id: str + key: str """Unique identifier within the namespace.""" @@ -80,7 +86,7 @@ class PutOp(NamedTuple): For example: ("documents", "user123") """ - id: str + key: str """Unique identifier for the document. Should be distinct within its namespace. @@ -97,8 +103,48 @@ class PutOp(NamedTuple): """ -Op = Union[GetOp, SearchOp, PutOp] -Result = Union[Item, list[Item], None] +NameSpacePath = tuple[Union[str, Literal["*"]], ...] + +NamespaceMatchType = Literal["prefix", "suffix"] + + +class MatchCondition(NamedTuple): + """Represents a single match condition.""" + + match_type: NamespaceMatchType + path: NameSpacePath + + +class ListNamespacesOp(NamedTuple): + """Operation to list namespaces with optional match conditions.""" + + match_conditions: Optional[tuple[MatchCondition, ...]] = None + """A tuple of match conditions to apply to namespaces.""" + + max_depth: Optional[int] = None + """Return namespaces up to this depth in the hierarchy.""" + + limit: int = 100 + """Maximum number of namespaces to return.""" + + offset: int = 0 + """Number of namespaces to skip before returning results.""" + + +Op = Union[GetOp, SearchOp, PutOp, ListNamespacesOp] +Result = Union[Item, list[Item], list[tuple[str, ...]], None] + + +class InvalidNamespaceError(ValueError): + """Provided namespace is invalid.""" + + +def _validate_namespace(namespace: tuple[str, ...]) -> None: + for label in namespace: + if "." in label: + raise InvalidNamespaceError( + f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." + ) class BaseStore(ABC): @@ -114,17 +160,17 @@ def batch(self, ops: Iterable[Op]) -> list[Result]: async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute a batch of operations asynchronously.""" - def get(self, namespace: tuple[str, ...], id: str) -> Optional[Item]: + def get(self, namespace: tuple[str, ...], key: str) -> Optional[Item]: """Retrieve a single item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. Returns: The retrieved item or None if not found. """ - return self.batch([GetOp(namespace, id)])[0] + return self.batch([GetOp(namespace, key)])[0] def search( self, @@ -148,36 +194,88 @@ def search( """ return self.batch([SearchOp(namespace_prefix, filter, limit, offset)])[0] - def put(self, namespace: tuple[str, ...], id: str, value: dict[str, Any]) -> None: + def put(self, namespace: tuple[str, ...], key: str, value: dict[str, Any]) -> None: """Store or update an item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. value: Dictionary containing the item's data. """ - self.batch([PutOp(namespace, id, value)]) + _validate_namespace(namespace) + self.batch([PutOp(namespace, key, value)]) - def delete(self, namespace: tuple[str, ...], id: str) -> None: + def delete(self, namespace: tuple[str, ...], key: str) -> None: """Delete an item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. """ - self.batch([PutOp(namespace, id, None)]) + self.batch([PutOp(namespace, key, None)]) + + def list_namespaces( + self, + *, + prefix: Optional[NameSpacePath] = None, + suffix: Optional[NameSpacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + """List and filter namespaces in the store. + + Used to explore the organization of data, + find specific collections, or navigate the namespace hierarchy. + + Args: + prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path. + suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path. + max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy. + Namespaces deeper than this level will be truncated to this depth. + limit (int): Maximum number of namespaces to return (default 100). + offset (int): Number of namespaces to skip for pagination (default 0). - async def aget(self, namespace: tuple[str, ...], id: str) -> Optional[Item]: + Returns: + List[Tuple[str, ...]]: A list of namespace tuples that match the criteria. + Each tuple represents a full namespace path up to `max_depth`. + + Examples: + + Setting max_depth=3. Given the namespaces: + # ("a", "b", "c") + # ("a", "b", "d", "e") + # ("a", "b", "d", "i") + # ("a", "b", "f") + # ("a", "c", "f") + store.list_namespaces(prefix=("a", "b"), max_depth=3) + # [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] + """ + match_conditions = [] + if prefix: + match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) + if suffix: + match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) + + op = ListNamespacesOp( + match_conditions=tuple(match_conditions), + max_depth=max_depth, + limit=limit, + offset=offset, + ) + return self.batch([op])[0] + + async def aget(self, namespace: tuple[str, ...], key: str) -> Optional[Item]: """Asynchronously retrieve a single item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. Returns: The retrieved item or None if not found. """ - return (await self.abatch([GetOp(namespace, id)]))[0] + return (await self.abatch([GetOp(namespace, key)]))[0] async def asearch( self, @@ -204,22 +302,74 @@ async def asearch( ] async def aput( - self, namespace: tuple[str, ...], id: str, value: dict[str, Any] + self, namespace: tuple[str, ...], key: str, value: dict[str, Any] ) -> None: """Asynchronously store or update an item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. value: Dictionary containing the item's data. """ - await self.abatch([PutOp(namespace, id, value)]) + _validate_namespace(namespace) + await self.abatch([PutOp(namespace, key, value)]) - async def adelete(self, namespace: tuple[str, ...], id: str) -> None: + async def adelete(self, namespace: tuple[str, ...], key: str) -> None: """Asynchronously delete an item. Args: namespace: Hierarchical path for the item. - id: Unique identifier within the namespace. + key: Unique identifier within the namespace. + """ + await self.abatch([PutOp(namespace, key, None)]) + + async def alist_namespaces( + self, + *, + prefix: Optional[NameSpacePath] = None, + suffix: Optional[NameSpacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + """List and filter namespaces in the store asynchronously. + + Used to explore the organization of data, + find specific collections, or navigate the namespace hierarchy. + + Args: + prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path. + suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path. + max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy. + Namespaces deeper than this level will be truncated to this depth. + limit (int): Maximum number of namespaces to return (default 100). + offset (int): Number of namespaces to skip for pagination (default 0). + + Returns: + List[Tuple[str, ...]]: A list of namespace tuples that match the criteria. + Each tuple represents a full namespace path up to `max_depth`. + + Examples: + + Setting max_depth=3. Given the namespaces: + # ("a", "b", "c") + # ("a", "b", "d", "e") + # ("a", "b", "d", "i") + # ("a", "b", "f") + # ("a", "c", "f") + await store.alist_namespaces(prefix=("a", "b"), max_depth=3) + # [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] """ - await self.abatch([PutOp(namespace, id, None)]) + match_conditions = [] + if prefix: + match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) + if suffix: + match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) + + op = ListNamespacesOp( + match_conditions=tuple(match_conditions), + max_depth=max_depth, + limit=limit, + offset=offset, + ) + return (await self.abatch([op]))[0] diff --git a/libs/checkpoint/langgraph/store/batch.py b/libs/checkpoint/langgraph/store/batch.py index 148acb8d1..8283a7a66 100644 --- a/libs/checkpoint/langgraph/store/batch.py +++ b/libs/checkpoint/langgraph/store/batch.py @@ -21,10 +21,10 @@ def __del__(self) -> None: async def aget( self, namespace: tuple[str, ...], - id: str, + key: str, ) -> Optional[Item]: fut = self._loop.create_future() - self._aqueue[fut] = GetOp(namespace, id) + self._aqueue[fut] = GetOp(namespace, key) return await fut async def asearch( @@ -32,9 +32,7 @@ async def asearch( namespace_prefix: tuple[str, ...], /, *, - query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, - weights: Optional[dict[str, float]] = None, limit: int = 10, offset: int = 0, ) -> list[Item]: @@ -45,20 +43,20 @@ async def asearch( async def aput( self, namespace: tuple[str, ...], - id: str, + key: str, value: dict[str, Any], ) -> None: fut = self._loop.create_future() - self._aqueue[fut] = PutOp(namespace, id, value) + self._aqueue[fut] = PutOp(namespace, key, value) return await fut async def adelete( self, namespace: tuple[str, ...], - id: str, + key: str, ) -> None: fut = self._loop.create_future() - self._aqueue[fut] = PutOp(namespace, id, None) + self._aqueue[fut] = PutOp(namespace, key, None) return await fut diff --git a/libs/checkpoint/langgraph/store/memory.py b/libs/checkpoint/langgraph/store/memory.py index 2d002de9d..69a315096 100644 --- a/libs/checkpoint/langgraph/store/memory.py +++ b/libs/checkpoint/langgraph/store/memory.py @@ -2,7 +2,17 @@ from datetime import datetime, timezone from typing import Iterable -from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, Result, SearchOp +from langgraph.store.base import ( + BaseStore, + GetOp, + Item, + ListNamespacesOp, + MatchCondition, + Op, + PutOp, + Result, + SearchOp, +) class InMemoryStore(BaseStore): @@ -21,9 +31,7 @@ def batch(self, ops: Iterable[Op]) -> list[Result]: results: list[Result] = [] for op in ops: if isinstance(op, GetOp): - item = self._data[op.namespace].get(op.id) - if item is not None: - item.last_accessed_at = datetime.now(timezone.utc) + item = self._data[op.namespace].get(op.key) results.append(item) elif isinstance(op, SearchOp): candidates = [ @@ -45,24 +53,67 @@ def batch(self, ops: Iterable[Op]) -> list[Result]: results.append(candidates[op.offset : op.offset + op.limit]) elif isinstance(op, PutOp): if op.value is None: - self._data[op.namespace].pop(op.id, None) - elif op.id in self._data[op.namespace]: - self._data[op.namespace][op.id].value = op.value - self._data[op.namespace][op.id].updated_at = datetime.now( + self._data[op.namespace].pop(op.key, None) + elif op.key in self._data[op.namespace]: + self._data[op.namespace][op.key].value = op.value + self._data[op.namespace][op.key].updated_at = datetime.now( timezone.utc ) else: - self._data[op.namespace][op.id] = Item( + self._data[op.namespace][op.key] = Item( value=op.value, - scores={}, - id=op.id, + key=op.key, namespace=op.namespace, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), - last_accessed_at=datetime.now(timezone.utc), ) results.append(None) + elif isinstance(op, ListNamespacesOp): + results.append(self._handle_list_namespaces(op)) return results async def abatch(self, ops: Iterable[Op]) -> list[Result]: return self.batch(ops) + + def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]: + all_namespaces = list( + self._data.keys() + ) # Avoid collection size changing while iterating + namespaces = all_namespaces + if op.match_conditions: + namespaces = [ + ns + for ns in namespaces + if all(_does_match(condition, ns) for condition in op.match_conditions) + ] + + if op.max_depth is not None: + namespaces = sorted({ns[: op.max_depth] for ns in namespaces}) + else: + namespaces = sorted(namespaces) + return namespaces[op.offset : op.offset + op.limit] + + +def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool: + match_type = match_condition.match_type + path = match_condition.path + + if len(key) < len(path): + return False + + if match_type == "prefix": + for k_elem, p_elem in zip(key, path): + if p_elem == "*": + continue # Wildcard matches any element + if k_elem != p_elem: + return False + return True + elif match_type == "suffix": + for k_elem, p_elem in zip(reversed(key), reversed(path)): + if p_elem == "*": + continue # Wildcard matches any element + if k_elem != p_elem: + return False + return True + else: + raise ValueError(f"Unsupported match type: {match_type}") diff --git a/libs/checkpoint/tests/test_jsonplus.py b/libs/checkpoint/tests/test_jsonplus.py index e52531414..e504e80b1 100644 --- a/libs/checkpoint/tests/test_jsonplus.py +++ b/libs/checkpoint/tests/test_jsonplus.py @@ -15,6 +15,7 @@ from zoneinfo import ZoneInfo from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.store.base import Item class InnerPydantic(BaseModel): @@ -121,6 +122,13 @@ def test_serde_jsonplus() -> None: "a_float": 1.1, "a_bytes": b"my bytes", "a_bytearray": bytearray([42]), + "my_item": Item( + value={}, + key="my-key", + namespace=("a", "name", " "), + created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + ), } serde = JsonPlusSerializer() diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index fecc54dc1..317d9a720 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -6,6 +6,7 @@ from langgraph.store.base import GetOp, Item, Op, Result from langgraph.store.batch import AsyncBatchedBaseStore +from langgraph.store.memory import InMemoryStore async def test_async_batch_store(mocker: MockerFixture) -> None: @@ -21,12 +22,10 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: return [ Item( value={}, - scores={}, - id=getattr(op, "id", ""), + key=getattr(op, "key", ""), namespace=getattr(op, "namespace", ()), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), - last_accessed_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ) for op in ops ] @@ -35,27 +34,23 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: # concurrent calls are batched results = await asyncio.gather( - store.aget(namespace=("a",), id="b"), - store.aget(namespace=("c",), id="d"), + store.aget(namespace=("a",), key="b"), + store.aget(namespace=("c",), key="d"), ) assert results == [ Item( - {}, - {}, - "b", - ("a",), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), + value={}, + key="b", + namespace=("a",), + created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ), Item( - {}, - {}, - "d", - ("c",), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), - datetime(2024, 9, 24, 17, 29, 10, 128397), + value={}, + key="d", + namespace=("c",), + created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), + updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ), ] assert abatch.call_count == 1 @@ -65,3 +60,202 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: GetOp(("c",), "d"), ), ] + + +def test_list_namespaces_basic() -> None: + store = InMemoryStore() + + namespaces = [ + ("a", "b", "c"), + ("a", "b", "d", "e"), + ("a", "b", "d", "i"), + ("a", "b", "f"), + ("a", "c", "f"), + ("b", "a", "f"), + ("users", "123"), + ("users", "456", "settings"), + ("admin", "users", "789"), + ] + + for i, ns in enumerate(namespaces): + store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) + + result = store.list_namespaces(prefix=("a", "b")) + expected = [ + ("a", "b", "c"), + ("a", "b", "d", "e"), + ("a", "b", "d", "i"), + ("a", "b", "f"), + ] + assert sorted(result) == sorted(expected) + + result = store.list_namespaces(suffix=("f",)) + expected = [ + ("a", "b", "f"), + ("a", "c", "f"), + ("b", "a", "f"), + ] + assert sorted(result) == sorted(expected) + + result = store.list_namespaces(prefix=("a",), suffix=("f",)) + expected = [ + ("a", "b", "f"), + ("a", "c", "f"), + ] + assert sorted(result) == sorted(expected) + + # Test max_depth + result = store.list_namespaces(prefix=("a", "b"), max_depth=3) + expected = [ + ("a", "b", "c"), + ("a", "b", "d"), + ("a", "b", "f"), + ] + assert sorted(result) == sorted(expected) + + # Test limit and offset + result = store.list_namespaces(prefix=("a", "b"), limit=2) + expected = [ + ("a", "b", "c"), + ("a", "b", "d", "e"), + ] + assert result == expected + + result = store.list_namespaces(prefix=("a", "b"), offset=2) + expected = [ + ("a", "b", "d", "i"), + ("a", "b", "f"), + ] + assert result == expected + + result = store.list_namespaces(prefix=("a", "*", "f")) + expected = [ + ("a", "b", "f"), + ("a", "c", "f"), + ] + assert sorted(result) == sorted(expected) + + result = store.list_namespaces(suffix=("*", "f")) + expected = [ + ("a", "b", "f"), + ("a", "c", "f"), + ("b", "a", "f"), + ] + assert sorted(result) == sorted(expected) + + result = store.list_namespaces(prefix=("nonexistent",)) + assert result == [] + + result = store.list_namespaces(prefix=("users", "123")) + expected = [("users", "123")] + assert result == expected + + +def test_list_namespaces_with_wildcards() -> None: + store = InMemoryStore() + + namespaces = [ + ("users", "123"), + ("users", "456"), + ("users", "789", "settings"), + ("admin", "users", "789"), + ("guests", "123"), + ("guests", "456", "preferences"), + ] + + for i, ns in enumerate(namespaces): + store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) + + result = store.list_namespaces(prefix=("users", "*")) + expected = [ + ("users", "123"), + ("users", "456"), + ("users", "789", "settings"), + ] + assert sorted(result) == sorted(expected) + + result = store.list_namespaces(suffix=("*", "preferences")) + expected = [ + ("guests", "456", "preferences"), + ] + assert result == expected + + result = store.list_namespaces(prefix=("*", "users"), suffix=("*", "settings")) + + assert result == [] + + store.put( + namespace=("admin", "users", "settings", "789"), + key="foo", + value={"data": "some_val"}, + ) + expected = [ + ("admin", "users", "settings", "789"), + ] + + +def test_list_namespaces_pagination() -> None: + store = InMemoryStore() + + for i in range(20): + ns = ("namespace", f"sub_{i:02d}") + store.put(namespace=ns, key=f"id_{i:02d}", value={"data": f"value_{i:02d}"}) + + result = store.list_namespaces(prefix=("namespace",), limit=5, offset=0) + expected = [("namespace", f"sub_{i:02d}") for i in range(5)] + assert result == expected + + result = store.list_namespaces(prefix=("namespace",), limit=5, offset=5) + expected = [("namespace", f"sub_{i:02d}") for i in range(5, 10)] + assert result == expected + + result = store.list_namespaces(prefix=("namespace",), limit=5, offset=15) + expected = [("namespace", f"sub_{i:02d}") for i in range(15, 20)] + assert result == expected + + +def test_list_namespaces_max_depth() -> None: + store = InMemoryStore() + + namespaces = [ + ("a", "b", "c", "d"), + ("a", "b", "c", "e"), + ("a", "b", "f"), + ("a", "g"), + ("h", "i", "j", "k"), + ] + + for i, ns in enumerate(namespaces): + store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) + + result = store.list_namespaces(max_depth=2) + expected = [ + ("a", "b"), + ("a", "g"), + ("h", "i"), + ] + assert sorted(result) == sorted(expected) + + +def test_list_namespaces_no_conditions() -> None: + store = InMemoryStore() + + namespaces = [ + ("a", "b"), + ("c", "d"), + ("e", "f", "g"), + ] + + for i, ns in enumerate(namespaces): + store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) + + result = store.list_namespaces() + expected = namespaces + assert sorted(result) == sorted(expected) + + +def test_list_namespaces_empty_store() -> None: + store = InMemoryStore() + + result = store.list_namespaces() + assert result == [] diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index 12a15fc93..7c8f45b14 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -59,7 +59,7 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: with super().enter(config, **kwargs) as value: if value.store is not None: saved = value.store.search(value.ns) - value.value = {it.id: it.value for it in saved} + value.value = {it.key: it.value for it in saved} yield value @classmethod @@ -68,7 +68,7 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se async with super().aenter(config, **kwargs) as value: if value.store is not None: saved = await value.store.asearch(value.ns) - value.value = {it.id: it.value for it in saved} + value.value = {it.key: it.value for it in saved} yield value def __init__( From 121587be82748fde66a8ccb22193adec7b6174ff Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sun, 29 Sep 2024 16:51:13 -0700 Subject: [PATCH 25/26] Add SDK (#1892) --- libs/checkpoint/langgraph/store/base.py | 25 +- libs/sdk-py/langgraph_sdk/client.py | 414 +++++++++++++++++++++++- libs/sdk-py/langgraph_sdk/schema.py | 24 ++ 3 files changed, 456 insertions(+), 7 deletions(-) diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index c4b6bd552..c51eba14a 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Iterable, Literal, NamedTuple, Optional, Union +from typing import Any, Iterable, Literal, NamedTuple, Optional, Union, cast class Item: @@ -35,9 +35,19 @@ def __init__( ): self.value = value self.key = key + # The casting from json-like types is for if this object is + # deserialized. self.namespace = tuple(namespace) - self.created_at = created_at - self.updated_at = updated_at + self.created_at = ( + datetime.fromisoformat(cast(str, created_at)) + if isinstance(created_at, str) + else created_at + ) + self.updated_at = ( + datetime.fromisoformat(cast(str, created_at)) + if isinstance(updated_at, str) + else updated_at + ) def __eq__(self, other: object) -> bool: if not isinstance(other, Item): @@ -53,6 +63,15 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return hash((self.namespace, self.key)) + def dict(self) -> dict: + return { + "value": self.value, + "key": self.key, + "namespace": list(self.namespace), + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + class GetOp(NamedTuple): """Operation to retrieve an item by namespace and key.""" diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 9125f8bb1..fe2f7f498 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -11,6 +11,7 @@ Iterator, List, Optional, + Sequence, Union, overload, ) @@ -29,12 +30,15 @@ Cron, DisconnectMode, GraphSchema, + Item, Json, + ListNamespaceResponse, MultitaskStrategy, OnCompletionBehavior, OnConflictBehavior, Run, RunCreate, + SearchItemsResponse, StreamMode, StreamPart, Subgraphs, @@ -143,6 +147,7 @@ def __init__(self, client: httpx.AsyncClient) -> None: self.threads = ThreadsClient(self.http) self.runs = RunsClient(self.http) self.crons = CronClient(self.http) + self.store = StoreClient(self.http) class HttpClient: @@ -211,9 +216,9 @@ async def patch(self, path: str, *, json: dict) -> Any: raise e return await adecode_json(r) - async def delete(self, path: str) -> None: + async def delete(self, path: str, *, json: Optional[Any] = None) -> None: """Make a DELETE request.""" - r = await self.client.delete(path) + r = await self.client.request("DELETE", path, json=json) try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -1874,6 +1879,205 @@ async def search( return await self.http.post("/runs/crons/search", json=payload) +class StoreClient: + def __init__(self, http: HttpClient) -> None: + self.http = http + + async def put_item( + self, namespace: Sequence[str], /, key: str, value: dict[str, Any] + ) -> None: + """Store or update an item. + + Args: + namespace: A list of strings representing the namespace path. + key: The unique identifier for the item within the namespace. + value: A dictionary containing the item's data. + + Returns: + None + + Example Usage: + + await client.store.put_item( + ["documents", "user123"], + key="item456", + value={"title": "My Document", "content": "Hello World"} + ) + """ + for label in namespace: + if "." in label: + raise ValueError( + f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." + ) + payload = { + "namespace": namespace, + "key": key, + "value": value, + } + await self.http.put("/store/items", json=payload) + + async def get_item(self, namespace: Sequence[str], /, key: str) -> Item: + """Retrieve a single item. + + Args: + key: The unique identifier for the item. + namespace: Optional list of strings representing the namespace path. + + Returns: + Item: The retrieved item. + + Example Usage: + + item = await client.store.get_item( + ["documents", "user123"], + key="item456", + ) + print(item) + + ---------------------------------------------------------------- + + { + 'namespace': ['documents', 'user123'], + 'key': 'item456', + 'value': {'title': 'My Document', 'content': 'Hello World'}, + 'created_at': '2024-07-30T12:00:00Z', + 'updated_at': '2024-07-30T12:00:00Z' + } + """ + for label in namespace: + if "." in label: + raise ValueError( + f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." + ) + return await self.http.get( + "/store/items", params={"namespace": ".".join(namespace), "key": key} + ) + + async def delete_item(self, namespace: Sequence[str], /, key: str) -> None: + """Delete an item. + + Args: + key: The unique identifier for the item. + namespace: Optional list of strings representing the namespace path. + + Returns: + None + + Example Usage: + + await client.store.delete_item( + ["documents", "user123"], + key="item456", + ) + """ + await self.http.delete( + "/store/items", json={"namespace": namespace, "key": key} + ) + + async def search_items( + self, + namespace_prefix: Sequence[str], + /, + filter: Optional[dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> SearchItemsResponse: + """Search for items within a namespace prefix. + + Args: + namespace_prefix: List of strings representing the namespace prefix. + filter: Optional dictionary of key-value pairs to filter results. + limit: Maximum number of items to return (default is 10). + offset: Number of items to skip before returning results (default is 0). + + Returns: + List[Item]: A list of items matching the search criteria. + + Example Usage: + + items = await client.store.search_items( + ["documents"], + filter={"author": "John Doe"}, + limit=5, + offset=0 + ) + print(items) + + ---------------------------------------------------------------- + + { + "items": [ + { + "namespace": ["documents", "user123"], + "key": "item789", + "value": { + "title": "Another Document", + "author": "John Doe" + }, + "created_at": "2024-07-30T12:00:00Z", + "updated_at": "2024-07-30T12:00:00Z" + }, + # ... additional items ... + ] + } + """ + payload = { + "namespace_prefix": namespace_prefix, + "filter": filter, + "limit": limit, + "offset": offset, + } + + return await self.http.post("/store/items/search", json=_provided_vals(payload)) + + async def list_namespaces( + self, + prefix: Optional[List[str]] = None, + suffix: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> ListNamespaceResponse: + """List namespaces with optional match conditions. + + Args: + prefix: Optional list of strings representing the prefix to filter namespaces. + suffix: Optional list of strings representing the suffix to filter namespaces. + max_depth: Optional integer specifying the maximum depth of namespaces to return. + limit: Maximum number of namespaces to return (default is 100). + offset: Number of namespaces to skip before returning results (default is 0). + + Returns: + List[List[str]]: A list of namespaces matching the criteria. + + Example Usage: + + namespaces = await client.store.list_namespaces( + prefix=["documents"], + max_depth=3, + limit=10, + offset=0 + ) + print(namespaces) + + ---------------------------------------------------------------- + + [ + ["documents", "user123", "reports"], + ["documents", "user456", "invoices"], + ... + ] + """ + payload = { + "prefix": prefix, + "suffix": suffix, + "max_depth": max_depth, + "limit": limit, + "offset": offset, + } + return await self.http.post("/store/namespaces", json=_provided_vals(payload)) + + def get_sync_client( *, url: Optional[str] = None, @@ -1913,6 +2117,7 @@ def __init__(self, client: httpx.Client) -> None: self.threads = SyncThreadsClient(self.http) self.runs = SyncRunsClient(self.http) self.crons = SyncCronClient(self.http) + self.store = SyncStoreClient(self.http) class SyncHttpClient: @@ -1981,9 +2186,9 @@ def patch(self, path: str, *, json: dict) -> Any: raise e return decode_json(r) - def delete(self, path: str) -> None: + def delete(self, path: str, *, json: Optional[Any] = None) -> None: """Make a DELETE request.""" - r = self.client.delete(path) + r = self.client.request("DELETE", path, json=json) try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -3628,3 +3833,204 @@ def search( } payload = {k: v for k, v in payload.items() if v is not None} return self.http.post("/runs/crons/search", json=payload) + + +class SyncStoreClient: + def __init__(self, http: SyncHttpClient) -> None: + self.http = http + + def put_item( + self, namespace: Sequence[str], /, key: str, value: dict[str, Any] + ) -> None: + """Store or update an item. + + Args: + namespace: A list of strings representing the namespace path. + key: The unique identifier for the item within the namespace. + value: A dictionary containing the item's data. + + Returns: + None + + Example Usage: + + client.store.put_item( + ["documents", "user123"], + key="item456", + value={"title": "My Document", "content": "Hello World"} + ) + """ + for label in namespace: + if "." in label: + raise ValueError( + f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." + ) + payload = { + "namespace": namespace, + "key": key, + "value": value, + } + self.http.put("/store/items", json=payload) + + def get_item(self, namespace: Sequence[str], /, key: str) -> Item: + """Retrieve a single item. + + Args: + key: The unique identifier for the item. + namespace: Optional list of strings representing the namespace path. + + Returns: + Item: The retrieved item. + + Example Usage: + + item = client.store.get_item( + ["documents", "user123"], + key="item456", + ) + print(item) + + ---------------------------------------------------------------- + + { + 'namespace': ['documents', 'user123'], + 'key': 'item456', + 'value': {'title': 'My Document', 'content': 'Hello World'}, + 'created_at': '2024-07-30T12:00:00Z', + 'updated_at': '2024-07-30T12:00:00Z' + } + """ + for label in namespace: + if "." in label: + raise ValueError( + f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." + ) + + return self.http.get( + "/store/items", params={"key": key, "namespace": ".".join(namespace)} + ) + + def delete_item(self, namespace: Sequence[str], /, key: str) -> None: + """Delete an item. + + Args: + key: The unique identifier for the item. + namespace: Optional list of strings representing the namespace path. + + Returns: + None + + Example Usage: + + client.store.delete_item( + ["documents", "user123"], + key="item456", + ) + """ + self.http.delete("/store/items", json={"key": key, "namespace": namespace}) + + def search_items( + self, + namespace_prefix: Sequence[str], + /, + filter: Optional[dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> SearchItemsResponse: + """Search for items within a namespace prefix. + + Args: + namespace_prefix: List of strings representing the namespace prefix. + filter: Optional dictionary of key-value pairs to filter results. + limit: Maximum number of items to return (default is 10). + offset: Number of items to skip before returning results (default is 0). + + Returns: + List[Item]: A list of items matching the search criteria. + + Example Usage: + + items = client.store.search_items( + ["documents"], + filter={"author": "John Doe"}, + limit=5, + offset=0 + ) + print(items) + + ---------------------------------------------------------------- + + { + "items": [ + { + "namespace": ["documents", "user123"], + "key": "item789", + "value": { + "title": "Another Document", + "author": "John Doe" + }, + "created_at": "2024-07-30T12:00:00Z", + "updated_at": "2024-07-30T12:00:00Z" + }, + # ... additional items ... + ] + } + """ + payload = { + "namespace_prefix": namespace_prefix, + "filter": filter, + "limit": limit, + "offset": offset, + } + return self.http.post("/store/items/search", json=_provided_vals(payload)) + + def list_namespaces( + self, + prefix: Optional[List[str]] = None, + suffix: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> ListNamespaceResponse: + """List namespaces with optional match conditions. + + Args: + prefix: Optional list of strings representing the prefix to filter namespaces. + suffix: Optional list of strings representing the suffix to filter namespaces. + max_depth: Optional integer specifying the maximum depth of namespaces to return. + limit: Maximum number of namespaces to return (default is 100). + offset: Number of namespaces to skip before returning results (default is 0). + + Returns: + List[List[str]]: A list of namespaces matching the criteria. + + Example Usage: + + namespaces = client.store.list_namespaces( + prefix=["documents"], + max_depth=3, + limit=10, + offset=0 + ) + print(namespaces) + + ---------------------------------------------------------------- + + [ + ["documents", "user123", "reports"], + ["documents", "user456", "invoices"], + ... + ] + """ + payload = { + "prefix": prefix, + "suffix": suffix, + "max_depth": max_depth, + "limit": limit, + "offset": offset, + } + return self.http.post("/store/namespaces", json=_provided_vals(payload)) + + +def _provided_vals(d: dict): + return {k: v for k, v in d.items() if v is not None} diff --git a/libs/sdk-py/langgraph_sdk/schema.py b/libs/sdk-py/langgraph_sdk/schema.py index da1db204c..2ea166aef 100644 --- a/libs/sdk-py/langgraph_sdk/schema.py +++ b/libs/sdk-py/langgraph_sdk/schema.py @@ -197,6 +197,30 @@ class RunCreate(TypedDict): multitask_strategy: Optional[MultitaskStrategy] +class Item(TypedDict): + namespace: list[str] + """The namespace of the item.""" + key: str + """The unique identifier of the item within its namespace. + + In general, keys are not globally unique. + """ + value: dict[str, Any] + """The value stored in the item. This is the document itself.""" + created_at: datetime + """The timestamp when the item was created.""" + updated_at: datetime + """The timestamp when the item was last updated.""" + + +class ListNamespaceResponse(TypedDict): + namespaces: list[list[str]] + + +class SearchItemsResponse(TypedDict): + items: list[Item] + + class StreamPart(NamedTuple): event: str data: dict From a2e58109eac0dc050a5eecf5bddc633104b9ffcc Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Sun, 29 Sep 2024 16:57:40 -0700 Subject: [PATCH 26/26] Clarify order in batch method --- libs/checkpoint/langgraph/store/base.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/libs/checkpoint/langgraph/store/base.py b/libs/checkpoint/langgraph/store/base.py index c51eba14a..d1ad914a6 100644 --- a/libs/checkpoint/langgraph/store/base.py +++ b/libs/checkpoint/langgraph/store/base.py @@ -173,11 +173,27 @@ class BaseStore(ABC): @abstractmethod def batch(self, ops: Iterable[Op]) -> list[Result]: - """Execute a batch of operations synchronously.""" + """Execute multiple operations synchronously in a single batch. + + Args: + ops: An iterable of operations to execute. + + Returns: + A list of results, where each result corresponds to an operation in the input. + The order of results matches the order of input operations. + """ @abstractmethod async def abatch(self, ops: Iterable[Op]) -> list[Result]: - """Execute a batch of operations asynchronously.""" + """Execute multiple operations asynchronously in a single batch. + + Args: + ops: An iterable of operations to execute. + + Returns: + A list of results, where each result corresponds to an operation in the input. + The order of results matches the order of input operations. + """ def get(self, namespace: tuple[str, ...], key: str) -> Optional[Item]: """Retrieve a single item.