-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support for tools in HuggingFaceAPIChatGenerator
#8661
Changes from 4 commits
4db6f40
44b0103
3bc5374
39c3383
39184a6
6832e31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,30 +5,25 @@ | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Union | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.dataclasses import ChatMessage, StreamingChunk | ||
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall | ||
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace | ||
from haystack.lazy_imports import LazyImport | ||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable | ||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model | ||
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format | ||
from haystack.utils.url_validation import is_valid_http_url | ||
|
||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: | ||
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient | ||
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: | ||
from huggingface_hub import ( | ||
ChatCompletionInputTool, | ||
ChatCompletionOutput, | ||
ChatCompletionStreamOutput, | ||
InferenceClient, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: | ||
""" | ||
Convert a message to the format expected by Hugging Face APIs. | ||
|
||
:returns: A dictionary with the following keys: | ||
- `role` | ||
- `content` | ||
""" | ||
return {"role": message.role.value, "content": message.text or ""} | ||
|
||
|
||
@component | ||
class HuggingFaceAPIChatGenerator: | ||
""" | ||
|
@@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments | |
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
stop_words: Optional[List[str]] = None, | ||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
tools: Optional[List[Tool]] = None, | ||
): | ||
""" | ||
Initialize the HuggingFaceAPIChatGenerator instance. | ||
|
@@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments | |
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. | ||
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or | ||
`TEXT_GENERATION_INFERENCE`. | ||
:param token: The Hugging Face token to use as HTTP bearer authorization. | ||
:param token: | ||
The Hugging Face token to use as HTTP bearer authorization. | ||
Check your HF token in your [account settings](https://huggingface.co/settings/tokens). | ||
:param generation_kwargs: | ||
A dictionary with keyword arguments to customize text generation. | ||
Some examples: `max_tokens`, `temperature`, `top_p`. | ||
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). | ||
:param stop_words: An optional list of strings representing the stop words. | ||
:param streaming_callback: An optional callable for handling streaming responses. | ||
:param stop_words: | ||
An optional list of strings representing the stop words. | ||
:param streaming_callback: | ||
An optional callable for handling streaming responses. | ||
:param tools: | ||
A list of tools for which the model can prepare calls. | ||
The chosen model should support tool/function calling, according to the model card. | ||
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience | ||
unexpected behavior. | ||
""" | ||
|
||
huggingface_hub_import.check() | ||
|
@@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments | |
msg = f"Unknown api_type {api_type}" | ||
raise ValueError(msg) | ||
|
||
if tools: | ||
if streaming_callback is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do this across the board for all CG - just something to write down and not forget There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. Most of Generators support Tools + Streaming, I think. Theoretically also HF could support this but it is highly undocumented and does not behave consistently. |
||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") | ||
_check_duplicate_tool_names(tools) | ||
|
||
# handle generation kwargs setup | ||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} | ||
generation_kwargs["stop"] = generation_kwargs.get("stop", []) | ||
|
@@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments | |
self.generation_kwargs = generation_kwargs | ||
self.streaming_callback = streaming_callback | ||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) | ||
self.tools = tools | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
|
@@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]: | |
A dictionary containing the serialized component. | ||
""" | ||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None | ||
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None | ||
return default_to_dict( | ||
self, | ||
api_type=str(self.api_type), | ||
api_params=self.api_params, | ||
token=self.token.to_dict() if self.token else None, | ||
generation_kwargs=self.generation_kwargs, | ||
streaming_callback=callback_name, | ||
tools=serialized_tools, | ||
) | ||
|
||
@classmethod | ||
|
@@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": | |
Deserialize this component from a dictionary. | ||
""" | ||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) | ||
deserialize_tools_inplace(data["init_parameters"], key="tools") | ||
init_params = data.get("init_parameters", {}) | ||
serialized_callback_handler = init_params.get("streaming_callback") | ||
if serialized_callback_handler: | ||
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(replies=List[ChatMessage]) | ||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): | ||
def run( | ||
self, | ||
messages: List[ChatMessage], | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
tools: Optional[List[Tool]] = None, | ||
): | ||
""" | ||
Invoke the text generation inference based on the provided messages and generation parameters. | ||
|
||
:param messages: A list of ChatMessage objects representing the input messages. | ||
:param generation_kwargs: Additional keyword arguments for text generation. | ||
:param messages: | ||
A list of ChatMessage objects representing the input messages. | ||
:param generation_kwargs: | ||
Additional keyword arguments for text generation. | ||
:param tools: | ||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set | ||
during component initialization. | ||
:returns: A dictionary with the following keys: | ||
- `replies`: A list containing the generated responses as ChatMessage objects. | ||
""" | ||
|
||
# update generation kwargs by merging with the default ones | ||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} | ||
|
||
formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] | ||
formatted_messages = [convert_message_to_hf_format(message) for message in messages] | ||
|
||
tools = tools or self.tools | ||
if tools: | ||
if self.streaming_callback: | ||
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") | ||
_check_duplicate_tool_names(tools) | ||
|
||
if self.streaming_callback: | ||
return self._run_streaming(formatted_messages, generation_kwargs) | ||
|
||
return self._run_non_streaming(formatted_messages, generation_kwargs) | ||
hf_tools = None | ||
if tools: | ||
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] | ||
|
||
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) | ||
|
||
def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): | ||
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( | ||
|
@@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict | |
|
||
generated_text = "" | ||
|
||
for chunk in api_output: # pylint: disable=not-an-iterable | ||
text = chunk.choices[0].delta.content | ||
for chunk in api_output: | ||
# n is unused, so the API always returns only one choice | ||
# the argument is probably allowed for compatibility with OpenAI | ||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n | ||
choice = chunk.choices[0] | ||
|
||
text = choice.delta.content | ||
if text: | ||
generated_text += text | ||
finish_reason = chunk.choices[0].finish_reason | ||
|
||
finish_reason = choice.finish_reason | ||
|
||
meta = {} | ||
if finish_reason: | ||
|
@@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict | |
stream_chunk = StreamingChunk(text, meta) | ||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) | ||
|
||
message = ChatMessage.from_assistant(generated_text) | ||
message.meta.update( | ||
meta.update( | ||
{ | ||
"model": self._client.model, | ||
"finish_reason": finish_reason, | ||
"index": 0, | ||
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming | ||
} | ||
) | ||
|
||
message = ChatMessage.from_assistant(text=generated_text, meta=meta) | ||
|
||
return {"replies": [message]} | ||
|
||
def _run_non_streaming( | ||
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any] | ||
self, | ||
messages: List[Dict[str, str]], | ||
generation_kwargs: Dict[str, Any], | ||
tools: Optional[List["ChatCompletionInputTool"]] = None, | ||
) -> Dict[str, List[ChatMessage]]: | ||
chat_messages: List[ChatMessage] = [] | ||
|
||
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs) | ||
for choice in api_chat_output.choices: | ||
message = ChatMessage.from_assistant(choice.message.content) | ||
message.meta.update( | ||
{ | ||
"model": self._client.model, | ||
"finish_reason": choice.finish_reason, | ||
"index": choice.index, | ||
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0}, | ||
} | ||
) | ||
chat_messages.append(message) | ||
|
||
return {"replies": chat_messages} | ||
api_chat_output: ChatCompletionOutput = self._client.chat_completion( | ||
messages=messages, tools=tools, **generation_kwargs | ||
) | ||
|
||
if len(api_chat_output.choices) == 0: | ||
return {"replies": []} | ||
|
||
# n is unused, so the API always returns only one choice | ||
# the argument is probably allowed for compatibility with OpenAI | ||
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n | ||
choice = api_chat_output.choices[0] | ||
|
||
text = choice.message.content | ||
tool_calls = [] | ||
|
||
if hfapi_tool_calls := choice.message.tool_calls: | ||
for hfapi_tc in hfapi_tool_calls: | ||
tool_call = ToolCall( | ||
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id | ||
) | ||
tool_calls.append(tool_call) | ||
|
||
meta = { | ||
"model": self._client.model, | ||
"finish_reason": choice.finish_reason, | ||
"index": choice.index, | ||
"usage": { | ||
"prompt_tokens": api_chat_output.usage.prompt_tokens, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ugh, I've been burned b4 on these usage structures assuming they are always there. Please check HF pydantic classes to make sure they are not optional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have now taken a more cautious approach. Thank you! |
||
"completion_tokens": api_chat_output.usage.completion_tokens, | ||
}, | ||
} | ||
|
||
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) | ||
return {"replies": [message]} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
import inspect | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Dict, Optional | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
from pydantic import create_model | ||
|
||
|
@@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]): | |
del property_schema[key] | ||
|
||
|
||
def _check_duplicate_tool_names(tools: List[Tool]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a common check we do, so I moved it to the |
||
""" | ||
Check for duplicate tool names. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and raises ValueError if they are found. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed! |
||
|
||
:param tools: The list of tools to check. | ||
:raises ValueError: If duplicate tool names are found. | ||
""" | ||
tool_names = [tool.name for tool in tools] | ||
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} | ||
if duplicate_tool_names: | ||
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") | ||
|
||
|
||
def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): | ||
""" | ||
Deserialize Tools in a dictionary inplace. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,15 +8,15 @@ | |
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
from haystack import logging | ||
from haystack.dataclasses import StreamingChunk | ||
from haystack.dataclasses import ChatMessage, StreamingChunk | ||
from haystack.lazy_imports import LazyImport | ||
from haystack.utils.auth import Secret | ||
from haystack.utils.device import ComponentDevice | ||
|
||
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import: | ||
import torch | ||
|
||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: | ||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: | ||
from huggingface_hub import HfApi, InferenceClient, model_info | ||
from huggingface_hub.utils import RepositoryNotFoundError | ||
|
||
|
@@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte | |
) | ||
|
||
|
||
def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use this function in the HF Local Chat Generator too (I recently discovered that it does not work properly without the conversion), so I moved it to utils. (I will open issue + PR to update the HF Local Chat Generator.) |
||
""" | ||
Convert a message to the format expected by Hugging Face. | ||
""" | ||
text_contents = message.texts | ||
tool_calls = message.tool_calls | ||
tool_call_results = message.tool_call_results | ||
|
||
if not text_contents and not tool_calls and not tool_call_results: | ||
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") | ||
if len(text_contents) + len(tool_call_results) > 1: | ||
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") | ||
|
||
# HF always expects a content field, even if it is empty | ||
hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""} | ||
|
||
if tool_call_results: | ||
result = tool_call_results[0] | ||
hf_msg["content"] = result.result | ||
if tc_id := result.origin.id: | ||
hf_msg["tool_call_id"] = tc_id | ||
# HF does not provide a way to communicate errors in tool invocations, so we ignore the error field | ||
return hf_msg | ||
|
||
if text_contents: | ||
hf_msg["content"] = text_contents[0] | ||
if tool_calls: | ||
hf_tool_calls = [] | ||
for tc in tool_calls: | ||
hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} | ||
if tc.id is not None: | ||
hf_tool_call["id"] = tc.id | ||
hf_tool_calls.append(hf_tool_call) | ||
hf_msg["tool_calls"] = hf_tool_calls | ||
|
||
return hf_msg | ||
|
||
|
||
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: | ||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features: | ||
- | | ||
Add support for Tools in the Hugging Face API Chat Generator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only a suggestion for users.
The
ChatCompletionInputTool
type was introduced recently, so it is better to install the latest version.