Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 16, 2025
1 parent f9c25bb commit 145220f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 63 deletions.
1 change: 1 addition & 0 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def agen_wrapper(
output_channels=END,
stream_channels=END,
stream_mode=stream_mode,
stream_eager=True,
checkpointer=checkpointer,
store=store,
)
Expand Down
20 changes: 18 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class Pregel(PregelProtocol):
stream_mode: StreamMode = "values"
"""Mode to stream output, defaults to 'values'."""

stream_eager: bool = False
"""Whether to force emitting stream events eagerly, automatically turned on
for stream_mode "messages" and "custom"."""

output_channels: Union[str, Sequence[str]]

stream_channels: Optional[Union[str, Sequence[str]]] = None
Expand Down Expand Up @@ -242,6 +246,7 @@ def __init__(
channels: Optional[dict[str, Union[BaseChannel, ManagedValueSpec]]],
auto_validate: bool = True,
stream_mode: StreamMode = "values",
stream_eager: bool = False,
output_channels: Union[str, Sequence[str]],
stream_channels: Optional[Union[str, Sequence[str]]] = None,
interrupt_after_nodes: Union[All, Sequence[str]] = (),
Expand All @@ -259,6 +264,7 @@ def __init__(
self.nodes = nodes
self.channels = channels or {}
self.stream_mode = stream_mode
self.stream_eager = stream_eager
self.output_channels = output_channels
self.stream_channels = stream_channels
self.interrupt_after_nodes = interrupt_after_nodes
Expand Down Expand Up @@ -1655,7 +1661,12 @@ def output() -> Iterator:
if subgraphs:
loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
# enable concurrent streaming
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
if (
self.stream_eager
or subgraphs
or "messages" in stream_modes
or "custom" in stream_modes
):
# we are careful to have a single waiter live at any one time
# because on exit we increment semaphore count by exactly 1
waiter: Optional[concurrent.futures.Future] = None
Expand Down Expand Up @@ -1886,7 +1897,12 @@ def output() -> Iterator:
stream_put, stream_modes
)
# enable concurrent streaming
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
if (
self.stream_eager
or subgraphs
or "messages" in stream_modes
or "custom" in stream_modes
):

def get_waiter() -> asyncio.Task[None]:
return aioloop.create_task(stream.wait())
Expand Down
65 changes: 14 additions & 51 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import enum
import json
import logging
Expand Down Expand Up @@ -2168,10 +2167,10 @@ class State(TypedDict, total=False):

@workflow.add_node
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -2308,10 +2307,10 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -2741,11 +2740,11 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyzer_one(data: State) -> State:
time.sleep(0.1)
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -2831,10 +2830,10 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -2904,10 +2903,10 @@ class State(TypedDict, total=False):
query: str

def rewrite(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyze(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

class ChooseAnalyzer:
def __call__(self, data: State) -> str:
Expand All @@ -2930,10 +2929,10 @@ class State(TypedDict, total=False):
query: str

def rewrite(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def analyze(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def choose_analyzer(data: State) -> str:
return "analyzer"
Expand Down Expand Up @@ -2966,13 +2965,13 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]

def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -5528,39 +5527,3 @@ def graph(inputs: dict) -> list:
delta = arrival_times[1] - arrival_times[0]
# Delta cannot be less than 10 ms if it is streaming as results are generated.
assert delta > time_delay


async def test_async_streaming_with_functional_api() -> None:
"""Test streaming with functional API.
This test verifies that we're able to stream results as they're being generated
rather than have all the results arrive at once after the graph has completed.
The time of arrival between the two updates corresponding to the two `slow` tasks
should be greater than the time delay between the two tasks.
"""

time_delay = 0.01

@task()
async def slow() -> dict:
await asyncio.sleep(time_delay) # Simulate a delay of 10 ms
return {"tic": time.time()}

@entrypoint()
async def graph(inputs: dict) -> list:
first = await slow()
second = await slow()
return [first, second]

arrival_times = []

async for chunk in graph.astream({}):
if "slow" not in chunk: # We'll just look at the updates from `slow`
continue
arrival_times.append(time.time())

assert len(arrival_times) == 2
delta = arrival_times[1] - arrival_times[0]
# Delta cannot be less than 10 ms if it is streaming as results are generated.
assert delta > time_delay
56 changes: 46 additions & 10 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4294,10 +4294,10 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -4384,10 +4384,10 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -4801,11 +4801,11 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

async def analyzer_one(data: State) -> State:
await asyncio.sleep(0.1)
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -4895,10 +4895,10 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -4979,13 +4979,13 @@ class State(TypedDict, total=False):
docs: Annotated[list[str], sorted_add]

async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
return {"query": f"query: {data['query']}"}

async def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]

async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
return {"query": f"analyzed: {data['query']}"}

async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
Expand Down Expand Up @@ -6875,3 +6875,39 @@ def invoke_sub_agent(state: AgentState):
"invoke_sub_agent": {"input": True},
},
]


async def test_async_streaming_with_functional_api() -> None:
"""Test streaming with functional API.
This test verifies that we're able to stream results as they're being generated
rather than have all the results arrive at once after the graph has completed.
The time of arrival between the two updates corresponding to the two `slow` tasks
should be greater than the time delay between the two tasks.
"""

time_delay = 0.01

@task()
async def slow() -> dict:
await asyncio.sleep(time_delay) # Simulate a delay of 10 ms
return {"tic": asyncio.get_running_loop().time()}

@entrypoint()
async def graph(inputs: dict) -> list:
first = await slow()
second = await slow()
return [first, second]

arrival_times = []

async for chunk in graph.astream({}):
if "slow" not in chunk: # We'll just look at the updates from `slow`
continue
arrival_times.append(asyncio.get_running_loop().time())

assert len(arrival_times) == 2
delta = arrival_times[1] - arrival_times[0]
# Delta cannot be less than 10 ms if it is streaming as results are generated.
assert delta > time_delay

0 comments on commit 145220f

Please sign in to comment.