Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding postgres cache support for node #1648

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import threading
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union

from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
from psycopg import Connection, Cursor, DatabaseError, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import dict_row
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -251,6 +251,57 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
self._load_writes(value["pending_writes"]),
)

def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get checkpoint writes from the database based on multiple task IDs.

This method retrieves checkpoint writes from the Postgres database based on the
provided list of task IDs.

Args:
task_ids (List[str]): A list of task IDs to retrieve checkpoint writes for.

Returns:
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.

Examples:
>>> task_ids = ["task1", "task2", "task3"]
>>> checkpoint_writes = memory.get_writes_by_task_ids(task_ids)
>>> for task_id, writes in checkpoint_writes.items():
... print(f"Task ID: {task_id}")
... for write in writes:
... print(f" {write}")
"""
results = {}
try:
with self._cursor() as cur:
cur.execute(
"""
SELECT task_id, channel, type, blob
FROM checkpoint_writes
WHERE task_id = ANY(%s)
ORDER BY task_id, idx ASC
""",
(task_ids,),
binary=True,
)

for row in cur:
task_id = row["task_id"]
if task_id not in results:
results[task_id] = []
# Appending all the writes to the mapping task_id
results[task_id].append(
(row["task_id"], row["channel"], row["type"], row["blob"])
)
except DatabaseError as e:
raise RuntimeError(
f"Exception occurred while fetching writes from the database: {e}"
) from e

return {
task_id: self._load_writes(writes) for task_id, writes in results.items()
}

def put(
self,
config: RunnableConfig,
Expand Down
21 changes: 16 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hashlib import md5
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

from langchain_core.runnables import RunnableConfig
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -177,14 +177,25 @@ def _dump_blobs(
]

def _load_writes(
self, writes: list[tuple[bytes, bytes, bytes, bytes]]
self,
# We want to support both bytes and strings here
writes: list[
tuple[
Union[str, bytes],
Union[str, bytes],
Union[str, bytes],
Union[str, bytes],
]
],
) -> list[tuple[str, str, Any]]:
return (
[
(
tid.decode(),
channel.decode(),
self.serde.loads_typed((t.decode(), v)),
tid.decode() if isinstance(tid, bytes) else tid,
channel.decode() if isinstance(channel, bytes) else channel,
self.serde.loads_typed(
(t.decode() if isinstance(t, bytes) else t, v)
),
)
for tid, channel, t, v in writes
]
Expand Down
22 changes: 22 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,28 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
"""
return current + 1 if current is not None else 1

def aget_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.

Args:
task_ids (str): The task id is serving as the cache key to retrieve writes.

Returns:
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.
"""
raise NotImplementedError

def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.

Args:
task_ids (str): The task id is serving as the cache key to retrieve writes.

Returns:
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.
"""
pass


class EmptyChannelError(Exception):
"""Raised when attempting to get the value of a channel that hasn't been updated
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
is_writable_managed_value,
)
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.types import All, RetryPolicy
from langgraph.pregel.types import All, CachePolicy, RetryPolicy
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.fields import get_field_default
Expand All @@ -68,6 +68,7 @@ class StateNodeSpec(NamedTuple):
metadata: dict[str, Any]
input: Type[Any]
retry_policy: Optional[RetryPolicy]
cache_policy: Optional[CachePolicy]


class StateGraph(Graph):
Expand Down Expand Up @@ -202,6 +203,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> None:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand Down Expand Up @@ -249,6 +251,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> None:
"""Adds a new node to the state graph.

Expand Down Expand Up @@ -343,6 +346,7 @@ def add_node(
metadata,
input=input or self.schema,
retry_policy=retry,
cache_policy=cache,
)

def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None:
Expand Down Expand Up @@ -571,6 +575,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
],
metadata=node.metadata,
retry_policy=node.retry_policy,
cache_policy=node.cache_policy,
bound=node.runnable,
)

