Skip to content

Commit

Permalink
community[patch]: Add streaming logic in ChatHuggingFace (#18784)
Browse files Browse the repository at this point in the history
- Add functions (_stream, _astream)
- Connect to _generate and _agenerate

Thank you for contributing to LangChain!

- [x] **PR title**: "community: Add streaming logic in ChatHuggingFace"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Addition functions (_stream, _astream) and connection
to _generate and _agenerate
    - **Issue:** #18782
    - **Dependencies:** none
    - **Twitter handle:** @lunara_x
  • Loading branch information
eunhye1kim authored Apr 17, 2024
1 parent c05c379 commit b34f108
Showing 1 changed file with 59 additions and 5 deletions.
64 changes: 59 additions & 5 deletions libs/community/langchain_community/chat_models/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
"""Hugging Face Chat Wrapper."""

from typing import Any, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator

from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
Expand All @@ -26,7 +35,8 @@


class ChatHuggingFace(BaseChatModel):
"""Hugging Face LLMs as ChatModels.
"""
Wrapper for using Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
and `HuggingFaceHub` LLMs.
Expand All @@ -44,6 +54,7 @@ class ChatHuggingFace(BaseChatModel):
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: Optional[str] = None
streaming: bool = False

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
Expand All @@ -70,13 +81,50 @@ def validate_llm(cls, values: dict) -> dict:
)
return values

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)

for data in self.llm.stream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)
async for data in self.llm.astream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)

llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
Expand All @@ -90,6 +138,12 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)

llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
Expand Down

0 comments on commit b34f108

Please sign in to comment.