Skip to content

Commit

Permalink
[InferenceClient] Add support for adapter_id (text-generation) and …
Browse files Browse the repository at this point in the history
…`response_format` (chat-completion) (huggingface#2383)

* types

* Add adapter_id arg to text_generation

* Add adapter_id to text-generation and response_format to chat_completion

* update example

* add test

* fix quality

* remove dummy

* lint

* b

* lint
  • Loading branch information
Wauplin authored Jul 16, 2024
1 parent 6ddaf44 commit 36396f1
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 49 deletions.
10 changes: 8 additions & 2 deletions docs/source/en/package_reference/inference_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,20 @@ This part of the lib is still under development and will be improved in future r

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionDefinition

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName

[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType

[[autodoc]] huggingface_hub.ChatCompletionInputMessage

[[autodoc]] huggingface_hub.ChatCompletionInputTool
[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk

[[autodoc]] huggingface_hub.ChatCompletionInputToolCall
[[autodoc]] huggingface_hub.ChatCompletionInputTool

[[autodoc]] huggingface_hub.ChatCompletionInputToolTypeClass

[[autodoc]] huggingface_hub.ChatCompletionInputURL

[[autodoc]] huggingface_hub.ChatCompletionOutput

[[autodoc]] huggingface_hub.ChatCompletionOutputComplete
Expand Down
10 changes: 8 additions & 2 deletions docs/source/ko/package_reference/inference_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ rendered properly in your Markdown viewer.

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionDefinition

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName

[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType

[[autodoc]] huggingface_hub.ChatCompletionInputMessage

[[autodoc]] huggingface_hub.ChatCompletionInputTool
[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk

[[autodoc]] huggingface_hub.ChatCompletionInputToolCall
[[autodoc]] huggingface_hub.ChatCompletionInputTool

[[autodoc]] huggingface_hub.ChatCompletionInputToolTypeClass

[[autodoc]] huggingface_hub.ChatCompletionInputURL

[[autodoc]] huggingface_hub.ChatCompletionOutput

[[autodoc]] huggingface_hub.ChatCompletionOutputComplete
Expand Down
10 changes: 8 additions & 2 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,13 @@
"AutomaticSpeechRecognitionParameters",
"ChatCompletionInput",
"ChatCompletionInputFunctionDefinition",
"ChatCompletionInputFunctionName",
"ChatCompletionInputGrammarType",
"ChatCompletionInputMessage",
"ChatCompletionInputMessageChunk",
"ChatCompletionInputTool",
"ChatCompletionInputToolCall",
"ChatCompletionInputToolTypeClass",
"ChatCompletionInputURL",
"ChatCompletionOutput",
"ChatCompletionOutputComplete",
"ChatCompletionOutputFunctionDefinition",
Expand Down Expand Up @@ -775,10 +778,13 @@ def __dir__():
AutomaticSpeechRecognitionParameters, # noqa: F401
ChatCompletionInput, # noqa: F401
ChatCompletionInputFunctionDefinition, # noqa: F401
ChatCompletionInputFunctionName, # noqa: F401
ChatCompletionInputGrammarType, # noqa: F401
ChatCompletionInputMessage, # noqa: F401
ChatCompletionInputMessageChunk, # noqa: F401
ChatCompletionInputTool, # noqa: F401
ChatCompletionInputToolCall, # noqa: F401
ChatCompletionInputToolTypeClass, # noqa: F401
ChatCompletionInputURL, # noqa: F401
ChatCompletionOutput, # noqa: F401
ChatCompletionOutputComplete, # noqa: F401
ChatCompletionOutputFunctionDefinition, # noqa: F401
Expand Down
73 changes: 65 additions & 8 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
AudioClassificationOutputElement,
AudioToAudioOutputElement,
AutomaticSpeechRecognitionOutput,
ChatCompletionInputGrammarType,
ChatCompletionInputTool,
ChatCompletionInputToolTypeClass,
ChatCompletionOutput,
Expand All @@ -103,7 +104,6 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputToolTypeEnum
from huggingface_hub.inference._types import (
ConversationalOutput, # soon to be removed
)
Expand Down Expand Up @@ -465,10 +465,11 @@ def chat_completion( # type: ignore
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
Expand All @@ -488,10 +489,11 @@ def chat_completion( # type: ignore
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
Expand All @@ -511,10 +513,11 @@ def chat_completion(
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
Expand All @@ -534,10 +537,11 @@ def chat_completion(
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, ChatCompletionInputToolTypeEnum]] = None,
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
Expand Down Expand Up @@ -584,6 +588,8 @@ def chat_completion(
presence_penalty (`float`, *optional*):
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics.
response_format ([`ChatCompletionInputGrammarType`], *optional*):
Grammar constraints. Can be either a JSONSchema or a regex.
seed (Optional[`int`], *optional*):
Seed for reproducible control flow. Defaults to None.
stop (Optional[`str`], *optional*):
Expand All @@ -601,7 +607,7 @@ def chat_completion(
top_p (`float`, *optional*):
Fraction of the most likely next words to sample from.
Must be between 0 and 1. Defaults to 1.0.
tool_choice ([`ChatCompletionInputToolTypeClass`] or [`ChatCompletionInputToolTypeEnum`], *optional*):
tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*):
The tool to use for the completion. Defaults to "auto".
tool_prompt (`str`, *optional*):
A prompt to be appended before the tools.
Expand All @@ -624,7 +630,6 @@ def chat_completion(
Example:
```py
# Chat example
>>> from huggingface_hub import InferenceClient
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
>>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
Expand Down Expand Up @@ -654,7 +659,13 @@ def chat_completion(
total_tokens=25
)
)
```
Example (stream=True):
```py
>>> from huggingface_hub import InferenceClient
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
>>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
>>> for token in client.chat_completion(messages, max_tokens=10, stream=True):
... print(token)
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504)
Expand Down Expand Up @@ -770,6 +781,37 @@ def chat_completion(
description=None
)
```
Example using response_format:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> messages = [
... {
... "role": "user",
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
... },
... ]
>>> response_format = {
... "type": "json",
... "value": {
... "properties": {
... "location": {"type": "string"},
... "activity": {"type": "string"},
... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5},
... "animals": {"type": "array", "items": {"type": "string"}},
... },
... "required": ["location", "activity", "animals_seen", "animals"],
... },
... }
>>> response = client.chat_completion(
... messages=messages,
... response_format=response_format,
... max_tokens=500,
)
>>> response.choices[0].message.content
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```
"""
# Determine model
# `self.xxx` takes precedence over the method argument only in `chat_completion`
Expand Down Expand Up @@ -804,6 +846,7 @@ def chat_completion(
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
Expand Down Expand Up @@ -855,6 +898,11 @@ def chat_completion(
"Tools are not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided tool parameters will be ignored."
)
if response_format is not None:
warnings.warn(
"Response format is not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided response format will be ignored."
)

# generate response
text_generation_output = self.text_generation(
Expand All @@ -873,7 +921,6 @@ def chat_completion(
return ChatCompletionOutput(
id="dummy",
model="dummy",
object="dummy",
system_fingerprint="dummy",
usage=None, # type: ignore # set to `None` as we don't want to provide false information
created=int(time.time()),
Expand Down Expand Up @@ -1742,6 +1789,7 @@ def text_generation( # type: ignore
stream: Literal[False] = ...,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1770,6 +1818,7 @@ def text_generation( # type: ignore
stream: Literal[False] = ...,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1798,6 +1847,7 @@ def text_generation( # type: ignore
stream: Literal[True] = ...,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1826,6 +1876,7 @@ def text_generation( # type: ignore
stream: Literal[True] = ...,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1854,6 +1905,7 @@ def text_generation(
stream: bool = ...,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1881,6 +1933,7 @@ def text_generation(
stream: bool = False,
model: Optional[str] = None,
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
adapter_id: Optional[str] = None,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
do_sample: Optional[bool] = False, # Manual default value
Expand Down Expand Up @@ -1932,6 +1985,8 @@ def text_generation(
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
adapter_id (`str`, *optional*):
Lora adapter id.
best_of (`int`, *optional*):
Generate best_of sequences and return the one if the highest token logprobs.
decoder_input_details (`bool`, *optional*):
Expand Down Expand Up @@ -2100,6 +2155,7 @@ def text_generation(

# Build payload
parameters = {
"adapter_id": adapter_id,
"best_of": best_of,
"decoder_input_details": decoder_input_details,
"details": details,
Expand Down Expand Up @@ -2170,6 +2226,7 @@ def text_generation(
details=details,
stream=stream,
model=model,
adapter_id=adapter_id,
best_of=best_of,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
Expand Down
2 changes: 0 additions & 2 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ def _format_chat_completion_stream_output_from_text_generation(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
object="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
Expand All @@ -335,7 +334,6 @@ def _format_chat_completion_stream_output_from_text_generation(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
object="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
Expand Down
Loading

0 comments on commit 36396f1

Please sign in to comment.