From eba767c848874ed336c56549b97ff52966d757e7 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 17:07:25 -0700 Subject: [PATCH 1/6] Adding cache support for postgres --- .../langgraph/checkpoint/postgres/__init__.py | 52 +++- .../langgraph/checkpoint/postgres/base.py | 20 +- .../langgraph/checkpoint/base/__init__.py | 24 ++ libs/langgraph/langgraph/graph/state.py | 7 +- libs/langgraph/langgraph/pregel/algo.py | 54 ++-- libs/langgraph/langgraph/pregel/loop.py | 11 + libs/langgraph/langgraph/pregel/read.py | 5 + libs/langgraph/langgraph/pregel/types.py | 14 +- .../tests/__snapshots__/test_cache.ambr | 27 ++ libs/langgraph/tests/test_cache.py | 279 ++++++++++++++++++ 10 files changed, 465 insertions(+), 28 deletions(-) create mode 100644 libs/langgraph/tests/__snapshots__/test_cache.ambr create mode 100644 libs/langgraph/tests/test_cache.py diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 1c064415a..d918d6ba1 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -3,7 +3,7 @@ from typing import Any, Iterator, List, Optional, Union from langchain_core.runnables import RunnableConfig -from psycopg import Connection, Cursor, Pipeline +from psycopg import Connection, Cursor, DatabaseError, Pipeline from psycopg.errors import UndefinedTable from psycopg.rows import dict_row from psycopg.types.json import Jsonb @@ -251,6 +251,56 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: self._load_writes(value["pending_writes"]), ) + def get_writes_by_cache_key(self, cache_key: str) -> List[Any]: + """Get checkpoint tuples from the database based on a cache key. + + This method retrieves checkpoint tuples from the Postgres database based on the + provided cache key. + + Args: + cache_key (str): The cache key to use for retrieving the checkpoints. + + Returns: + List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + + Examples: + >>> cache_key = "some_unique_cache_key" + >>> checkpoint_tuples = memory.get_writes_by_cache_key(cache_key) + >>> for tuple in checkpoint_tuples: + ... print(tuple) + CheckpointTuple(...) + CheckpointTuple(...) + """ + results = [] + try: + with self._cursor() as cur: + cur.execute( + """ + SELECT task_id, channel, type, blob + FROM checkpoint_writes + WHERE task_id = %s + ORDER BY idx ASC + """, + (cache_key,), + binary=True, + ) + + for row in cur: + results.append(( + row['task_id'], + row['channel'], + row['type'], + row['blob'] + )) + except DatabaseError as e: + # Log the error or handle it as appropriate for your application + # Optionally re-raise the error if you want it to propagate + # raise + raise RuntimeError( + f"Exception occurred while fetching writes from the database: {e}" + ) + return self._load_writes(results) + def put( self, config: RunnableConfig, diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 6cfbc5108..a311c52b4 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -1,5 +1,5 @@ from hashlib import md5 -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union from langchain_core.runnables import RunnableConfig from psycopg.types.json import Jsonb @@ -177,14 +177,24 @@ def _dump_blobs( ] def _load_writes( - self, writes: list[tuple[bytes, bytes, bytes, bytes]] + self, + writes: list[ + tuple[ + Union[str, bytes], + Union[str, bytes], + Union[str, bytes], + Union[str, bytes], + ] + ], ) -> list[tuple[str, str, Any]]: return ( [ ( - tid.decode(), - channel.decode(), - self.serde.loads_typed((t.decode(), v)), + tid.decode() if isinstance(tid, bytes) else tid, + channel.decode() if isinstance(channel, bytes) else channel, + self.serde.loads_typed( + (t.decode() if isinstance(t, bytes) else t, v) + ), ) for tid, channel, t, v in writes ] diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index 17caf46d8..56eaf1e80 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -422,6 +422,30 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """ return current + 1 if current is not None else 1 + def aget_writes_by_cache_key( + self, cache_key: str + ) -> Optional[List[CheckpointTuple]]: + """Get a checkpoint tuple from the database based on a cache key. + + Args: + cache_key (str): The cache key to use for retrieving the checkpoint. + + Returns: + List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + """ + raise NotImplementedError + + def get_writes_by_cache_key(self, cache_key: str) -> Optional[List[Any]]: + """Get a checkpoint tuple from the database based on a cache key. + + Args: + cache_key (str): The cache key to use for retrieving the checkpoint. + + Returns: + List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + """ + pass + class EmptyChannelError(Exception): """Raised when attempting to get the value of a channel that hasn't been updated diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index bfa9792ee..5b37f5a3a 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -42,7 +42,7 @@ is_writable_managed_value, ) from langgraph.pregel.read import ChannelRead, PregelNode -from langgraph.pregel.types import All, RetryPolicy +from langgraph.pregel.types import All, CachePolicy, RetryPolicy from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore from langgraph.utils.fields import get_field_default @@ -68,6 +68,7 @@ class StateNodeSpec(NamedTuple): metadata: dict[str, Any] input: Type[Any] retry_policy: Optional[RetryPolicy] + cache_policy: Optional[CachePolicy] class StateGraph(Graph): @@ -202,6 +203,7 @@ def add_node( metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, + cache: Optional[CachePolicy] = None, ) -> None: """Adds a new node to the state graph. Will take the name of the function/runnable as the node name. @@ -249,6 +251,7 @@ def add_node( metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, + cache: Optional[CachePolicy] = None, ) -> None: """Adds a new node to the state graph. @@ -343,6 +346,7 @@ def add_node( metadata, input=input or self.schema, retry_policy=retry, + cache_policy=cache, ) def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None: @@ -571,6 +575,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any: ], metadata=node.metadata, retry_policy=node.retry_policy, + cache_policy=node.cache_policy, bound=node.runnable, ) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 033fab8e0..35a9dfdce 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -334,7 +334,6 @@ def prepare_single_task( checkpoint_id = UUID(checkpoint["id"]).bytes configurable = config.get("configurable", {}) parent_ns = configurable.get("checkpoint_ns", "") - if task_path[0] == TASKS: idx = int(task_path[1]) packet = checkpoint["pending_sends"][idx] @@ -357,14 +356,23 @@ def prepare_single_task( checkpoint_ns = ( f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node ) - task_id = _uuid5_str( - checkpoint_id, - checkpoint_ns, - str(step), - packet.node, - TASKS, - str(idx), - ) + + if proc.cache_policy: + cache_key = proc.cache_policy.cache_key + task_id = _uuid5_str( + b"", + cache_key, + ) + else: + task_id = _uuid5_str( + checkpoint_id, + checkpoint_ns, + str(step), + packet.node, + TASKS, + str(idx), + ) + if task_id_checksum is not None: assert task_id == task_id_checksum if for_execution: @@ -423,7 +431,7 @@ def prepare_single_task( ), triggers, proc.retry_policy, - None, + proc.cache_policy, task_id, task_path, ) @@ -465,14 +473,22 @@ def prepare_single_task( "langgraph_path": task_path, } checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name - task_id = _uuid5_str( - checkpoint_id, - checkpoint_ns, - str(step), - name, - SUBSCRIPTIONS, - *triggers, - ) + + if proc.cache_policy: + cache_key = proc.cache_policy.cache_key + task_id = _uuid5_str( + b"", + cache_key, + ) + else: + task_id = _uuid5_str( + checkpoint_id, + checkpoint_ns, + str(step), + name, + SUBSCRIPTIONS, + *triggers, + ) if task_id_checksum is not None: assert task_id == task_id_checksum @@ -531,7 +547,7 @@ def prepare_single_task( ), triggers, proc.retry_policy, - None, + proc.cache_policy, task_id, task_path, ) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index afd29ec1b..54e048b5b 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -334,6 +334,17 @@ def tick( self.status = "done" return False + if self.checkpointer: + for task in self.tasks.values(): + # if there are cached writes, apply them + cached_writes = self.checkpointer.get_writes_by_cache_key(task.id) + if cached_writes and not task.writes: + # Extract only the last two items from each write tuple + task.writes.extend( + [(channel, value) for _, channel, value in cached_writes] + ) + self._output_writes(task.id, task.writes, cached=True) + # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: for tid, k, v in self.checkpoint_pending_writes: diff --git a/libs/langgraph/langgraph/pregel/read.py b/libs/langgraph/langgraph/pregel/read.py index 4d9944661..ecdfd5c93 100644 --- a/libs/langgraph/langgraph/pregel/read.py +++ b/libs/langgraph/langgraph/pregel/read.py @@ -23,6 +23,7 @@ from langgraph.constants import CONFIG_KEY_READ from langgraph.pregel.retry import RetryPolicy +from langgraph.pregel.types import CachePolicy from langgraph.pregel.write import ChannelWrite from langgraph.utils.config import merge_configs from langgraph.utils.runnable import RunnableCallable, RunnableSeq @@ -122,6 +123,8 @@ class PregelNode(Runnable): config: RunnableConfig + cache_policy: Optional[CachePolicy] + def __init__( self, *, @@ -134,6 +137,7 @@ def __init__( bound: Optional[Runnable[Any, Any]] = None, retry_policy: Optional[RetryPolicy] = None, config: Optional[RunnableConfig] = None, + cache_policy: Optional[CachePolicy] = None, ) -> None: self.channels = channels self.triggers = list(triggers) @@ -144,6 +148,7 @@ def __init__( self.config = merge_configs( config, {"tags": tags or [], "metadata": metadata or {}} ) + self.cache_policy = cache_policy def copy(self, update: dict[str, Any]) -> PregelNode: attrs = {**self.__dict__, **update} diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index 80d0d8a23..d2b3b9914 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Type, Union from langchain_core.runnables import Runnable, RunnableConfig @@ -60,7 +60,17 @@ class RetryPolicy(NamedTuple): class CachePolicy(NamedTuple): """Configuration for caching nodes.""" - pass + cache_key: Optional[Callable[[Any, Optional[RunnableConfig]], str]] + """ + A function that takes in the input and config, and returns a string key + under which the output should be cached. + """ + # TODO: implement cache_ttl + # cache_ttl: Optional[float] = None + # """ + # Time-to-live for the cached value, in seconds. If not provided, the value will be cached indefinitely. + # We'd probably want to store this in a bucket way intead of a TTL timeline. + # """ class PregelTask(NamedTuple): diff --git a/libs/langgraph/tests/__snapshots__/test_cache.ambr b/libs/langgraph/tests/__snapshots__/test_cache.ambr new file mode 100644 index 000000000..bb94363e5 --- /dev/null +++ b/libs/langgraph/tests/__snapshots__/test_cache.ambr @@ -0,0 +1,27 @@ +# serializer version: 1 +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres].1 + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- diff --git a/libs/langgraph/tests/test_cache.py b/libs/langgraph/tests/test_cache.py new file mode 100644 index 000000000..8c0f56bd5 --- /dev/null +++ b/libs/langgraph/tests/test_cache.py @@ -0,0 +1,279 @@ +import hashlib +import json +import operator +import time + +from typing import ( + Annotated, + Any, + Literal, + Optional, + TypedDict, + Union, +) + +from langgraph.checkpoint.base import BaseCheckpointSaver +import pytest +from langchain_core.runnables import ( + RunnableConfig, +) + +from syrupy import SnapshotAssertion +from langgraph.graph.state import StateGraph +from langgraph.pregel.types import CachePolicy + + +def custom_cache_key(input: Any, config: Optional[RunnableConfig] = None) -> str: + """ + Generate a cache key based on the input and config. + + Args: + input (Any): The input to the node. + config (Optional[RunnableConfig]): The configuration for the node. + + Returns: + str: A string key under which the output should be cached. + """ + # Convert input to a JSON-serializable format + if isinstance(input, dict): + input_str = json.dumps(input, sort_keys=True) + elif isinstance(input, (str, int, float, bool)): + input_str = str(input) + else: + input_str = str(hash(input)) + + # Extract relevant parts from the config + config_str = "" + if config: + relevant_config = { + "tags": config.get("tags", []), + "metadata": config.get("metadata", {}), + } + config_str = json.dumps(relevant_config, sort_keys=True) + print() + # Combine input and config strings + combined_str = f"{input_str}|{config_str}" + print("combined_str", combined_str) + # Generate a hash of the combined string + return hashlib.md5(combined_str.encode("utf-8")).hexdigest() + + +@pytest.mark.parametrize("checkpointer_name", ["postgres"]) +def test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache( + snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer: BaseCheckpointSaver = request.getfixturevalue( + f"checkpointer_{checkpointer_name}" + ) + + def sorted_add( + x: list[str], y: Union[list[str], list[tuple[str, str]]] + ) -> list[str]: + if isinstance(y[0], tuple): + for rem, _ in y: + x.remove(rem) + y = [t[1] for t in y] + return sorted(operator.add(x, y)) + + class State(TypedDict, total=False): + query: str + answer: str + docs: Annotated[list[str], sorted_add] + + workflow = StateGraph(State) + call_count = 0 + + def rewrite_query(data: State) -> State: + return {"query": f'query: {data["query"]}'} + + def analyzer_one(data: State) -> State: + return {"query": f'analyzed: {data["query"]}'} + + def retriever_one(data: State) -> State: + nonlocal call_count + call_count += 1 + print("increasing count", call_count) + return {"docs": ["doc1", "doc2"]} + + def retriever_two(data: State) -> State: + time.sleep(0.1) + return {"docs": ["doc3", "doc4"]} + + def qa(data: State) -> State: + print("qaaaa", data["docs"]) + return {"answer": ",".join(data["docs"])} + + def rewrite_query_then(data: State) -> Literal["retriever_two"]: + return "retriever_two" + + config = RunnableConfig(configurable={"thread_id": "1"}) + + workflow.add_node("rewrite_query", rewrite_query) + workflow.add_node("analyzer_one", analyzer_one) + workflow.add_node( + "retriever_one", + retriever_one, + cache=CachePolicy(custom_cache_key({"user_id": "a user"}, config)), + ) + workflow.add_node("retriever_two", retriever_two) + workflow.add_node("qa", qa) + + workflow.set_entry_point("rewrite_query") + workflow.add_edge("rewrite_query", "analyzer_one") + workflow.add_edge("analyzer_one", "retriever_one") + workflow.add_conditional_edges("rewrite_query", rewrite_query_then) + workflow.add_edge(["retriever_one", "retriever_two"], "qa") + workflow.set_finish_point("qa") + + app_with_checkpointer = workflow.compile( + checkpointer=checkpointer, + ) + + interrupt_results = [ + c + for c in app_with_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + assert interrupt_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + ] + assert call_count == 1 + config = RunnableConfig(configurable={"thread_id": "2"}) + + stream_results = [ + c + for c in app_with_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + + assert stream_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, + ] + + # Should not increase count because of cache + assert call_count == 1 + + config = RunnableConfig(configurable={"thread_id": "3"}) + + stream_results = [ + c + for c in app_with_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + + assert stream_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, + ] + + # Should not increase count because of cache + assert call_count == 1 + + # Cache is not used when checkpointer is not provided + app_without_checkpointer = workflow.compile() + interrupt_results = [ + c + for c in app_without_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + assert interrupt_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + ] + assert call_count == 2 + + # Test a new workflow with the same cache key + new_workflow = StateGraph(State) + config = RunnableConfig(configurable={"thread_id": "4"}) + new_workflow.add_node("rewrite_query", rewrite_query) + new_workflow.add_node("analyzer_one", analyzer_one) + new_workflow.add_node( + "retriever_one", + retriever_one, + cache=CachePolicy(custom_cache_key({"user_id": "a user"}, config)), + ) + new_workflow.add_node("retriever_two", retriever_two) + new_workflow.add_node("qa", qa) + + new_workflow.set_entry_point("rewrite_query") + new_workflow.add_edge("rewrite_query", "analyzer_one") + new_workflow.add_edge("analyzer_one", "retriever_one") + new_workflow.add_conditional_edges("rewrite_query", rewrite_query_then) + new_workflow.add_edge(["retriever_one", "retriever_two"], "qa") + new_workflow.set_finish_point("qa") + + app_with_checkpointer = new_workflow.compile( + checkpointer=checkpointer, + ) + + interrupt_results = [ + c + for c in app_with_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + assert interrupt_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + ] + assert call_count == 2 + + # Test a new workflow with a different cache key + another_new_workflow = StateGraph(State) + config = RunnableConfig(configurable={"thread_id": "5"}) + another_new_workflow.add_node("rewrite_query", rewrite_query) + another_new_workflow.add_node("analyzer_one", analyzer_one) + another_new_workflow.add_node( + "retriever_one", + retriever_one, + cache=CachePolicy(custom_cache_key({"user_id": "a different user"}, config)), + ) + another_new_workflow.add_node("retriever_two", retriever_two) + another_new_workflow.add_node("qa", qa) + + another_new_workflow.set_entry_point("rewrite_query") + another_new_workflow.add_edge("rewrite_query", "analyzer_one") + another_new_workflow.add_edge("analyzer_one", "retriever_one") + another_new_workflow.add_conditional_edges("rewrite_query", rewrite_query_then) + another_new_workflow.add_edge(["retriever_one", "retriever_two"], "qa") + another_new_workflow.set_finish_point("qa") + + app_with_checkpointer = another_new_workflow.compile( + checkpointer=checkpointer, + ) + + interrupt_results = [ + c + for c in app_with_checkpointer.stream( + {"query": "what is weather in sf"}, config + ) + ] + assert interrupt_results == [ + {"rewrite_query": {"query": "query: what is weather in sf"}}, + {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, + {"retriever_two": {"docs": ["doc3", "doc4"]}}, + {"retriever_one": {"docs": ["doc1", "doc2"]}}, + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + ] + assert call_count == 3 From e52aca9cee7a7695b5a83eeb4ae42b9bad6d3c12 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 17:16:33 -0700 Subject: [PATCH 2/6] remove print statement --- .../langgraph/checkpoint/postgres/__init__.py | 14 ++++++-------- .../langgraph/checkpoint/base/__init__.py | 10 +++++----- libs/langgraph/tests/test_cache.py | 4 ---- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index d918d6ba1..1ecfd1d12 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -252,24 +252,22 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: ) def get_writes_by_cache_key(self, cache_key: str) -> List[Any]: - """Get checkpoint tuples from the database based on a cache key. + """Get checkpoint writes from the database based on a cache key. - This method retrieves checkpoint tuples from the Postgres database based on the + This method retrieves checkpoint writes from the Postgres database based on the provided cache key. Args: cache_key (str): The cache key to use for retrieving the checkpoints. Returns: - List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + List[Any]: A list of retrieved checkpoint writes. Empty list if none found. Examples: >>> cache_key = "some_unique_cache_key" - >>> checkpoint_tuples = memory.get_writes_by_cache_key(cache_key) - >>> for tuple in checkpoint_tuples: - ... print(tuple) - CheckpointTuple(...) - CheckpointTuple(...) + >>> checkpoint_writes = memory.get_writes_by_cache_key(cache_key) + >>> for write in checkpoint_writes: + ... print(write) """ results = [] try: diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index 56eaf1e80..db71a7d61 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -424,25 +424,25 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: def aget_writes_by_cache_key( self, cache_key: str - ) -> Optional[List[CheckpointTuple]]: - """Get a checkpoint tuple from the database based on a cache key. + ) -> Optional[List[Any]]: + """Get a checkpoint writes from the database based on a cache key. Args: cache_key (str): The cache key to use for retrieving the checkpoint. Returns: - List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + List[Any]: A list of retrieved checkpoint writes. Empty list if none found. """ raise NotImplementedError def get_writes_by_cache_key(self, cache_key: str) -> Optional[List[Any]]: - """Get a checkpoint tuple from the database based on a cache key. + """Get a checkpoint writes from the database based on a cache key. Args: cache_key (str): The cache key to use for retrieving the checkpoint. Returns: - List[CheckpointTuple]: A list of retrieved checkpoint tuples. Empty list if none found. + List[Any]: A list of retrieved checkpoint writes. Empty list if none found. """ pass diff --git a/libs/langgraph/tests/test_cache.py b/libs/langgraph/tests/test_cache.py index 8c0f56bd5..926f0f3fa 100644 --- a/libs/langgraph/tests/test_cache.py +++ b/libs/langgraph/tests/test_cache.py @@ -50,10 +50,8 @@ def custom_cache_key(input: Any, config: Optional[RunnableConfig] = None) -> str "metadata": config.get("metadata", {}), } config_str = json.dumps(relevant_config, sort_keys=True) - print() # Combine input and config strings combined_str = f"{input_str}|{config_str}" - print("combined_str", combined_str) # Generate a hash of the combined string return hashlib.md5(combined_str.encode("utf-8")).hexdigest() @@ -92,7 +90,6 @@ def analyzer_one(data: State) -> State: def retriever_one(data: State) -> State: nonlocal call_count call_count += 1 - print("increasing count", call_count) return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: @@ -100,7 +97,6 @@ def retriever_two(data: State) -> State: return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: - print("qaaaa", data["docs"]) return {"answer": ",".join(data["docs"])} def rewrite_query_then(data: State) -> Literal["retriever_two"]: From 5a475616b2fe100e4fe150d61ff9161c81455fc1 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 17:23:35 -0700 Subject: [PATCH 3/6] always use node name as part of the cache key --- libs/langgraph/langgraph/pregel/algo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 35a9dfdce..39d558e1b 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -361,6 +361,7 @@ def prepare_single_task( cache_key = proc.cache_policy.cache_key task_id = _uuid5_str( b"", + packet.node, cache_key, ) else: @@ -478,6 +479,7 @@ def prepare_single_task( cache_key = proc.cache_policy.cache_key task_id = _uuid5_str( b"", + name, cache_key, ) else: From 689ecc49b4a2e70aec0c19e6b05d0767cf04c045 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 18:17:18 -0700 Subject: [PATCH 4/6] update the query to fetch multiple tasks writes together --- .../langgraph/checkpoint/postgres/__init__.py | 36 ++++++++++++------- .../langgraph/checkpoint/postgres/base.py | 1 + .../langgraph/checkpoint/base/__init__.py | 10 +++--- libs/langgraph/langgraph/pregel/loop.py | 21 +++++++---- .../tests/__snapshots__/test_cache.ambr | 26 ++++++++++++++ libs/langgraph/tests/test_cache.py | 7 ++++ 6 files changed, 76 insertions(+), 25 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 1ecfd1d12..ba195149f 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -1,6 +1,6 @@ import threading from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union from langchain_core.runnables import RunnableConfig from psycopg import Connection, Cursor, DatabaseError, Pipeline @@ -251,40 +251,46 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: self._load_writes(value["pending_writes"]), ) - def get_writes_by_cache_key(self, cache_key: str) -> List[Any]: + def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List[Any]]]: """Get checkpoint writes from the database based on a cache key. This method retrieves checkpoint writes from the Postgres database based on the provided cache key. Args: - cache_key (str): The cache key to use for retrieving the checkpoints. + task_ids (str): The task id is serving as the cache key to retrieve writes. Returns: List[Any]: A list of retrieved checkpoint writes. Empty list if none found. Examples: - >>> cache_key = "some_unique_cache_key" - >>> checkpoint_writes = memory.get_writes_by_cache_key(cache_key) - >>> for write in checkpoint_writes: - ... print(write) - """ - results = [] + >>> task_ids = ["task1", "task2", "task3"] + >>> checkpoint_writes = memory.get_writes_by_task_ids(task_ids) + >>> for task_id, writes in checkpoint_writes.items(): + ... print(f"Task ID: {task_id}") + ... for write in writes: + ... print(f" {write}") + """ + results = {} try: with self._cursor() as cur: cur.execute( """ SELECT task_id, channel, type, blob FROM checkpoint_writes - WHERE task_id = %s + WHERE task_id = ANY(%s) ORDER BY idx ASC """, - (cache_key,), + (task_ids,), binary=True, ) for row in cur: - results.append(( + task_id = row['task_id'] + if task_id not in results: + results[task_id] = [] + # Appending all the writes to the mapping task_id + results[task_id].append(( row['task_id'], row['channel'], row['type'], @@ -297,7 +303,11 @@ def get_writes_by_cache_key(self, cache_key: str) -> List[Any]: raise RuntimeError( f"Exception occurred while fetching writes from the database: {e}" ) - return self._load_writes(results) + + for task_id, writes in results.items(): + # check writes are loaded correctly + results[task_id] = self._load_writes(writes) + return results def put( self, diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index a311c52b4..38dfc7df1 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -178,6 +178,7 @@ def _dump_blobs( def _load_writes( self, + # We want to support both bytes and strings here writes: list[ tuple[ Union[str, bytes], diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index db71a7d61..ad688431f 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -422,24 +422,22 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """ return current + 1 if current is not None else 1 - def aget_writes_by_cache_key( - self, cache_key: str - ) -> Optional[List[Any]]: + def get_writes_by_task_ids(self, task_id: List[str]) -> Optional[Dict[str, List[Any]]]: """Get a checkpoint writes from the database based on a cache key. Args: - cache_key (str): The cache key to use for retrieving the checkpoint. + task_ids (str): The task id is serving as the cache key to retrieve writes. Returns: List[Any]: A list of retrieved checkpoint writes. Empty list if none found. """ raise NotImplementedError - def get_writes_by_cache_key(self, cache_key: str) -> Optional[List[Any]]: + def get_writes_by_task_ids(self, task_id: List[str]) -> Optional[Dict[str, List[Any]]]: """Get a checkpoint writes from the database based on a cache key. Args: - cache_key (str): The cache key to use for retrieving the checkpoint. + task_ids (str): The task id is serving as the cache key to retrieve writes. Returns: List[Any]: A list of retrieved checkpoint writes. Empty list if none found. diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 54e048b5b..d8a06877b 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -335,15 +335,24 @@ def tick( return False if self.checkpointer: - for task in self.tasks.values(): - # if there are cached writes, apply them - cached_writes = self.checkpointer.get_writes_by_cache_key(task.id) - if cached_writes and not task.writes: - # Extract only the last two items from each write tuple + # only fetching tasks with cache policy enabled + tasks_with_cache_policy = [ + task for task in self.tasks.values() if task.cache_policy + ] + + # if there are cached writes, apply them + task_cached_writes = self.checkpointer.get_writes_by_task_ids( + [task.id for task in tasks_with_cache_policy] + ) + for task_id, cached_writes in task_cached_writes.items(): + # if there are cached writes and task doesn't have any writes yet, apply them + if cached_writes and not self.tasks[task_id].writes: + task = self.tasks[task_id] + # Extract only the channel and value from cached_writes task.writes.extend( [(channel, value) for _, channel, value in cached_writes] ) - self._output_writes(task.id, task.writes, cached=True) + self._output_writes(task_id, task.writes, cached=True) # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: diff --git a/libs/langgraph/tests/__snapshots__/test_cache.ambr b/libs/langgraph/tests/__snapshots__/test_cache.ambr index bb94363e5..d96ad206e 100644 --- a/libs/langgraph/tests/__snapshots__/test_cache.ambr +++ b/libs/langgraph/tests/__snapshots__/test_cache.ambr @@ -25,3 +25,29 @@ ''' # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres].2 + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres].3 + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- diff --git a/libs/langgraph/tests/test_cache.py b/libs/langgraph/tests/test_cache.py index 926f0f3fa..4b7340e03 100644 --- a/libs/langgraph/tests/test_cache.py +++ b/libs/langgraph/tests/test_cache.py @@ -124,6 +124,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: app_with_checkpointer = workflow.compile( checkpointer=checkpointer, ) + assert app_with_checkpointer.get_graph().draw_mermaid(with_styles=False) == snapshot interrupt_results = [ c @@ -181,6 +182,10 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: # Cache is not used when checkpointer is not provided app_without_checkpointer = workflow.compile() + assert ( + app_without_checkpointer.get_graph().draw_mermaid(with_styles=False) == snapshot + ) + interrupt_results = [ c for c in app_without_checkpointer.stream( @@ -219,6 +224,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: app_with_checkpointer = new_workflow.compile( checkpointer=checkpointer, ) + assert app_with_checkpointer.get_graph().draw_mermaid(with_styles=False) == snapshot interrupt_results = [ c @@ -258,6 +264,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: app_with_checkpointer = another_new_workflow.compile( checkpointer=checkpointer, ) + assert app_with_checkpointer.get_graph().draw_mermaid(with_styles=False) == snapshot interrupt_results = [ c From 6c3fff9375263f30619439810a263145f5001250 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 18:29:46 -0700 Subject: [PATCH 5/6] fix errors --- libs/checkpoint/langgraph/checkpoint/base/__init__.py | 8 ++++++-- libs/langgraph/langgraph/pregel/types.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index ad688431f..dfd5d12d0 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -422,7 +422,9 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """ return current + 1 if current is not None else 1 - def get_writes_by_task_ids(self, task_id: List[str]) -> Optional[Dict[str, List[Any]]]: + def aget_writes_by_task_ids( + self, task_ids: List[str] + ) -> Optional[Dict[str, List[Any]]]: """Get a checkpoint writes from the database based on a cache key. Args: @@ -433,7 +435,9 @@ def get_writes_by_task_ids(self, task_id: List[str]) -> Optional[Dict[str, List[ """ raise NotImplementedError - def get_writes_by_task_ids(self, task_id: List[str]) -> Optional[Dict[str, List[Any]]]: + def get_writes_by_task_ids( + self, task_ids: List[str] + ) -> Optional[Dict[str, List[Any]]]: """Get a checkpoint writes from the database based on a cache key. Args: diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index d2b3b9914..4578d3927 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -69,7 +69,7 @@ class CachePolicy(NamedTuple): # cache_ttl: Optional[float] = None # """ # Time-to-live for the cached value, in seconds. If not provided, the value will be cached indefinitely. - # We'd probably want to store this in a bucket way intead of a TTL timeline. + # We'd probably want to store this in a bucket way instead of a TTL timeline. # """ From cacf54ffa5afe5e045357c65f8d0884db1b0a605 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 18:43:25 -0700 Subject: [PATCH 6/6] fix linting --- .../langgraph/checkpoint/postgres/__init__.py | 37 +++++++--------- .../langgraph/checkpoint/base/__init__.py | 12 ++---- libs/langgraph/langgraph/pregel/algo.py | 2 +- libs/langgraph/langgraph/pregel/loop.py | 19 +++++---- libs/langgraph/langgraph/pregel/types.py | 2 +- libs/langgraph/tests/test_cache.py | 42 +++++++------------ 6 files changed, 46 insertions(+), 68 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index ba195149f..a7928fb47 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -251,17 +251,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: self._load_writes(value["pending_writes"]), ) - def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List[Any]]]: - """Get checkpoint writes from the database based on a cache key. + def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]: + """Get checkpoint writes from the database based on multiple task IDs. This method retrieves checkpoint writes from the Postgres database based on the - provided cache key. + provided list of task IDs. Args: - task_ids (str): The task id is serving as the cache key to retrieve writes. + task_ids (List[str]): A list of task IDs to retrieve checkpoint writes for. Returns: - List[Any]: A list of retrieved checkpoint writes. Empty list if none found. + Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes. Examples: >>> task_ids = ["task1", "task2", "task3"] @@ -270,7 +270,7 @@ def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List ... print(f"Task ID: {task_id}") ... for write in writes: ... print(f" {write}") - """ + """ results = {} try: with self._cursor() as cur: @@ -279,35 +279,28 @@ def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List SELECT task_id, channel, type, blob FROM checkpoint_writes WHERE task_id = ANY(%s) - ORDER BY idx ASC + ORDER BY task_id, idx ASC """, (task_ids,), binary=True, ) for row in cur: - task_id = row['task_id'] + task_id = row["task_id"] if task_id not in results: results[task_id] = [] # Appending all the writes to the mapping task_id - results[task_id].append(( - row['task_id'], - row['channel'], - row['type'], - row['blob'] - )) + results[task_id].append( + (row["task_id"], row["channel"], row["type"], row["blob"]) + ) except DatabaseError as e: - # Log the error or handle it as appropriate for your application - # Optionally re-raise the error if you want it to propagate - # raise raise RuntimeError( f"Exception occurred while fetching writes from the database: {e}" - ) + ) from e - for task_id, writes in results.items(): - # check writes are loaded correctly - results[task_id] = self._load_writes(writes) - return results + return { + task_id: self._load_writes(writes) for task_id, writes in results.items() + } def put( self, diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index dfd5d12d0..a4286ed85 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -422,29 +422,25 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """ return current + 1 if current is not None else 1 - def aget_writes_by_task_ids( - self, task_ids: List[str] - ) -> Optional[Dict[str, List[Any]]]: + def aget_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]: """Get a checkpoint writes from the database based on a cache key. Args: task_ids (str): The task id is serving as the cache key to retrieve writes. Returns: - List[Any]: A list of retrieved checkpoint writes. Empty list if none found. + Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes. """ raise NotImplementedError - def get_writes_by_task_ids( - self, task_ids: List[str] - ) -> Optional[Dict[str, List[Any]]]: + def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]: """Get a checkpoint writes from the database based on a cache key. Args: task_ids (str): The task id is serving as the cache key to retrieve writes. Returns: - List[Any]: A list of retrieved checkpoint writes. Empty list if none found. + Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes. """ pass diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 39d558e1b..25e162c19 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -356,6 +356,7 @@ def prepare_single_task( checkpoint_ns = ( f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node ) + proc = processes[packet.node] if proc.cache_policy: cache_key = proc.cache_policy.cache_key @@ -377,7 +378,6 @@ def prepare_single_task( if task_id_checksum is not None: assert task_id == task_id_checksum if for_execution: - proc = processes[packet.node] if node := proc.node: managed.replace_runtime_placeholders(step, packet.arg) writes = deque() diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index d8a06877b..3f498720a 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -344,15 +344,16 @@ def tick( task_cached_writes = self.checkpointer.get_writes_by_task_ids( [task.id for task in tasks_with_cache_policy] ) - for task_id, cached_writes in task_cached_writes.items(): - # if there are cached writes and task doesn't have any writes yet, apply them - if cached_writes and not self.tasks[task_id].writes: - task = self.tasks[task_id] - # Extract only the channel and value from cached_writes - task.writes.extend( - [(channel, value) for _, channel, value in cached_writes] - ) - self._output_writes(task_id, task.writes, cached=True) + if task_cached_writes: + for task_id, cached_writes in task_cached_writes.items(): + # if there are cached writes and task doesn't have any writes yet, apply them + if cached_writes and not self.tasks[task_id].writes: + task = self.tasks[task_id] + # Extract only the channel and value from cached_writes + task.writes.extend( + [(channel, value) for _, channel, value in cached_writes] + ) + self._output_writes(task_id, task.writes, cached=True) # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index 4578d3927..ef2933020 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Type, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union from langchain_core.runnables import Runnable, RunnableConfig diff --git a/libs/langgraph/tests/test_cache.py b/libs/langgraph/tests/test_cache.py index 4b7340e03..af6ad3a6c 100644 --- a/libs/langgraph/tests/test_cache.py +++ b/libs/langgraph/tests/test_cache.py @@ -2,7 +2,6 @@ import json import operator import time - from typing import ( Annotated, Any, @@ -12,13 +11,13 @@ Union, ) -from langgraph.checkpoint.base import BaseCheckpointSaver import pytest from langchain_core.runnables import ( RunnableConfig, ) - from syrupy import SnapshotAssertion + +from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph.state import StateGraph from langgraph.pregel.types import CachePolicy @@ -137,30 +136,11 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra - ] - assert call_count == 1 - config = RunnableConfig(configurable={"thread_id": "2"}) - - stream_results = [ - c - for c in app_with_checkpointer.stream( - {"query": "what is weather in sf"}, config - ) - ] - - assert stream_results == [ - {"rewrite_query": {"query": "query: what is weather in sf"}}, - {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, - {"retriever_two": {"docs": ["doc3", "doc4"]}}, - {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] - - # Should not increase count because of cache + # first time calling retriever_one() assert call_count == 1 - - config = RunnableConfig(configurable={"thread_id": "3"}) + config = RunnableConfig(configurable={"thread_id": "2"}) stream_results = [ c @@ -169,6 +149,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: ) ] + # retriever_one should be cached assert stream_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, @@ -177,7 +158,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] - # Should not increase count because of cache + # retriever_one call_count should not increase count because of cache assert call_count == 1 # Cache is not used when checkpointer is not provided @@ -192,6 +173,8 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + + # retriever_one should not be cached assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, @@ -199,6 +182,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra ] + # retriever_one should not be cached, and call_count should increase assert call_count == 2 # Test a new workflow with the same cache key @@ -232,13 +216,15 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + # retriever_one should be cached as long as the cache key is the same assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] + # retriever_one call_count should not increase count because of cache assert call_count == 2 # Test a new workflow with a different cache key @@ -272,11 +258,13 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + # retriever_one should not be cached because the user_id is different assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] + # retriever_one call_count should increase count because of cache miss assert call_count == 3