Expand Down
58 changes: 38 additions & 20 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def prepare_single_task(
checkpoint_id = UUID(checkpoint["id"]).bytes
configurable = config.get("configurable", {})
parent_ns = configurable.get("checkpoint_ns", "")

if task_path[0] == TASKS:
idx = int(task_path[1])
packet = checkpoint["pending_sends"][idx]
Expand All @@ -357,18 +356,28 @@ def prepare_single_task(
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
TASKS,
str(idx),
)
proc = processes[packet.node]

if proc.cache_policy:
cache_key = proc.cache_policy.cache_key
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
TASKS,
str(idx),
)

if task_id_checksum is not None:
assert task_id == task_id_checksum
if for_execution:
proc = processes[packet.node]
if node := proc.node:
managed.replace_runtime_placeholders(step, packet.arg)
writes = deque()
Expand Down Expand Up @@ -423,7 +432,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down Expand Up @@ -465,14 +474,23 @@ def prepare_single_task(
"langgraph_path": task_path,
}
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
SUBSCRIPTIONS,
*triggers,
)

if proc.cache_policy:
cache_key = proc.cache_policy.cache_key
task_id = _uuid5_str(
b"",
name,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
SUBSCRIPTIONS,
*triggers,
)
if task_id_checksum is not None:
assert task_id == task_id_checksum

Expand Down Expand Up @@ -531,7 +549,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down
21 changes: 21 additions & 0 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,27 @@ def tick(
self.status = "done"
return False

if self.checkpointer:
# only fetching tasks with cache policy enabled
tasks_with_cache_policy = [
task for task in self.tasks.values() if task.cache_policy
]

# if there are cached writes, apply them
task_cached_writes = self.checkpointer.get_writes_by_task_ids(
[task.id for task in tasks_with_cache_policy]
)
if task_cached_writes:
for task_id, cached_writes in task_cached_writes.items():
# if there are cached writes and task doesn't have any writes yet, apply them
if cached_writes and not self.tasks[task_id].writes:
task = self.tasks[task_id]
# Extract only the channel and value from cached_writes
task.writes.extend(
[(channel, value) for _, channel, value in cached_writes]
)
self._output_writes(task_id, task.writes, cached=True)

# if there are pending writes from a previous loop, apply them
if self.skip_done_tasks and self.checkpoint_pending_writes:
for tid, k, v in self.checkpoint_pending_writes:
Expand Down
5 changes: 5 additions & 0 deletions libs/langgraph/langgraph/pregel/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from langgraph.constants import CONFIG_KEY_READ
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.types import CachePolicy
from langgraph.pregel.write import ChannelWrite
from langgraph.utils.config import merge_configs
from langgraph.utils.runnable import RunnableCallable, RunnableSeq
Expand Down Expand Up @@ -122,6 +123,8 @@ class PregelNode(Runnable):

config: RunnableConfig

cache_policy: Optional[CachePolicy]

def __init__(
self,
*,
Expand All @@ -134,6 +137,7 @@ def __init__(
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
config: Optional[RunnableConfig] = None,
cache_policy: Optional[CachePolicy] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
Expand All @@ -144,6 +148,7 @@ def __init__(
self.config = merge_configs(
config, {"tags": tags or [], "metadata": metadata or {}}
)
self.cache_policy = cache_policy

def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
Expand Down
12 changes: 11 additions & 1 deletion libs/langgraph/langgraph/pregel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@ class RetryPolicy(NamedTuple):
class CachePolicy(NamedTuple):
"""Configuration for caching nodes."""

pass
cache_key: Optional[Callable[[Any, Optional[RunnableConfig]], str]]
"""
A function that takes in the input and config, and returns a string key
under which the output should be cached.
"""
# TODO: implement cache_ttl
# cache_ttl: Optional[float] = None
# """
# Time-to-live for the cached value, in seconds. If not provided, the value will be cached indefinitely.
# We'd probably want to store this in a bucket way instead of a TTL timeline.
# """


class PregelTask(NamedTuple):
Expand Down
Loading
Loading