-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1751 from langchain-ai/nc/17sep/stream-messages
Add stream_mode=messages
- Loading branch information
Showing
8 changed files
with
711 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from typing import ( | ||
Any, | ||
AsyncIterator, | ||
Callable, | ||
Dict, | ||
Iterator, | ||
List, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
) | ||
from uuid import UUID, uuid4 | ||
|
||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages import BaseMessage | ||
from langchain_core.outputs import ChatGenerationChunk, LLMResult | ||
from langchain_core.tracers._streaming import T, _StreamingCallbackHandler | ||
|
||
from langgraph.constants import NS_SEP | ||
|
||
|
||
class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): | ||
def __init__(self, stream: Callable[[Tuple[str, str, Any]], None]): | ||
self.stream = stream | ||
self.metadata: dict[str, tuple[str, dict[str, Any]]] = {} | ||
self.seen = set() | ||
|
||
def _emit( | ||
self, | ||
meta: Tuple[str, dict[str, Any]], | ||
message: BaseMessage, | ||
*, | ||
dedupe: bool = False, | ||
): | ||
ident = id(message) | ||
if dedupe and message.id in self.seen: | ||
return | ||
elif ident in self.seen: | ||
return | ||
else: | ||
if message.id is None: | ||
message.id = str(uuid4()) | ||
self.seen.add(ident) | ||
self.seen.add(message.id) | ||
self.stream((meta[0], "messages", (message, meta[1]))) | ||
|
||
def tap_output_aiter( | ||
self, run_id: UUID, output: AsyncIterator[T] | ||
) -> AsyncIterator[T]: | ||
return output | ||
|
||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]: | ||
return output | ||
|
||
def on_chat_model_start( | ||
self, | ||
serialized: dict[str, Any], | ||
messages: list[list[BaseMessage]], | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
tags: Optional[list[str]] = None, | ||
metadata: Optional[dict[str, Any]] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
if metadata: | ||
self.metadata[run_id] = ( | ||
tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), | ||
metadata, | ||
) | ||
|
||
def on_llm_new_token( | ||
self, | ||
token: str, | ||
*, | ||
chunk: Optional[ChatGenerationChunk] = None, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
if not isinstance(chunk, ChatGenerationChunk): | ||
return | ||
if meta := self.metadata.get(run_id): | ||
self._emit(meta, chunk.message) | ||
|
||
def on_llm_end( | ||
self, | ||
response: LLMResult, | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
self.metadata.pop(run_id, None) | ||
|
||
def on_llm_error( | ||
self, | ||
error: BaseException, | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
self.metadata.pop(run_id, None) | ||
|
||
def on_chain_start( | ||
self, | ||
serialized: Dict[str, Any], | ||
inputs: Dict[str, Any], | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
tags: Optional[List[str]] = None, | ||
metadata: Optional[Dict[str, Any]] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
if metadata and kwargs.get("name") == metadata.get("langgraph_node"): | ||
self.metadata[run_id] = ( | ||
tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), | ||
metadata, | ||
) | ||
|
||
def on_chain_end( | ||
self, | ||
response: Any, | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
if meta := self.metadata.pop(run_id, None): | ||
if isinstance(response, BaseMessage): | ||
self._emit(meta, response, dedupe=True) | ||
elif isinstance(response, Sequence): | ||
for value in response: | ||
if isinstance(value, BaseMessage): | ||
self._emit(meta, value, dedupe=True) | ||
elif isinstance(response, dict): | ||
for value in response.values(): | ||
if isinstance(value, BaseMessage): | ||
self._emit(meta, value, dedupe=True) | ||
elif isinstance(value, Sequence): | ||
for item in value: | ||
if isinstance(item, BaseMessage): | ||
self._emit(meta, item, dedupe=True) | ||
elif hasattr(response, "__dir__") and callable(response.__dir__): | ||
for key in dir(response): | ||
try: | ||
value = getattr(response, key) | ||
if isinstance(value, BaseMessage): | ||
self._emit(meta, value, dedupe=True) | ||
elif isinstance(value, Sequence): | ||
for item in value: | ||
if isinstance(item, BaseMessage): | ||
self._emit(meta, item, dedupe=True) | ||
except AttributeError: | ||
pass | ||
|
||
def on_chain_error( | ||
self, | ||
error: BaseException, | ||
*, | ||
run_id: UUID, | ||
parent_run_id: Optional[UUID] = None, | ||
**kwargs: Any, | ||
) -> Any: | ||
self.metadata.pop(run_id, None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import re | ||
from typing import Any, Iterator, List, Optional, cast | ||
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel | ||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage | ||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||
|
||
|
||
class FakeChatModel(GenericFakeChatModel): | ||
messages: list[BaseMessage] | ||
|
||
i: int = 0 | ||
|
||
def bind_tools(self, functions: list): | ||
return self | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
"""Top Level call""" | ||
if self.i >= len(self.messages): | ||
self.i = 0 | ||
message = self.messages[self.i] | ||
self.i += 1 | ||
if isinstance(message, str): | ||
message_ = AIMessage(content=message) | ||
else: | ||
if hasattr(message, "model_copy"): | ||
message_ = message.model_copy() | ||
else: | ||
message_ = message.copy() | ||
generation = ChatGeneration(message=message_) | ||
return ChatResult(generations=[generation]) | ||
|
||
def _stream( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> Iterator[ChatGenerationChunk]: | ||
"""Stream the output of the model.""" | ||
chat_result = self._generate( | ||
messages, stop=stop, run_manager=run_manager, **kwargs | ||
) | ||
if not isinstance(chat_result, ChatResult): | ||
raise ValueError( | ||
f"Expected generate to return a ChatResult, " | ||
f"but got {type(chat_result)} instead." | ||
) | ||
|
||
message = chat_result.generations[0].message | ||
|
||
if not isinstance(message, AIMessage): | ||
raise ValueError( | ||
f"Expected invoke to return an AIMessage, " | ||
f"but got {type(message)} instead." | ||
) | ||
|
||
content = message.content | ||
|
||
if content: | ||
# Use a regular expression to split on whitespace with a capture group | ||
# so that we can preserve the whitespace in the output. | ||
assert isinstance(content, str) | ||
content_chunks = cast(list[str], re.split(r"(\s)", content)) | ||
|
||
for token in content_chunks: | ||
chunk = ChatGenerationChunk( | ||
message=AIMessageChunk(content=token, id=message.id) | ||
) | ||
if run_manager: | ||
run_manager.on_llm_new_token(token, chunk=chunk) | ||
yield chunk | ||
else: | ||
args = message.__dict__ | ||
args.pop("type") | ||
chunk = ChatGenerationChunk(message=AIMessageChunk(**args)) | ||
if run_manager: | ||
run_manager.on_llm_new_token("", chunk=chunk) | ||
yield chunk |
Oops, something went wrong.