Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Oct 23, 2024
1 parent e294720 commit 037a95f
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 19 deletions.
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ class RemoteGraph(PregelProtocol, Runnable):
def __init__(
self,
graph_id: str,
config: Optional[RunnableConfig] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
client: Optional[LangGraphClient] = None,
sync_client: Optional[SyncLangGraphClient] = None,
config: Optional[RunnableConfig] = None,
):
"""Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.
Expand Down
179 changes: 161 additions & 18 deletions libs/langgraph/tests/test_remote_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,17 +489,39 @@ def test_stream():
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode="values",
):
stream_parts.append(stream_part)

assert stream_parts == [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
{"chunk": "data1"},
{"chunk": "data2"},
{"chunk": "data3"},
]

mock_sync_client.runs.stream.return_value = [
StreamPart(event="updates", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]

# default stream_mode is updates
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
):
stream_parts.append(stream_part)

assert stream_parts == [
{"chunk": "data3"},
{"chunk": "data4"},
]

# stream modes includes 'updates'
# list stream_mode includes mode names
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
Expand All @@ -510,10 +532,39 @@ def test_stream():
stream_parts.append(stream_part)

assert stream_parts == [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
("updates", {"chunk": "data3"}),
("updates", {"chunk": "data4"}),
]

# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
((), "updates", {"chunk": "data3"}),
((), "updates", {"chunk": "data4"}),
]

# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
((), {"chunk": "data3"}),
((), {"chunk": "data4"}),
]


Expand All @@ -538,17 +589,41 @@ async def test_astream():
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode="values",
):
stream_parts.append(stream_part)

assert stream_parts == [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
{"chunk": "data1"},
{"chunk": "data2"},
{"chunk": "data3"},
]

async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="updates", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]
mock_async_client.runs.stream.return_value = async_iter

# default stream_mode is updates
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
):
stream_parts.append(stream_part)

assert stream_parts == [
{"chunk": "data3"},
{"chunk": "data4"},
]

# stream modes includes 'updates'
# list stream_mode includes mode names
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
Expand All @@ -559,10 +634,78 @@ async def test_astream():
stream_parts.append(stream_part)

assert stream_parts == [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
("updates", {"chunk": "data3"}),
("updates", {"chunk": "data4"}),
]

# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
((), "updates", {"chunk": "data3"}),
((), "updates", {"chunk": "data4"}),
]

# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
((), {"chunk": "data3"}),
((), {"chunk": "data4"}),
]

async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="updates|my|subgraph", data={"chunk": "data3"}),
StreamPart(event="updates|hello|subgraph", data={"chunk": "data4"}),
StreamPart(event="updates|bye|subgraph", data={"__interrupt__": ()}),
]
mock_async_client.runs.stream.return_value = async_iter

# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
(("my", "subgraph"), "updates", {"chunk": "data3"}),
(("hello", "subgraph"), "updates", {"chunk": "data4"}),
]

# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)

assert stream_parts == [
(("my", "subgraph"), {"chunk": "data3"}),
(("hello", "subgraph"), {"chunk": "data4"}),
]


Expand Down

0 comments on commit 037a95f

Please sign in to comment.