Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[langchain_community.llms.xinference]: fix error and support stream method #29192

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions libs/community/langchain_community/llms/xinference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
from typing import (
Any,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
TYPE_CHECKING,
Union,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

if TYPE_CHECKING:

Check failure on line 17 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/llms/xinference.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 17 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/llms/xinference.py:1:1: I001 Import block is un-sorted or un-formatted
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.model.llm.core import LlamaCppGenerateConfig

Expand Down Expand Up @@ -81,7 +92,7 @@

""" # noqa: E501

client: Any
client: Optional[Any] = None
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
Expand Down Expand Up @@ -214,3 +225,68 @@
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
for stream_resp in self._create_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk

def _create_generate_stream(
self,
prompt: str,
generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
model = self.client.get_model(self.model_uid)
yield from self.create_stream(
model,
prompt,
generate_config,
)

@staticmethod
def _stream_response_to_generation_chunk(
stream_response: str,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
token = ''
if isinstance(stream_response, dict):
choices = stream_response.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")

if not stream_response["choices"]:
return GenerationChunk(text=token)

return GenerationChunk(
text=token,
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
)

@staticmethod
def create_stream(
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
prompt: str,
generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
return model.generate(prompt=prompt, generate_config=generate_config)
Loading