From 2f819a6a9be97636f8413d24880ea65c509b8f88 Mon Sep 17 00:00:00 2001 From: Andrew Nguonly Date: Fri, 11 Oct 2024 18:40:40 -0700 Subject: [PATCH 1/9] Update stream() and astream() to process 'updates' event types. --- libs/langgraph/langgraph/pregel/remote.py | 55 ++++++++++- ..._remote_pregel.py => test_remote_graph.py} | 91 +++++++++++++++---- 2 files changed, 127 insertions(+), 19 deletions(-) rename libs/langgraph/tests/{test_remote_pregel.py => test_remote_graph.py} (85%) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 2f1c0c41e..4b7e87421 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -26,10 +26,12 @@ get_client, get_sync_client, ) -from langgraph_sdk.schema import Checkpoint, ThreadState +from langgraph_sdk.schema import Checkpoint, StreamPart, ThreadState from typing_extensions import Self from langgraph.checkpoint.base import CheckpointMetadata +from langgraph.constants import INTERRUPT +from langgraph.errors import GraphInterrupt from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode from langgraph.types import Interrupt @@ -345,6 +347,35 @@ async def aupdate_state( ) return self._get_config(response["checkpoint"]) + def _get_stream_modes( + self, + stream_mode: Optional[Union[StreamMode, list[StreamMode]]], + ) -> tuple[list[StreamMode], bool]: + """Return a tuple of the final list of stream modes sent to the + remote graph and a boolean flag indicating if stream mode 'updates' + was present in the original list of stream modes. + + 'updates' mode is added to the list of stream modes so that interrupts + can be detected in the remote graph. + """ + updated_stream_modes = [] + updates_mode = False + + if stream_mode: + if isinstance(stream_mode, str): + updated_stream_modes.append(stream_mode) + else: + updated_stream_modes.extend(stream_mode) + + if "updates" in updated_stream_modes: + updates_mode = True + else: + updated_stream_modes.append("updates") + else: + updated_stream_modes.extend(["values", "updates"]) + + return (updated_stream_modes, updates_mode) + def stream( self, input: Union[dict[str, Any], Any], @@ -357,17 +388,26 @@ def stream( ) -> Iterator[Union[dict[str, Any], Any]]: merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) + updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) for chunk in self.sync_client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=stream_mode, # type: ignore + stream_mode=updated_stream_modes, # type: ignore interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): + if chunk.event == INTERRUPT: + raise GraphInterrupt() + + # Don't emit 'updates' events if the original list of stream modes + # didn't include it. + if chunk.event == "updates" and not include_updates: + continue + yield chunk async def astream( @@ -382,17 +422,26 @@ async def astream( ) -> AsyncIterator[Union[dict[str, Any], Any]]: merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) + updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) async for chunk in self.client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=stream_mode if stream_mode else "values", # type: ignore + stream_mode=updated_stream_modes, # type: ignore interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): + if chunk.event == INTERRUPT: + raise GraphInterrupt() + + # Don't emit 'updates' events if the original list of stream modes + # didn't include it. + if chunk.event == "updates" and not include_updates: + continue + yield chunk async def astream_events( diff --git a/libs/langgraph/tests/test_remote_pregel.py b/libs/langgraph/tests/test_remote_graph.py similarity index 85% rename from libs/langgraph/tests/test_remote_pregel.py rename to libs/langgraph/tests/test_remote_graph.py index f2ca53583..dd07d27bb 100644 --- a/libs/langgraph/tests/test_remote_pregel.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -7,7 +7,9 @@ from langchain_core.runnables.graph import ( Node as DrawableNode, ) +from langgraph_sdk.schema import StreamPart +from langgraph.errors import GraphInterrupt from langgraph.pregel.remote import RemoteGraph from langgraph.pregel.types import StateSnapshot @@ -473,17 +475,46 @@ def test_stream(): # set up test mock_sync_client = MagicMock() mock_sync_client.runs.stream.return_value = [ - {"chunk": "data1"}, - {"chunk": "data2"}, - {"chunk": "data3"}, + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="__interrupt__", data={}), ] # call method / assertions remote_pregel = RemoteGraph(sync_client=mock_sync_client, graph_id="test_graph_id") - config = {"configurable": {"thread_id": "thread_1"}} - result = list(remote_pregel.stream({"input": "data"}, config)) - assert result == [{"chunk": "data1"}, {"chunk": "data2"}, {"chunk": "data3"}] + # stream modes doesn't include 'updates' + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + ] + + # stream modes includes 'updates' + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + ] @pytest.mark.anyio @@ -492,20 +523,47 @@ async def test_astream(): mock_async_client = MagicMock() async_iter = MagicMock() async_iter.__aiter__.return_value = [ - {"chunk": "data1"}, - {"chunk": "data2"}, - {"chunk": "data3"}, + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="__interrupt__", data={}), ] mock_async_client.runs.stream.return_value = async_iter # call method / assertions remote_pregel = RemoteGraph(client=mock_async_client, graph_id="test_graph_id") - config = {"configurable": {"thread_id": "thread_1"}} - chunks = [] - async for chunk in remote_pregel.astream({"input": "data"}, config): - chunks.append(chunk) - assert chunks == [{"chunk": "data1"}, {"chunk": "data2"}, {"chunk": "data3"}] + # stream modes doesn't include 'updates' + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + ] + + # stream modes includes 'updates' + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + StreamPart(event="values", data={"chunk": "data1"}), + StreamPart(event="values", data={"chunk": "data2"}), + StreamPart(event="values", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + ] def test_invoke(): @@ -572,7 +630,7 @@ async def test_langgraph_cloud_integration(): "messages": [ { "role": "human", - "content": "Hello world!", + "content": "What's the weather in SF?", } ] } @@ -580,7 +638,8 @@ async def test_langgraph_cloud_integration(): # test invoke response = app.invoke( input, - config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}}, + config={"configurable": {"thread_id": "39a6104a-34e7-4f83-929c-d9eb163003c9"}}, + interrupt_before=["agent"], ) print("response:", response["messages"][-1].content) From a277b86fcb32b1d80fd5fb1f115a0d37854b61fc Mon Sep 17 00:00:00 2001 From: Andrew Nguonly Date: Fri, 11 Oct 2024 18:53:45 -0700 Subject: [PATCH 2/9] Fix unit test. --- libs/langgraph/langgraph/pregel/remote.py | 28 ++++++++++++----------- libs/langgraph/tests/test_remote_graph.py | 4 ++-- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 4b7e87421..b0df12bb7 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -400,13 +400,14 @@ def stream( interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): - if chunk.event == INTERRUPT: - raise GraphInterrupt() + if chunk.event == "updates": + if INTERRUPT in chunk.data: + raise GraphInterrupt() - # Don't emit 'updates' events if the original list of stream modes - # didn't include it. - if chunk.event == "updates" and not include_updates: - continue + # Don't emit 'updates' events if the original list of stream + # modes didn't include it. + if not include_updates: + continue yield chunk @@ -434,13 +435,14 @@ async def astream( interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): - if chunk.event == INTERRUPT: - raise GraphInterrupt() - - # Don't emit 'updates' events if the original list of stream modes - # didn't include it. - if chunk.event == "updates" and not include_updates: - continue + if chunk.event == "updates": + if INTERRUPT in chunk.data: + raise GraphInterrupt() + + # Don't emit 'updates' events if the original list of stream + # modes didn't include it. + if not include_updates: + continue yield chunk diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index dd07d27bb..a6a208b27 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -479,7 +479,7 @@ def test_stream(): StreamPart(event="values", data={"chunk": "data2"}), StreamPart(event="values", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), - StreamPart(event="__interrupt__", data={}), + StreamPart(event="updates", data={"__interrupt__": ()}), ] # call method / assertions @@ -527,7 +527,7 @@ async def test_astream(): StreamPart(event="values", data={"chunk": "data2"}), StreamPart(event="values", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), - StreamPart(event="__interrupt__", data={}), + StreamPart(event="updates", data={"__interrupt__": ()}), ] mock_async_client.runs.stream.return_value = async_iter From 19ccb0c6af5bfed81551fd1ca014acf7c36d2ec8 Mon Sep 17 00:00:00 2001 From: Andrew Nguonly Date: Fri, 11 Oct 2024 19:06:52 -0700 Subject: [PATCH 3/9] Update astream_events() to process interrupt. --- libs/langgraph/langgraph/pregel/remote.py | 30 ++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index b0df12bb7..28d9abd83 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -26,7 +26,8 @@ get_client, get_sync_client, ) -from langgraph_sdk.schema import Checkpoint, StreamPart, ThreadState +from langgraph_sdk.schema import Checkpoint, ThreadState +from langgraph_sdk.schema import StreamMode as StreamModeSDK from typing_extensions import Self from langgraph.checkpoint.base import CheckpointMetadata @@ -350,7 +351,7 @@ async def aupdate_state( def _get_stream_modes( self, stream_mode: Optional[Union[StreamMode, list[StreamMode]]], - ) -> tuple[list[StreamMode], bool]: + ) -> tuple[list[StreamModeSDK], bool]: """Return a tuple of the final list of stream modes sent to the remote graph and a boolean flag indicating if stream mode 'updates' was present in the original list of stream modes. @@ -395,7 +396,7 @@ def stream( assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=updated_stream_modes, # type: ignore + stream_mode=updated_stream_modes, interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, @@ -403,9 +404,6 @@ def stream( if chunk.event == "updates": if INTERRUPT in chunk.data: raise GraphInterrupt() - - # Don't emit 'updates' events if the original list of stream - # modes didn't include it. if not include_updates: continue @@ -430,7 +428,7 @@ async def astream( assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=updated_stream_modes, # type: ignore + stream_mode=updated_stream_modes, interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, @@ -438,9 +436,6 @@ async def astream( if chunk.event == "updates": if INTERRUPT in chunk.data: raise GraphInterrupt() - - # Don't emit 'updates' events if the original list of stream - # modes didn't include it. if not include_updates: continue @@ -456,20 +451,27 @@ async def astream_events( sanitized_config = self._sanitize_config(merged_config) # manually add 'events' to stream modes list - stream_mode: list[str] = kwargs.get("stream_mode", []) - if "events" not in stream_mode: - stream_mode.append("events") + stream_mode: list[StreamMode] = kwargs.get("stream_mode", []) + updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) + if "events" not in updated_stream_modes: + updated_stream_modes.append("events") async for chunk in self.client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=stream_mode, # type: ignore + stream_mode=updated_stream_modes, interrupt_before=kwargs.get("interrupt_before"), interrupt_after=kwargs.get("interrupt_after"), stream_subgraphs=kwargs.get("subgraphs", False), ): + if chunk.event == "updates": + if INTERRUPT in chunk.data: + raise GraphInterrupt() + if not include_updates: + continue + yield StandardStreamEvent( event=chunk.event, data=chunk.data, From 58cf0c6a6e1420efb7d4c360f1f6b7631e9bf617 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 22 Oct 2024 10:46:37 -0700 Subject: [PATCH 4/9] chore: Switch s3 client utils from httpx client to curl client --- libs/langgraph/langgraph/pregel/remote.py | 64 +++++++------ libs/sdk-py/langgraph_sdk/client.py | 108 +++++++++++----------- 2 files changed, 90 insertions(+), 82 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 28d9abd83..509f048d3 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -351,7 +351,8 @@ async def aupdate_state( def _get_stream_modes( self, stream_mode: Optional[Union[StreamMode, list[StreamMode]]], - ) -> tuple[list[StreamModeSDK], bool]: + default: StreamMode = "updates", + ) -> tuple[list[StreamModeSDK], bool, bool]: """Return a tuple of the final list of stream modes sent to the remote graph and a boolean flag indicating if stream mode 'updates' was present in the original list of stream modes. @@ -359,23 +360,24 @@ def _get_stream_modes( 'updates' mode is added to the list of stream modes so that interrupts can be detected in the remote graph. """ - updated_stream_modes = [] - updates_mode = False - + updated_stream_modes: list[StreamMode] = [] + req_updates = False + req_single = True + # coerce to list, or add default stream mode if stream_mode: if isinstance(stream_mode, str): updated_stream_modes.append(stream_mode) else: + req_single = False updated_stream_modes.extend(stream_mode) - - if "updates" in updated_stream_modes: - updates_mode = True - else: - updated_stream_modes.append("updates") else: - updated_stream_modes.extend(["values", "updates"]) - - return (updated_stream_modes, updates_mode) + updated_stream_modes.append(default) + # add 'updates' mode if not present + if "updates" in updated_stream_modes: + req_updates = True + else: + updated_stream_modes.append("updates") + return (updated_stream_modes, req_updates, req_single) def stream( self, @@ -389,25 +391,28 @@ def stream( ) -> Iterator[Union[dict[str, Any], Any]]: merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) + stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) + # TODO if req_subgraphs transform chunk to match Pregel for chunk in self.sync_client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=updated_stream_modes, + stream_mode=stream_modes, interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): if chunk.event == "updates": - if INTERRUPT in chunk.data: - raise GraphInterrupt() - if not include_updates: + if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: + raise GraphInterrupt(chunk.data[INTERRUPT]) + if not req_updates: continue - - yield chunk + if req_single: + yield chunk.data + else: + yield chunk async def astream( self, @@ -421,25 +426,27 @@ async def astream( ) -> AsyncIterator[Union[dict[str, Any], Any]]: merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) + stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) async for chunk in self.client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], assistant_id=self.graph_id, input=input, config=sanitized_config, - stream_mode=updated_stream_modes, + stream_mode=stream_modes, interrupt_before=interrupt_before, # type: ignore interrupt_after=interrupt_after, # type: ignore stream_subgraphs=subgraphs, ): if chunk.event == "updates": - if INTERRUPT in chunk.data: - raise GraphInterrupt() - if not include_updates: + if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: + raise GraphInterrupt(chunk.data[INTERRUPT]) + if not req_updates: continue - - yield chunk + if req_single: + yield chunk.data + else: + yield chunk async def astream_events( self, @@ -451,10 +458,11 @@ async def astream_events( sanitized_config = self._sanitize_config(merged_config) # manually add 'events' to stream modes list - stream_mode: list[StreamMode] = kwargs.get("stream_mode", []) - updated_stream_modes, include_updates = self._get_stream_modes(stream_mode) + stream_mode: Union[StreamMode, list[StreamMode]] = kwargs.get("stream_mode", []) + updated_stream_modes, include_updates, _ = self._get_stream_modes(stream_mode) if "events" not in updated_stream_modes: updated_stream_modes.append("events") + # TODO bundle main stream events back into StreamEvent async for chunk in self.client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index f89a0f6a5..8800d77cf 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -1169,15 +1169,15 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, @@ -1191,13 +1191,13 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, @@ -1210,15 +1210,15 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, @@ -1313,12 +1313,12 @@ async def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, after_seconds: Optional[int] = None, @@ -1331,14 +1331,14 @@ async def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, after_seconds: Optional[int] = None, @@ -1350,14 +1350,14 @@ async def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, on_completion: Optional[OnCompletionBehavior] = None, @@ -1495,8 +1495,8 @@ async def wait( config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, @@ -1512,8 +1512,8 @@ async def wait( input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, @@ -1530,8 +1530,8 @@ async def wait( config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, @@ -3232,15 +3232,15 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, @@ -3254,13 +3254,13 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, @@ -3273,15 +3273,15 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, - feedback_keys: Optional[list[str]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, + feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, @@ -3376,12 +3376,12 @@ def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, after_seconds: Optional[int] = None, @@ -3394,14 +3394,14 @@ def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, after_seconds: Optional[int] = None, @@ -3413,14 +3413,14 @@ def create( assistant_id: str, *, input: Optional[dict] = None, - stream_mode: Union[StreamMode, list[StreamMode]] = "values", + stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, on_completion: Optional[OnCompletionBehavior] = None, @@ -3558,8 +3558,8 @@ def wait( config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, @@ -3575,8 +3575,8 @@ def wait( input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, @@ -3593,8 +3593,8 @@ def wait( config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, - interrupt_before: Optional[Union[All, list[str]]] = None, - interrupt_after: Optional[Union[All, list[str]]] = None, + interrupt_before: Optional[Union[All, Sequence[str]]] = None, + interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, From dca200d6c4689f6fb8eb8b4b9a04e110dbb6a9a3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 13:23:42 -0700 Subject: [PATCH 5/9] Finish --- libs/langgraph/langgraph/pregel/remote.py | 97 ++++++++++------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 509f048d3..59150b7b1 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -19,7 +19,6 @@ from langchain_core.runnables.graph import ( Node as DrawableNode, ) -from langchain_core.runnables.schema import StandardStreamEvent, StreamEvent from langgraph_sdk.client import ( LangGraphClient, SyncLangGraphClient, @@ -39,6 +38,12 @@ from langgraph.utils.config import merge_configs +class RemoteException(Exception): + """Exception raised when an error occurs in the remote graph.""" + + pass + + class RemoteGraph(PregelProtocol, Runnable): def __init__( self, @@ -326,7 +331,7 @@ def update_state( response: dict = self.sync_client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], - values=values, # type: ignore + values=values, as_node=as_node, checkpoint=self._get_checkpoint(merged_config), ) @@ -342,7 +347,7 @@ async def aupdate_state( response: dict = await self.client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], - values=values, # type: ignore + values=values, as_node=as_node, checkpoint=self._get_checkpoint(merged_config), ) @@ -392,7 +397,6 @@ def stream( merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) - # TODO if req_subgraphs transform chunk to match Pregel for chunk in self.sync_client.runs.stream( thread_id=sanitized_config["configurable"]["thread_id"], @@ -400,16 +404,28 @@ def stream( input=input, config=sanitized_config, stream_mode=stream_modes, - interrupt_before=interrupt_before, # type: ignore - interrupt_after=interrupt_after, # type: ignore + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, stream_subgraphs=subgraphs, ): - if chunk.event == "updates": + if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) if not req_updates: continue - if req_single: + elif chunk.event.startswith("error"): + raise RemoteException(chunk.data) + if subgraphs: + if "|" in chunk.event: + mode, ns_ = chunk.event.split("|", 1) + ns = tuple(ns_.split("|")) + else: + mode, ns = chunk.event, () + if req_single: + yield ns, chunk.data + else: + yield ns, mode, chunk.data + elif req_single: yield chunk.data else: yield chunk @@ -434,57 +450,32 @@ async def astream( input=input, config=sanitized_config, stream_mode=stream_modes, - interrupt_before=interrupt_before, # type: ignore - interrupt_after=interrupt_after, # type: ignore + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, stream_subgraphs=subgraphs, ): - if chunk.event == "updates": + if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) if not req_updates: continue - if req_single: + elif chunk.event.startswith("error"): + raise RemoteException(chunk.data) + if subgraphs: + if "|" in chunk.event: + mode, ns_ = chunk.event.split("|", 1) + ns = tuple(ns_.split("|")) + else: + mode, ns = chunk.event, () + if req_single: + yield ns, chunk.data + else: + yield ns, mode, chunk.data + elif req_single: yield chunk.data else: yield chunk - async def astream_events( - self, - input: Any, - config: Optional[RunnableConfig] = None, - **kwargs: Any, - ) -> AsyncIterator[StreamEvent]: - merged_config = merge_configs(self.config, config) - sanitized_config = self._sanitize_config(merged_config) - - # manually add 'events' to stream modes list - stream_mode: Union[StreamMode, list[StreamMode]] = kwargs.get("stream_mode", []) - updated_stream_modes, include_updates, _ = self._get_stream_modes(stream_mode) - if "events" not in updated_stream_modes: - updated_stream_modes.append("events") - # TODO bundle main stream events back into StreamEvent - - async for chunk in self.client.runs.stream( - thread_id=sanitized_config["configurable"]["thread_id"], - assistant_id=self.graph_id, - input=input, - config=sanitized_config, - stream_mode=updated_stream_modes, - interrupt_before=kwargs.get("interrupt_before"), - interrupt_after=kwargs.get("interrupt_after"), - stream_subgraphs=kwargs.get("subgraphs", False), - ): - if chunk.event == "updates": - if INTERRUPT in chunk.data: - raise GraphInterrupt() - if not include_updates: - continue - - yield StandardStreamEvent( - event=chunk.event, - data=chunk.data, - ) - def invoke( self, input: Union[dict[str, Any], Any], @@ -501,8 +492,8 @@ def invoke( assistant_id=self.graph_id, input=input, config=sanitized_config, - interrupt_before=interrupt_before, # type: ignore - interrupt_after=interrupt_after, # type: ignore + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, ) async def ainvoke( @@ -521,6 +512,6 @@ async def ainvoke( assistant_id=self.graph_id, input=input, config=sanitized_config, - interrupt_before=interrupt_before, # type: ignore - interrupt_after=interrupt_after, # type: ignore + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, ) From 1121806ba4bfeb516fef09c4fecc9e756e4181b1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 13:23:49 -0700 Subject: [PATCH 6/9] Add if_not_exists --- libs/sdk-py/langgraph_sdk/client.py | 16 ++++++++++++++++ libs/sdk-py/langgraph_sdk/schema.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 8800d77cf..d9a00b129 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -39,6 +39,7 @@ Cron, DisconnectMode, GraphSchema, + IfNotExists, Item, Json, ListNamespaceResponse, @@ -1181,6 +1182,7 @@ def stream( on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: ... @@ -1223,6 +1225,7 @@ def stream( on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: """Create a run and stream the results. @@ -1248,6 +1251,8 @@ def stream( webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -1293,6 +1298,7 @@ def stream( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, @@ -1341,6 +1347,7 @@ async def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -1360,6 +1367,7 @@ async def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, on_completion: Optional[OnCompletionBehavior] = None, after_seconds: Optional[int] = None, ) -> Run: @@ -1383,6 +1391,8 @@ async def create( Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -1466,6 +1476,7 @@ async def create( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_completion": on_completion, "after_seconds": after_seconds, } @@ -1500,6 +1511,7 @@ async def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... @@ -1536,6 +1548,7 @@ async def wait( on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: """Create a run, wait until it finishes and return the final state. @@ -1558,6 +1571,8 @@ async def wait( Must be one of 'delete' or 'keep'. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -1619,6 +1634,7 @@ async def wait( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, diff --git a/libs/sdk-py/langgraph_sdk/schema.py b/libs/sdk-py/langgraph_sdk/schema.py index 19014a76c..dd217a852 100644 --- a/libs/sdk-py/langgraph_sdk/schema.py +++ b/libs/sdk-py/langgraph_sdk/schema.py @@ -69,6 +69,13 @@ All = Literal["*"] """Represents a wildcard or 'all' selector.""" +IfNotExists = Literal["create", "reject"] +""" +Specifies behavior if the thread doesn't exist: +- "create": Create a new thread if it doesn't exist. +- "reject": Reject the operation if the thread doesn't exist. +""" + class Config(TypedDict, total=False): """Configuration options for a call.""" From f8a0b7a4648b0de618183fa967fced2aa1bf6026 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 13:24:19 -0700 Subject: [PATCH 7/9] Use if_not_exists --- libs/langgraph/langgraph/pregel/remote.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 59150b7b1..9d6572334 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -407,6 +407,7 @@ def stream( interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs, + if_not_exists="create", ): if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: @@ -453,6 +454,7 @@ async def astream( interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs, + if_not_exists="create", ): if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: @@ -494,6 +496,7 @@ def invoke( config=sanitized_config, interrupt_before=interrupt_before, interrupt_after=interrupt_after, + if_not_exists="create", ) async def ainvoke( @@ -514,4 +517,5 @@ async def ainvoke( config=sanitized_config, interrupt_before=interrupt_before, interrupt_after=interrupt_after, + if_not_exists="create", ) From e294720ec56e5f98f3496bdf103d86156a777ac3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 13:29:59 -0700 Subject: [PATCH 8/9] Lint --- libs/langgraph/langgraph/pregel/remote.py | 2 +- libs/sdk-py/langgraph_sdk/client.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 9d6572334..4b4889615 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -399,7 +399,7 @@ def stream( stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) for chunk in self.sync_client.runs.stream( - thread_id=sanitized_config["configurable"]["thread_id"], + thread_id=cast(str, sanitized_config["configurable"]["thread_id"]), assistant_id=self.graph_id, input=input, config=sanitized_config, diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index d9a00b129..239f4cf2a 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -3260,6 +3260,7 @@ def stream( on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: ... @@ -3302,6 +3303,7 @@ def stream( on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: """Create a run and stream the results. @@ -3327,6 +3329,8 @@ def stream( webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -3372,6 +3376,7 @@ def stream( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, @@ -3420,6 +3425,7 @@ def create( interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @@ -3440,6 +3446,7 @@ def create( webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, on_completion: Optional[OnCompletionBehavior] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: """Create a background run. @@ -3462,6 +3469,8 @@ def create( Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -3545,6 +3554,7 @@ def create( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_completion": on_completion, "after_seconds": after_seconds, } @@ -3579,6 +3589,7 @@ def wait( webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... @@ -3615,6 +3626,7 @@ def wait( on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, multitask_strategy: Optional[MultitaskStrategy] = None, + if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: """Create a run, wait until it finishes and return the final state. @@ -3637,6 +3649,8 @@ def wait( Must be one of 'delete' or 'keep'. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. + if_not_exists: How to handle missing thread. Defaults to 'reject'. + Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. @@ -3698,6 +3712,7 @@ def wait( "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, + "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, From 037a95ff60f363cdbed99d57b6140b98b23d44a9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Oct 2024 13:43:48 -0700 Subject: [PATCH 9/9] Update tests --- libs/langgraph/langgraph/pregel/remote.py | 2 +- libs/langgraph/tests/test_remote_graph.py | 179 +++++++++++++++++++--- 2 files changed, 162 insertions(+), 19 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 4b4889615..52b696e51 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -48,12 +48,12 @@ class RemoteGraph(PregelProtocol, Runnable): def __init__( self, graph_id: str, - config: Optional[RunnableConfig] = None, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, client: Optional[LangGraphClient] = None, sync_client: Optional[SyncLangGraphClient] = None, + config: Optional[RunnableConfig] = None, ): """Specify `url`, `api_key`, and/or `headers` to create default sync and async clients. diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index a6a208b27..46b3ab9a4 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -489,17 +489,39 @@ def test_stream(): stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( - {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), + {"chunk": "data1"}, + {"chunk": "data2"}, + {"chunk": "data3"}, + ] + + mock_sync_client.runs.stream.return_value = [ + StreamPart(event="updates", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="updates", data={"__interrupt__": ()}), + ] + + # default stream_mode is updates + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + {"chunk": "data3"}, + {"chunk": "data4"}, ] - # stream modes includes 'updates' + # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( @@ -510,10 +532,39 @@ def test_stream(): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), - StreamPart(event="updates", data={"chunk": "data4"}), + ("updates", {"chunk": "data3"}), + ("updates", {"chunk": "data4"}), + ] + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), "updates", {"chunk": "data3"}), + ((), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), {"chunk": "data3"}), + ((), {"chunk": "data4"}), ] @@ -538,17 +589,41 @@ async def test_astream(): stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( - {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), + {"chunk": "data1"}, + {"chunk": "data2"}, + {"chunk": "data3"}, + ] + + async_iter = MagicMock() + async_iter.__aiter__.return_value = [ + StreamPart(event="updates", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="updates", data={"__interrupt__": ()}), + ] + mock_async_client.runs.stream.return_value = async_iter + + # default stream_mode is updates + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + {"chunk": "data3"}, + {"chunk": "data4"}, ] - # stream modes includes 'updates' + # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( @@ -559,10 +634,78 @@ async def test_astream(): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), - StreamPart(event="updates", data={"chunk": "data4"}), + ("updates", {"chunk": "data3"}), + ("updates", {"chunk": "data4"}), + ] + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), "updates", {"chunk": "data3"}), + ((), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), {"chunk": "data3"}), + ((), {"chunk": "data4"}), + ] + + async_iter = MagicMock() + async_iter.__aiter__.return_value = [ + StreamPart(event="updates|my|subgraph", data={"chunk": "data3"}), + StreamPart(event="updates|hello|subgraph", data={"chunk": "data4"}), + StreamPart(event="updates|bye|subgraph", data={"__interrupt__": ()}), + ] + mock_async_client.runs.stream.return_value = async_iter + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + (("my", "subgraph"), "updates", {"chunk": "data3"}), + (("hello", "subgraph"), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + (("my", "subgraph"), {"chunk": "data3"}), + (("hello", "subgraph"), {"chunk": "data4"}), ]