diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index 598a8ebc49c25..2bf5b90948564 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -1,19 +1,28 @@ """Hugging Face Chat Wrapper.""" - -from typing import Any, List, Optional +from typing import Any, AsyncIterator, Iterator, List, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) -from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + LLMResult, +) from langchain_core.pydantic_v1 import root_validator from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint @@ -26,7 +35,8 @@ class ChatHuggingFace(BaseChatModel): - """Hugging Face LLMs as ChatModels. + """ + Wrapper for using Hugging Face LLM's as ChatModels. Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, and `HuggingFaceHub` LLMs. @@ -44,6 +54,7 @@ class ChatHuggingFace(BaseChatModel): system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) tokenizer: Any = None model_id: Optional[str] = None + streaming: bool = False def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -70,6 +81,37 @@ def validate_llm(cls, values: dict) -> dict: ) return values + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self._to_chat_prompt(messages) + + for data in self.llm.stream(request, **kwargs): + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + request = self._to_chat_prompt(messages) + async for data in self.llm.astream(request, **kwargs): + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + await run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + def _generate( self, messages: List[BaseMessage], @@ -77,6 +119,12 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + llm_input = self._to_chat_prompt(messages) llm_result = self.llm._generate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs @@ -90,6 +138,12 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + llm_input = self._to_chat_prompt(messages) llm_result = await self.llm._agenerate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs