Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding postgres cache support for node #1648

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Iterator, List, Optional, Union

from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
from psycopg import Connection, Cursor, DatabaseError, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import dict_row
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -251,6 +251,54 @@
self._load_writes(value["pending_writes"]),
)

def get_writes_by_cache_key(self, cache_key: str) -> List[Any]:
"""Get checkpoint writes from the database based on a cache key.

This method retrieves checkpoint writes from the Postgres database based on the
provided cache key.

Args:
cache_key (str): The cache key to use for retrieving the checkpoints.

Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.

Examples:
>>> cache_key = "some_unique_cache_key"
>>> checkpoint_writes = memory.get_writes_by_cache_key(cache_key)
>>> for write in checkpoint_writes:
... print(write)
"""
results = []
try:
with self._cursor() as cur:
cur.execute(
"""
SELECT task_id, channel, type, blob
FROM checkpoint_writes
WHERE task_id = %s
ORDER BY idx ASC
""",
(cache_key,),
binary=True,
)

for row in cur:
results.append((
row['task_id'],
row['channel'],
row['type'],
row['blob']
))
except DatabaseError as e:
# Log the error or handle it as appropriate for your application
# Optionally re-raise the error if you want it to propagate
# raise
raise RuntimeError(
f"Exception occurred while fetching writes from the database: {e}"
)

Check failure on line 299 in libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/checkpoint-postgres / lint #3.11

Ruff (B904)

langgraph/checkpoint/postgres/__init__.py:297:13: B904 Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
return self._load_writes(results)

def put(
self,
config: RunnableConfig,
Expand Down
20 changes: 15 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hashlib import md5
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

from langchain_core.runnables import RunnableConfig
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -177,14 +177,24 @@ def _dump_blobs(
]

def _load_writes(
self, writes: list[tuple[bytes, bytes, bytes, bytes]]
self,
writes: list[
tuple[
Union[str, bytes],
Union[str, bytes],
Union[str, bytes],
Union[str, bytes],
]
],
) -> list[tuple[str, str, Any]]:
return (
[
(
tid.decode(),
channel.decode(),
self.serde.loads_typed((t.decode(), v)),
tid.decode() if isinstance(tid, bytes) else tid,
channel.decode() if isinstance(channel, bytes) else channel,
self.serde.loads_typed(
(t.decode() if isinstance(t, bytes) else t, v)
),
)
for tid, channel, t, v in writes
]
Expand Down
24 changes: 24 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,30 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
"""
return current + 1 if current is not None else 1

def aget_writes_by_cache_key(
self, cache_key: str
) -> Optional[List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.

Args:
cache_key (str): The cache key to use for retrieving the checkpoint.

Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.
"""
raise NotImplementedError

def get_writes_by_cache_key(self, cache_key: str) -> Optional[List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.

Args:
cache_key (str): The cache key to use for retrieving the checkpoint.

Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.
"""
pass


class EmptyChannelError(Exception):
"""Raised when attempting to get the value of a channel that hasn't been updated
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
is_writable_managed_value,
)
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.types import All, RetryPolicy
from langgraph.pregel.types import All, CachePolicy, RetryPolicy
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.fields import get_field_default
Expand All @@ -68,6 +68,7 @@ class StateNodeSpec(NamedTuple):
metadata: dict[str, Any]
input: Type[Any]
retry_policy: Optional[RetryPolicy]
cache_policy: Optional[CachePolicy]


class StateGraph(Graph):
Expand Down Expand Up @@ -202,6 +203,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> None:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand Down Expand Up @@ -249,6 +251,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> None:
"""Adds a new node to the state graph.

Expand Down Expand Up @@ -343,6 +346,7 @@ def add_node(
metadata,
input=input or self.schema,
retry_policy=retry,
cache_policy=cache,
)

def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None:
Expand Down Expand Up @@ -571,6 +575,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
],
metadata=node.metadata,
retry_policy=node.retry_policy,
cache_policy=node.cache_policy,
bound=node.runnable,
)

Expand Down
56 changes: 37 additions & 19 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def prepare_single_task(
checkpoint_id = UUID(checkpoint["id"]).bytes
configurable = config.get("configurable", {})
parent_ns = configurable.get("checkpoint_ns", "")

if task_path[0] == TASKS:
idx = int(task_path[1])
packet = checkpoint["pending_sends"][idx]
Expand All @@ -357,14 +356,24 @@ def prepare_single_task(
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
TASKS,
str(idx),
)

if proc.cache_policy:
cache_key = proc.cache_policy.cache_key
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
TASKS,
str(idx),
)

if task_id_checksum is not None:
assert task_id == task_id_checksum
if for_execution:
Expand Down Expand Up @@ -423,7 +432,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down Expand Up @@ -465,14 +474,23 @@ def prepare_single_task(
"langgraph_path": task_path,
}
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
SUBSCRIPTIONS,
*triggers,
)

if proc.cache_policy:
cache_key = proc.cache_policy.cache_key
task_id = _uuid5_str(
b"",
name,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
SUBSCRIPTIONS,
*triggers,
)
if task_id_checksum is not None:
assert task_id == task_id_checksum

Expand Down Expand Up @@ -531,7 +549,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down
11 changes: 11 additions & 0 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,17 @@ def tick(
self.status = "done"
return False

if self.checkpointer:
for task in self.tasks.values():
# if there are cached writes, apply them
cached_writes = self.checkpointer.get_writes_by_cache_key(task.id)
if cached_writes and not task.writes:
zhouyuf6741 marked this conversation as resolved.
Show resolved Hide resolved
# Extract only the last two items from each write tuple
task.writes.extend(
[(channel, value) for _, channel, value in cached_writes]
)
self._output_writes(task.id, task.writes, cached=True)

# if there are pending writes from a previous loop, apply them
if self.skip_done_tasks and self.checkpoint_pending_writes:
for tid, k, v in self.checkpoint_pending_writes:
Expand Down
5 changes: 5 additions & 0 deletions libs/langgraph/langgraph/pregel/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from langgraph.constants import CONFIG_KEY_READ
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.types import CachePolicy
from langgraph.pregel.write import ChannelWrite
from langgraph.utils.config import merge_configs
from langgraph.utils.runnable import RunnableCallable, RunnableSeq
Expand Down Expand Up @@ -122,6 +123,8 @@ class PregelNode(Runnable):

config: RunnableConfig

cache_policy: Optional[CachePolicy]

def __init__(
self,
*,
Expand All @@ -134,6 +137,7 @@ def __init__(
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
config: Optional[RunnableConfig] = None,
cache_policy: Optional[CachePolicy] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
Expand All @@ -144,6 +148,7 @@ def __init__(
self.config = merge_configs(
config, {"tags": tags or [], "metadata": metadata or {}}
)
self.cache_policy = cache_policy

def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
Expand Down
14 changes: 12 additions & 2 deletions libs/langgraph/langgraph/pregel/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union
from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Type, Union

from langchain_core.runnables import Runnable, RunnableConfig

Expand Down Expand Up @@ -60,7 +60,17 @@
class CachePolicy(NamedTuple):
"""Configuration for caching nodes."""

pass
cache_key: Optional[Callable[[Any, Optional[RunnableConfig]], str]]
"""
A function that takes in the input and config, and returns a string key
under which the output should be cached.
"""
# TODO: implement cache_ttl
# cache_ttl: Optional[float] = None
# """
# Time-to-live for the cached value, in seconds. If not provided, the value will be cached indefinitely.
# We'd probably want to store this in a bucket way intead of a TTL timeline.

Check failure on line 72 in libs/langgraph/langgraph/pregel/types.py

View workflow job for this annotation

GitHub Actions / (Check for spelling errors)

intead ==> instead
# """


class PregelTask(NamedTuple):
Expand Down
27 changes: 27 additions & 0 deletions libs/langgraph/tests/__snapshots__/test_cache.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# serializer version: 1
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch_with_cache[postgres].1
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
Loading
Loading