diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 4b4889615..52b696e51 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -48,12 +48,12 @@ class RemoteGraph(PregelProtocol, Runnable): def __init__( self, graph_id: str, - config: Optional[RunnableConfig] = None, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, client: Optional[LangGraphClient] = None, sync_client: Optional[SyncLangGraphClient] = None, + config: Optional[RunnableConfig] = None, ): """Specify `url`, `api_key`, and/or `headers` to create default sync and async clients. diff --git a/libs/langgraph/tests/test_remote_graph.py b/libs/langgraph/tests/test_remote_graph.py index a6a208b27..46b3ab9a4 100644 --- a/libs/langgraph/tests/test_remote_graph.py +++ b/libs/langgraph/tests/test_remote_graph.py @@ -489,17 +489,39 @@ def test_stream(): stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( - {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), + {"chunk": "data1"}, + {"chunk": "data2"}, + {"chunk": "data3"}, + ] + + mock_sync_client.runs.stream.return_value = [ + StreamPart(event="updates", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="updates", data={"__interrupt__": ()}), + ] + + # default stream_mode is updates + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + {"chunk": "data3"}, + {"chunk": "data4"}, ] - # stream modes includes 'updates' + # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( @@ -510,10 +532,39 @@ def test_stream(): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), - StreamPart(event="updates", data={"chunk": "data4"}), + ("updates", {"chunk": "data3"}), + ("updates", {"chunk": "data4"}), + ] + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), "updates", {"chunk": "data3"}), + ((), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + for stream_part in remote_pregel.stream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), {"chunk": "data3"}), + ((), {"chunk": "data4"}), ] @@ -538,17 +589,41 @@ async def test_astream(): stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( - {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}} + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), + {"chunk": "data1"}, + {"chunk": "data2"}, + {"chunk": "data3"}, + ] + + async_iter = MagicMock() + async_iter.__aiter__.return_value = [ + StreamPart(event="updates", data={"chunk": "data3"}), + StreamPart(event="updates", data={"chunk": "data4"}), + StreamPart(event="updates", data={"__interrupt__": ()}), + ] + mock_async_client.runs.stream.return_value = async_iter + + # default stream_mode is updates + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + {"chunk": "data3"}, + {"chunk": "data4"}, ] - # stream modes includes 'updates' + # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( @@ -559,10 +634,78 @@ async def test_astream(): stream_parts.append(stream_part) assert stream_parts == [ - StreamPart(event="values", data={"chunk": "data1"}), - StreamPart(event="values", data={"chunk": "data2"}), - StreamPart(event="values", data={"chunk": "data3"}), - StreamPart(event="updates", data={"chunk": "data4"}), + ("updates", {"chunk": "data3"}), + ("updates", {"chunk": "data4"}), + ] + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), "updates", {"chunk": "data3"}), + ((), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + ((), {"chunk": "data3"}), + ((), {"chunk": "data4"}), + ] + + async_iter = MagicMock() + async_iter.__aiter__.return_value = [ + StreamPart(event="updates|my|subgraph", data={"chunk": "data3"}), + StreamPart(event="updates|hello|subgraph", data={"chunk": "data4"}), + StreamPart(event="updates|bye|subgraph", data={"__interrupt__": ()}), + ] + mock_async_client.runs.stream.return_value = async_iter + + # subgraphs + list modes + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + stream_mode=["updates"], + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + (("my", "subgraph"), "updates", {"chunk": "data3"}), + (("hello", "subgraph"), "updates", {"chunk": "data4"}), + ] + + # subgraphs + single mode + stream_parts = [] + with pytest.raises(GraphInterrupt): + async for stream_part in remote_pregel.astream( + {"input": "data"}, + config={"configurable": {"thread_id": "thread_1"}}, + subgraphs=True, + ): + stream_parts.append(stream_part) + + assert stream_parts == [ + (("my", "subgraph"), {"chunk": "data3"}), + (("hello", "subgraph"), {"chunk": "data4"}), ]