Skip to content

Commit

Permalink
fix test cases & linting
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyuf6741 committed Sep 7, 2024
1 parent 89986d4 commit 23e85e8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 36 deletions.
2 changes: 1 addition & 1 deletion libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
"""
return current + 1 if current is not None else 1

def aget_writes_by_task_ids(self, task_id: List[str]) -> Dict[str, List[Any]]:
def aget_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.
Args:
Expand Down
19 changes: 10 additions & 9 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,16 @@ def tick(
task_cached_writes = self.checkpointer.get_writes_by_task_ids(
[task.id for task in tasks_with_cache_policy]
)
for task_id, cached_writes in task_cached_writes.items():
# if there are cached writes and task doesn't have any writes yet, apply them
if cached_writes and not self.tasks[task_id].writes:
task = self.tasks[task_id]
# Extract only the channel and value from cached_writes
task.writes.extend(
[(channel, value) for _, channel, value in cached_writes]
)
self._output_writes(task_id, task.writes, cached=True)
if task_cached_writes:
for task_id, cached_writes in task_cached_writes.items():
# if there are cached writes and task doesn't have any writes yet, apply them
if cached_writes and not self.tasks[task_id].writes:
task = self.tasks[task_id]
# Extract only the channel and value from cached_writes
task.writes.extend(
[(channel, value) for _, channel, value in cached_writes]
)
self._output_writes(task_id, task.writes, cached=True)

# if there are pending writes from a previous loop, apply them
if self.skip_done_tasks and self.checkpoint_pending_writes:
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/pregel/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Type, Union
from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union

from langchain_core.runnables import Runnable, RunnableConfig

Expand Down
40 changes: 15 additions & 25 deletions libs/langgraph/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import operator
import time

from typing import (
Annotated,
Any,
Expand All @@ -12,13 +11,13 @@
Union,
)

from langgraph.checkpoint.base import BaseCheckpointSaver
import pytest
from langchain_core.runnables import (
RunnableConfig,
)

from syrupy import SnapshotAssertion

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import StateGraph
from langgraph.pregel.types import CachePolicy

Expand Down Expand Up @@ -137,7 +136,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
assert call_count == 1
config = RunnableConfig(configurable={"thread_id": "2"})
Expand All @@ -149,6 +148,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
)
]

# retriever_one should be cached
assert stream_results == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
Expand All @@ -157,27 +157,10 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]

# Should not increase count because of cache
# first time calling retriever_one()
assert call_count == 1

config = RunnableConfig(configurable={"thread_id": "3"})

stream_results = [
c
for c in app_with_checkpointer.stream(
{"query": "what is weather in sf"}, config
)
]

assert stream_results == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]

# Should not increase count because of cache
# retriever_one call_count should not increase count because of cache
assert call_count == 1

# Cache is not used when checkpointer is not provided
Expand All @@ -192,13 +175,16 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"query": "what is weather in sf"}, config
)
]

# retriever_one should not be cached
assert interrupt_results == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra
]
# retriever_one should not be cached, and call_count should increase
assert call_count == 2

# Test a new workflow with the same cache key
Expand Down Expand Up @@ -232,13 +218,15 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"query": "what is weather in sf"}, config
)
]
# retriever_one should be cached as long as the cache key is the same
assert interrupt_results == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}, "__metadata__": {"cached": True}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
# retriever_one call_count should not increase count because of cache
assert call_count == 2

# Test a new workflow with a different cache key
Expand Down Expand Up @@ -272,11 +260,13 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"query": "what is weather in sf"}, config
)
]
# retriever_one should not be cached because the user_id is different
assert interrupt_results == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}}, # This item is extra
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
# retriever_one call_count should increase count because of cache miss
assert call_count == 3

0 comments on commit 23e85e8

Please sign in to comment.