Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyuf6741 committed Sep 7, 2024
1 parent 6c3fff9 commit 9b38736
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 24 deletions.
24 changes: 9 additions & 15 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,17 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
self._load_writes(value["pending_writes"]),
)

def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List[Any]]]:
"""Get checkpoint writes from the database based on a cache key.
def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get checkpoint writes from the database based on multiple task IDs.
This method retrieves checkpoint writes from the Postgres database based on the
provided cache key.
provided list of task IDs.
Args:
task_ids (str): The task id is serving as the cache key to retrieve writes.
task_ids (List[str]): A list of task IDs to retrieve checkpoint writes for.
Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.
Examples:
>>> task_ids = ["task1", "task2", "task3"]
Expand All @@ -270,7 +270,7 @@ def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List
... print(f"Task ID: {task_id}")
... for write in writes:
... print(f" {write}")
"""
"""
results = {}
try:
with self._cursor() as cur:
Expand All @@ -279,7 +279,7 @@ def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List
SELECT task_id, channel, type, blob
FROM checkpoint_writes
WHERE task_id = ANY(%s)
ORDER BY idx ASC
ORDER BY task_id, idx ASC
""",
(task_ids,),
binary=True,
Expand All @@ -297,17 +297,11 @@ def get_writes_by_task_ids(self, task_ids: List[str]) -> Optional[Dict[str, List
row['blob']
))
except DatabaseError as e:
# Log the error or handle it as appropriate for your application
# Optionally re-raise the error if you want it to propagate
# raise
raise RuntimeError(
f"Exception occurred while fetching writes from the database: {e}"
)
) from e

for task_id, writes in results.items():
# check writes are loaded correctly
results[task_id] = self._load_writes(writes)
return results
return {task_id: self._load_writes(writes) for task_id, writes in results.items()}

def put(
self,
Expand Down
12 changes: 4 additions & 8 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,29 +422,25 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
"""
return current + 1 if current is not None else 1

def aget_writes_by_task_ids(
self, task_ids: List[str]
) -> Optional[Dict[str, List[Any]]]:
def aget_writes_by_task_ids(self, task_id: List[str]) -> Dict[str, List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.
Args:
task_ids (str): The task id is serving as the cache key to retrieve writes.
Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.
"""
raise NotImplementedError

def get_writes_by_task_ids(
self, task_ids: List[str]
) -> Optional[Dict[str, List[Any]]]:
def get_writes_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Get a checkpoint writes from the database based on a cache key.
Args:
task_ids (str): The task id is serving as the cache key to retrieve writes.
Returns:
List[Any]: A list of retrieved checkpoint writes. Empty list if none found.
Dict[str, List[Any]]: A dictionary where keys are task IDs and values are lists of checkpoint writes.
"""
pass

Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def prepare_single_task(
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)

if proc.cache_policy:
if processes[packet.node].cache_policy:
cache_key = proc.cache_policy.cache_key
task_id = _uuid5_str(
b"",
Expand Down

0 comments on commit 9b38736

Please sign in to comment.