Skip to content

Commit

Permalink
Merge pull request #2092 from langchain-ai/an/11oct/remote-graph-inte…
Browse files Browse the repository at this point in the history
…rrupt

Update `stream()` and `astream()` methods in `RemoteGraph` to process `updates` event types
  • Loading branch information
nfcampos authored Oct 23, 2024
2 parents 6f236b5 + 037a95f commit 08a1ed3
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 112 deletions.
144 changes: 100 additions & 44 deletions libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,41 @@
from langchain_core.runnables.graph import (
Node as DrawableNode,
)
from langchain_core.runnables.schema import StandardStreamEvent, StreamEvent
from langgraph_sdk.client import (
LangGraphClient,
SyncLangGraphClient,
get_client,
get_sync_client,
)
from langgraph_sdk.schema import Checkpoint, ThreadState
from langgraph_sdk.schema import StreamMode as StreamModeSDK
from typing_extensions import Self

from langgraph.checkpoint.base import CheckpointMetadata
from langgraph.constants import INTERRUPT
from langgraph.errors import GraphInterrupt
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode
from langgraph.types import Interrupt
from langgraph.utils.config import merge_configs


class RemoteException(Exception):
"""Exception raised when an error occurs in the remote graph."""

pass


class RemoteGraph(PregelProtocol, Runnable):
def __init__(
self,
graph_id: str,
config: Optional[RunnableConfig] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
client: Optional[LangGraphClient] = None,
sync_client: Optional[SyncLangGraphClient] = None,
config: Optional[RunnableConfig] = None,
):
"""Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.
Expand Down Expand Up @@ -348,6 +356,37 @@ async def aupdate_state(
)
return self._get_config(response["checkpoint"])

def _get_stream_modes(
self,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]],
default: StreamMode = "updates",
) -> tuple[list[StreamModeSDK], bool, bool]:
"""Return a tuple of the final list of stream modes sent to the
remote graph and a boolean flag indicating if stream mode 'updates'
was present in the original list of stream modes.
'updates' mode is added to the list of stream modes so that interrupts
can be detected in the remote graph.
"""
updated_stream_modes: list[StreamMode] = []
req_updates = False
req_single = True
# coerce to list, or add default stream mode
if stream_mode:
if isinstance(stream_mode, str):
updated_stream_modes.append(stream_mode)
else:
req_single = False
updated_stream_modes.extend(stream_mode)
else:
updated_stream_modes.append(default)
# add 'updates' mode if not present
if "updates" in updated_stream_modes:
req_updates = True
else:
updated_stream_modes.append("updates")
return (updated_stream_modes, req_updates, req_single)

def stream(
self,
input: Union[dict[str, Any], Any],
Expand All @@ -360,18 +399,40 @@ def stream(
) -> Iterator[Union[dict[str, Any], Any]]:
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

for chunk in self.sync_client.runs.stream(
thread_id=sanitized_config["configurable"]["thread_id"],
thread_id=cast(str, sanitized_config["configurable"]["thread_id"]),
assistant_id=self.graph_id,
input=input,
config=sanitized_config,
stream_mode=stream_mode, # type: ignore
interrupt_before=interrupt_before, # type: ignore
interrupt_after=interrupt_after, # type: ignore
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs,
if_not_exists="create",
):
yield chunk
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
if not req_updates:
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if subgraphs:
if "|" in chunk.event:
mode, ns_ = chunk.event.split("|", 1)
ns = tuple(ns_.split("|"))
else:
mode, ns = chunk.event, ()
if req_single:
yield ns, chunk.data
else:
yield ns, mode, chunk.data
elif req_single:
yield chunk.data
else:
yield chunk

async def astream(
self,
Expand All @@ -385,47 +446,40 @@ async def astream(
) -> AsyncIterator[Union[dict[str, Any], Any]]:
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode)

async for chunk in self.client.runs.stream(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
input=input,
config=sanitized_config,
stream_mode=stream_mode if stream_mode else "values", # type: ignore
interrupt_before=interrupt_before, # type: ignore
interrupt_after=interrupt_after, # type: ignore
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs,
if_not_exists="create",
):
yield chunk

async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)

# manually add 'events' to stream modes list
stream_mode: list[str] = kwargs.get("stream_mode", [])
if "events" not in stream_mode:
stream_mode.append("events")

async for chunk in self.client.runs.stream(
thread_id=sanitized_config["configurable"]["thread_id"],
assistant_id=self.graph_id,
input=input,
config=sanitized_config,
stream_mode=stream_mode, # type: ignore
interrupt_before=kwargs.get("interrupt_before"),
interrupt_after=kwargs.get("interrupt_after"),
stream_subgraphs=kwargs.get("subgraphs", False),
):
yield StandardStreamEvent(
event=chunk.event,
data=chunk.data,
)
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
if not req_updates:
continue
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
if subgraphs:
if "|" in chunk.event:
mode, ns_ = chunk.event.split("|", 1)
ns = tuple(ns_.split("|"))
else:
mode, ns = chunk.event, ()
if req_single:
yield ns, chunk.data
else:
yield ns, mode, chunk.data
elif req_single:
yield chunk.data
else:
yield chunk

def invoke(
self,
Expand All @@ -443,8 +497,9 @@ def invoke(
assistant_id=self.graph_id,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before, # type: ignore
interrupt_after=interrupt_after, # type: ignore
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
if_not_exists="create",
)

async def ainvoke(
Expand All @@ -463,6 +518,7 @@ async def ainvoke(
assistant_id=self.graph_id,
input=input,
config=sanitized_config,
interrupt_before=interrupt_before, # type: ignore
interrupt_after=interrupt_after, # type: ignore
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
if_not_exists="create",
)
Loading

0 comments on commit 08a1ed3

Please sign in to comment.