Skip to content

Commit

Permalink
Merge pull request #1418 from langchain-ai/nc/21aug/fix-pending-inter…
Browse files Browse the repository at this point in the history
…rupt-run

Fix pending run when interrupt exception is used
  • Loading branch information
nfcampos authored Aug 21, 2024
2 parents 06c2481 + 47ed3d9 commit 46171dd
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 18 deletions.
16 changes: 6 additions & 10 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ def stream(
checkpointer=checkpointer,
nodes=self.nodes,
specs=self.channels,
output_keys=output_keys,
) as loop:
# Similarly to Bulk Synchronous Parallel / Pregel model
# computation proceeds in steps, while there are channel updates
Expand All @@ -969,7 +970,6 @@ def stream(
# with channel updates applied only at the transition between steps
while loop.tick(
input_keys=self.input_channels,
output_keys=output_keys,
stream_keys=self.stream_channels_asis,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
Expand Down Expand Up @@ -1088,8 +1088,8 @@ def stream(
"without hitting a stop condition. You can increase the "
"limit by setting the `recursion_limit` config key."
)
# set final channel values as run output
run_manager.on_chain_end(read_channels(loop.channels, output_keys))
# set final channel values as run output
run_manager.on_chain_end(loop.output)
except BaseException as e:
run_manager.on_chain_error(e)
raise
Expand Down Expand Up @@ -1218,6 +1218,7 @@ async def astream(
checkpointer=checkpointer,
nodes=self.nodes,
specs=self.channels,
output_keys=output_keys,
) as loop:
aioloop = asyncio.get_event_loop()
# Similarly to Bulk Synchronous Parallel / Pregel model
Expand All @@ -1227,7 +1228,6 @@ async def astream(
# with channel updates applied only at the transition between steps
while loop.tick(
input_keys=self.input_channels,
output_keys=output_keys,
stream_keys=self.stream_channels_asis,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
Expand Down Expand Up @@ -1349,13 +1349,9 @@ async def astream(
"without hitting a stop condition. You can increase the "
"limit by setting the `recursion_limit` config key."
)

# set final channel values as run output
await run_manager.on_chain_end(
read_channels(loop.channels, output_keys)
)
# set final channel values as run output
await run_manager.on_chain_end(loop.output)
except BaseException as e:
# TODO use on_chain_end if exc is GraphInterrupt
await asyncio.shield(run_manager.on_chain_error(e))
raise

Expand Down
30 changes: 24 additions & 6 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@
BackgroundExecutor,
Submit,
)
from langgraph.pregel.io import map_input, map_output_updates, map_output_values, single
from langgraph.pregel.io import (
map_input,
map_output_updates,
map_output_values,
read_channels,
single,
)
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.read import PregelNode
from langgraph.pregel.types import PregelExecutableTask
Expand All @@ -83,6 +89,7 @@ class PregelLoop:
checkpointer: Optional[BaseCheckpointSaver]
nodes: Mapping[str, PregelNode]
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]]
output_keys: Union[str, Sequence[str]]
is_nested: bool

checkpointer_get_next_version: Callable[[Optional[V]], V]
Expand Down Expand Up @@ -117,6 +124,7 @@ class PregelLoop:
]
tasks: Sequence[PregelExecutableTask]
stream: deque[Tuple[str, Any]]
output: Union[None, dict[str, Any], Any] = None

# public

