Skip to content

Commit

Permalink
Merge pull request #1751 from langchain-ai/nc/17sep/stream-messages
Browse files Browse the repository at this point in the history
Add stream_mode=messages
  • Loading branch information
nfcampos authored Sep 18, 2024
2 parents 22c7805 + 9235178 commit 4ba803b
Show file tree
Hide file tree
Showing 8 changed files with 711 additions and 140 deletions.
9 changes: 8 additions & 1 deletion libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from langgraph.pregel.io import read_channels
from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.messages import StreamMessagesHandler
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.runner import PregelRunner
Expand Down Expand Up @@ -1213,7 +1214,11 @@ def output() -> Iterator:
interrupt_after=interrupt_after,
debug=debug,
)

# set up messages stream mode
if "messages" in stream_modes:
run_manager.inheritable_handlers.append(
StreamMessagesHandler(stream.put)
)
with SyncPregelLoop(
input,
stream=StreamProtocol(stream.put, stream_modes),
Expand All @@ -1234,6 +1239,8 @@ def output() -> Iterator:
# enable subgraph streaming
if subgraphs:
loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream
# enable concurrent streaming
if subgraphs or "messages" in stream_modes:
# we are careful to have a single waiter live at any one time
# because on exit we increment semaphore count by exactly 1
waiter: Optional[concurrent.futures.Future] = None
Expand Down
31 changes: 16 additions & 15 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,6 @@ def prepare_single_task(
return
# create task id
triggers = [PUSH]
metadata = {
"langgraph_step": step,
"langgraph_node": packet.node,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
}
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
Expand All @@ -367,6 +361,14 @@ def prepare_single_task(
PUSH,
str(idx),
)
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
metadata = {
"langgraph_step": step,
"langgraph_node": packet.node,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
assert task_id == task_id_checksum
if for_execution:
Expand All @@ -376,7 +378,6 @@ def prepare_single_task(
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
packet.node,
packet.arg,
Expand Down Expand Up @@ -461,12 +462,6 @@ def prepare_single_task(
return

# create task id
metadata = {
"langgraph_step": step,
"langgraph_node": name,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
}
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
Expand All @@ -476,15 +471,21 @@ def prepare_single_task(
PULL,
*triggers,
)
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
metadata = {
"langgraph_step": step,
"langgraph_node": name,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
assert task_id == task_id_checksum

if for_execution:
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
name,
val,
Expand Down
5 changes: 2 additions & 3 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
AsyncContextManager,
Callable,
ContextManager,
Iterable,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -112,11 +111,11 @@ class StreamProtocol:

modes: Sequence[Literal["values", "updates", "debug"]]

__call__: Callable[[Iterable[Tuple[str, str, Any]]], None]
__call__: Callable[[Tuple[str, str, Any]], None]

def __init__(
self,
__call__: Callable[[Iterable[Tuple[str, str, Any]]], None],
__call__: Callable[[Tuple[str, str, Any]], None],
modes: Sequence[Literal["values", "updates", "debug"]],
) -> None:
self.__call__ = __call__
Expand Down
167 changes: 167 additions & 0 deletions libs/langgraph/langgraph/pregel/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from uuid import UUID, uuid4

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.tracers._streaming import T, _StreamingCallbackHandler

from langgraph.constants import NS_SEP


class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
def __init__(self, stream: Callable[[Tuple[str, str, Any]], None]):
self.stream = stream
self.metadata: dict[str, tuple[str, dict[str, Any]]] = {}
self.seen = set()

def _emit(
self,
meta: Tuple[str, dict[str, Any]],
message: BaseMessage,
*,
dedupe: bool = False,
):
ident = id(message)
if dedupe and message.id in self.seen:
return
elif ident in self.seen:
return
else:
if message.id is None:
message.id = str(uuid4())
self.seen.add(ident)
self.seen.add(message.id)
self.stream((meta[0], "messages", (message, meta[1])))

def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
return output

def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
return output

def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
if metadata:
self.metadata[run_id] = (
tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)),
metadata,
)

def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[ChatGenerationChunk] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if not isinstance(chunk, ChatGenerationChunk):
return
if meta := self.metadata.get(run_id):
self._emit(meta, chunk.message)

def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)

def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)

def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
if metadata and kwargs.get("name") == metadata.get("langgraph_node"):
self.metadata[run_id] = (
tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)),
metadata,
)

def on_chain_end(
self,
response: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if meta := self.metadata.pop(run_id, None):
if isinstance(response, BaseMessage):
self._emit(meta, response, dedupe=True)
elif isinstance(response, Sequence):
for value in response:
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(response, dict):
for value in response.values():
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(value, Sequence):
for item in value:
if isinstance(item, BaseMessage):
self._emit(meta, item, dedupe=True)
elif hasattr(response, "__dir__") and callable(response.__dir__):
for key in dir(response):
try:
value = getattr(response, key)
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(value, Sequence):
for item in value:
if isinstance(item, BaseMessage):
self._emit(meta, item, dedupe=True)
except AttributeError:
pass

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)
3 changes: 2 additions & 1 deletion libs/langgraph/langgraph/pregel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ class StateSnapshot(NamedTuple):

All = Literal["*"]

StreamMode = Literal["values", "updates", "debug"]
StreamMode = Literal["values", "updates", "debug", "messages"]
"""How the stream method should emit outputs.
- 'values': Emit all values of the state for each step.
- 'updates': Emit only the node name(s) and updates
that were returned by the node(s) **after** each step.
- 'debug': Emit debug events for each step.
- 'messages': Emit LLM messages token-by-token.
"""
86 changes: 86 additions & 0 deletions libs/langgraph/tests/fake_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import re
from typing import Any, Iterator, List, Optional, cast

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult


class FakeChatModel(GenericFakeChatModel):
messages: list[BaseMessage]

i: int = 0

def bind_tools(self, functions: list):
return self

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
if self.i >= len(self.messages):
self.i = 0
message = self.messages[self.i]
self.i += 1
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
if hasattr(message, "model_copy"):
message_ = message.model_copy()
else:
message_ = message.copy()
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model."""
chat_result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if not isinstance(chat_result, ChatResult):
raise ValueError(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)

message = chat_result.generations[0].message

if not isinstance(message, AIMessage):
raise ValueError(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)

content = message.content

if content:
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(list[str], re.split(r"(\s)", content))

for token in content_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
else:
args = message.__dict__
args.pop("type")
chunk = ChatGenerationChunk(message=AIMessageChunk(**args))
if run_manager:
run_manager.on_llm_new_token("", chunk=chunk)
yield chunk
Loading

0 comments on commit 4ba803b

Please sign in to comment.