diff --git a/libs/checkpoint/langgraph/store/memory.py b/libs/checkpoint/langgraph/store/memory.py index 8b08079895..2d002de9d3 100644 --- a/libs/checkpoint/langgraph/store/memory.py +++ b/libs/checkpoint/langgraph/store/memory.py @@ -1,5 +1,6 @@ from collections import defaultdict from datetime import datetime, timezone +from typing import Iterable from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, Result, SearchOp @@ -16,7 +17,7 @@ class InMemoryStore(BaseStore): def __init__(self) -> None: self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) - def batch(self, ops: list[Op]) -> list[Result]: + def batch(self, ops: Iterable[Op]) -> list[Result]: results: list[Result] = [] for op in ops: if isinstance(op, GetOp): @@ -63,5 +64,5 @@ def batch(self, ops: list[Op]) -> list[Result]: results.append(None) return results - async def abatch(self, ops: list[Op]) -> list[Result]: + async def abatch(self, ops: Iterable[Op]) -> list[Result]: return self.batch(ops) diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index 807a5798ee..78b9f9e048 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -1,5 +1,6 @@ import asyncio from datetime import datetime +from typing import Iterable import pytest from pytest_mock import MockerFixture @@ -14,18 +15,18 @@ async def test_async_batch_store(mocker: MockerFixture) -> None: abatch = mocker.stub() class MockStore(AsyncBatchedBaseStore): - def batch(self, ops: list[Op]) -> list[Result]: + def batch(self, ops: Iterable[Op]) -> list[Result]: raise NotImplementedError - async def abatch(self, ops: list[Op]) -> list[Result]: + async def abatch(self, ops: Iterable[Op]) -> list[Result]: assert all(isinstance(op, GetOp) for op in ops) abatch(ops) return [ Item( value={}, scores={}, - id=op.id, - namespace=op.namespace, + id=getattr(op, "id", ""), + namespace=getattr(op, "namespace", ()), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), last_accessed_at=datetime(2024, 9, 24, 17, 29, 10, 128397),