Expand All @@ -129,6 +137,7 @@ def __init__(
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]],
) -> None:
self.stream = deque()
self.input = input
Expand All @@ -137,6 +146,7 @@ def __init__(
self.checkpointer = checkpointer
self.nodes = nodes
self.specs = specs
self.output_keys = output_keys
self.is_nested = CONFIG_KEY_READ in self.config.get("configurable", {})

def mark_tasks_scheduled(self, tasks: Sequence[PregelExecutableTask]) -> None:
Expand Down Expand Up @@ -167,7 +177,6 @@ def tick(
self,
*,
input_keys: Union[str, Sequence[str]],
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
interrupt_after: Sequence[str] = EMPTY_SEQ,
interrupt_before: Sequence[str] = EMPTY_SEQ,
Expand Down Expand Up @@ -196,15 +205,15 @@ def tick(
# produce values output
self.stream.extend(
("values", v)
for v in map_output_values(output_keys, writes, self.channels)
for v in map_output_values(self.output_keys, writes, self.channels)
)
# clear pending writes
self.checkpoint_pending_writes.clear()
# save checkpoint
self._put_checkpoint(
{
"source": "loop",
"writes": single(map_output_updates(output_keys, self.tasks)),
"writes": single(map_output_updates(self.output_keys, self.tasks)),
}
)
# after execution, check if we should interrupt
Expand Down Expand Up @@ -272,7 +281,7 @@ def tick(
if all(task.writes for task in self.tasks):
return self.tick(
input_keys=input_keys,
output_keys=output_keys,
stream_keys=stream_keys,
interrupt_after=interrupt_after,
interrupt_before=interrupt_before,
manager=manager,
Expand Down Expand Up @@ -406,7 +415,12 @@ def _suppress_interrupt(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exc_value, GraphInterrupt) and not self.is_nested:
suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
if suppress or exc_type is None:
# save final output
self.output = read_channels(self.channels, self.output_keys)
if suppress:
# suppress interrupt
return True


Expand All @@ -420,6 +434,7 @@ def __init__(
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
) -> None:
super().__init__(
input,
Expand All @@ -428,6 +443,7 @@ def __init__(
store=store,
nodes=nodes,
specs=specs,
output_keys=output_keys,
)
self.stack = ExitStack()
if checkpointer:
Expand Down Expand Up @@ -505,6 +521,7 @@ def __init__(
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
) -> None:
super().__init__(
input,
Expand All @@ -513,6 +530,7 @@ def __init__(
store=store,
nodes=nodes,
specs=specs,
output_keys=output_keys,
)
self.store = AsyncBatchedStore(self.store) if self.store else None
self.stack = AsyncExitStack()
Expand Down
91 changes: 91 additions & 0 deletions libs/langgraph/tests/fake_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any, Optional
from uuid import UUID

from langchain_core.messages.base import BaseMessage
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import BaseTracer, Run


class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution.
It replaces run ids with deterministic UUIDs for snapshotting."""

def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: list[Run] = []
self.uuids_map: dict[UUID, UUID] = {}
self.uuids_generator = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)

def _replace_uuid(self, uuid: UUID) -> UUID:
if uuid not in self.uuids_map:
self.uuids_map[uuid] = next(self.uuids_generator)
return self.uuids_map[uuid]

def _replace_message_id(self, maybe_message: Any) -> Any:
if isinstance(maybe_message, BaseMessage):
maybe_message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, ChatGeneration):
maybe_message.message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, LLMResult):
for i, gen_list in enumerate(maybe_message.generations):
for j, gen in enumerate(gen_list):
maybe_message.generations[i][j] = self._replace_message_id(gen)
if isinstance(maybe_message, dict):
for k, v in maybe_message.items():
maybe_message[k] = self._replace_message_id(v)
if isinstance(maybe_message, list):
for i, v in enumerate(maybe_message):
maybe_message[i] = self._replace_message_id(v)

return maybe_message

def _copy_run(self, run: Run) -> Run:
if run.dotted_order:
levels = run.dotted_order.split(".")
processed_levels = []
for level in levels:
timestamp, run_id = level.split("Z")
new_run_id = self._replace_uuid(UUID(run_id))
processed_level = f"{timestamp}Z{new_run_id}"
processed_levels.append(processed_level)
new_dotted_order = ".".join(processed_levels)
else:
new_dotted_order = None
return run.copy(
update={
"id": self._replace_uuid(run.id),
"parent_run_id": (
self.uuids_map[run.parent_run_id] if run.parent_run_id else None
),
"child_runs": [self._copy_run(child) for child in run.child_runs],
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order,
"inputs": self._replace_message_id(run.inputs),
"outputs": self._replace_message_id(run.outputs),
}
)

def _persist_run(self, run: Run) -> None:
"""Persist a run."""

self.runs.append(self._copy_run(run))

def flattened_runs(self) -> list[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result

@property
def run_ids(self) -> list[Optional[UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
12 changes: 11 additions & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from langgraph.pregel.types import PregelTask
from langgraph.store.memory import MemoryStore
from tests.any_str import AnyStr, ExceptionLike
from tests.fake_tracer import FakeTracer
from tests.memory_assert import (
MemorySaverAssertCheckpointMetadata,
MemorySaverAssertImmutable,
Expand Down Expand Up @@ -6159,11 +6160,20 @@ def tool_two_node(s: State) -> State:
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert tool_two.invoke({"my_key": "value", "market": "DE"}) == {
tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}

assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
Expand Down
12 changes: 11 additions & 1 deletion libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from langgraph.pregel.types import PregelTask
from langgraph.store.memory import MemoryStore
from tests.any_str import AnyStr, ExceptionLike
from tests.fake_tracer import FakeTracer
from tests.memory_assert import (
MemorySaverAssertCheckpointMetadata,
MemorySaverAssertImmutable,
Expand Down Expand Up @@ -225,11 +226,20 @@ async def tool_two_node(s: State) -> State:
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}) == {
tracer = FakeTracer()
assert await tool_two.ainvoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}

assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
Expand Down

0 comments on commit 46171dd

Please sign in to comment.