From 23e85e8f88f8ae70cf9b1f315ebc57e719ef88f7 Mon Sep 17 00:00:00 2001 From: yufeng-zhou Date: Fri, 6 Sep 2024 19:02:11 -0700 Subject: [PATCH] fix test cases & linting --- .../langgraph/checkpoint/base/__init__.py | 2 +- libs/langgraph/langgraph/pregel/loop.py | 19 ++++----- libs/langgraph/langgraph/pregel/types.py | 2 +- libs/langgraph/tests/test_cache.py | 40 +++++++------------ 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index a522f8c7d7..a4286ed856 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -422,7 +422,7 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """ return current + 1 if current is not None else 1 - def aget_writes_by_task_ids(self, task_id: List[str]) -> Dict[str, List[Any]]: + def aget_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]: """Get a checkpoint writes from the database based on a cache key. Args: diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index d8a06877b2..3f498720a2 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -344,15 +344,16 @@ def tick( task_cached_writes = self.checkpointer.get_writes_by_task_ids( [task.id for task in tasks_with_cache_policy] ) - for task_id, cached_writes in task_cached_writes.items(): - # if there are cached writes and task doesn't have any writes yet, apply them - if cached_writes and not self.tasks[task_id].writes: - task = self.tasks[task_id] - # Extract only the channel and value from cached_writes - task.writes.extend( - [(channel, value) for _, channel, value in cached_writes] - ) - self._output_writes(task_id, task.writes, cached=True) + if task_cached_writes: + for task_id, cached_writes in task_cached_writes.items(): + # if there are cached writes and task doesn't have any writes yet, apply them + if cached_writes and not self.tasks[task_id].writes: + task = self.tasks[task_id] + # Extract only the channel and value from cached_writes + task.writes.extend( + [(channel, value) for _, channel, value in cached_writes] + ) + self._output_writes(task_id, task.writes, cached=True) # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index 4578d39272..ef29330206 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Type, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union from langchain_core.runnables import Runnable, RunnableConfig diff --git a/libs/langgraph/tests/test_cache.py b/libs/langgraph/tests/test_cache.py index 4b7340e03c..00edc61619 100644 --- a/libs/langgraph/tests/test_cache.py +++ b/libs/langgraph/tests/test_cache.py @@ -2,7 +2,6 @@ import json import operator import time - from typing import ( Annotated, Any, @@ -12,13 +11,13 @@ Union, ) -from langgraph.checkpoint.base import BaseCheckpointSaver import pytest from langchain_core.runnables import ( RunnableConfig, ) - from syrupy import SnapshotAssertion + +from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph.state import StateGraph from langgraph.pregel.types import CachePolicy @@ -137,7 +136,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] assert call_count == 1 config = RunnableConfig(configurable={"thread_id": "2"}) @@ -149,6 +148,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: ) ] + # retriever_one should be cached assert stream_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, @@ -157,27 +157,10 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] - # Should not increase count because of cache + # first time calling retriever_one() assert call_count == 1 - config = RunnableConfig(configurable={"thread_id": "3"}) - - stream_results = [ - c - for c in app_with_checkpointer.stream( - {"query": "what is weather in sf"}, config - ) - ] - - assert stream_results == [ - {"rewrite_query": {"query": "query: what is weather in sf"}}, - {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, - {"retriever_two": {"docs": ["doc3", "doc4"]}}, - {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, - ] - - # Should not increase count because of cache + # retriever_one call_count should not increase count because of cache assert call_count == 1 # Cache is not used when checkpointer is not provided @@ -192,6 +175,8 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + + # retriever_one should not be cached assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, @@ -199,6 +184,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra ] + # retriever_one should not be cached, and call_count should increase assert call_count == 2 # Test a new workflow with the same cache key @@ -232,13 +218,15 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + # retriever_one should be cached as long as the cache key is the same assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] + # retriever_one call_count should not increase count because of cache assert call_count == 2 # Test a new workflow with a different cache key @@ -272,11 +260,13 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]: {"query": "what is weather in sf"}, config ) ] + # retriever_one should not be cached because the user_id is different assert interrupt_results == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, - {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra + {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] + # retriever_one call_count should increase count because of cache miss assert call_count == 3