Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for tools in HuggingFaceAPIChatGenerator #8661

Merged
merged 6 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a suggestion for users.
The ChatCompletionInputTool type was introduced recently, so it is better to install the latest version.

from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
152 changes: 107 additions & 45 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
ChatCompletionInputTool,
ChatCompletionOutput,
ChatCompletionStreamOutput,
InferenceClient,
)


logger = logging.getLogger(__name__)


def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
"""
Convert a message to the format expected by Hugging Face APIs.

:returns: A dictionary with the following keys:
- `role`
- `content`
"""
return {"role": message.role.value, "content": message.text or ""}


@component
class HuggingFaceAPIChatGenerator:
"""
Expand Down Expand Up @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
Expand All @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
`TEXT_GENERATION_INFERENCE`.
:param token: The Hugging Face token to use as HTTP bearer authorization.
:param token:
The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param generation_kwargs:
A dictionary with keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_p`.
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param stop_words:
An optional list of strings representing the stop words.
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior.
"""

huggingface_hub_import.check()
Expand Down Expand Up @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
msg = f"Unknown api_type {api_type}"
raise ValueError(msg)

if tools:
if streaming_callback is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this across the board for all CG - just something to write down and not forget

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Most of Generators support Tools + Streaming, I think.

Theoretically also HF could support this but it is highly undocumented and does not behave consistently.
See deepset-ai/haystack-experimental#120 (comment)

raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
Expand All @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
api_type=str(self.api_type),
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
tools=serialized_tools,
)

@classmethod
Expand All @@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.

:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
formatted_messages = [convert_message_to_hf_format(message) for message in messages]

tools = tools or self.tools
if tools:
if self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)

return self._run_non_streaming(formatted_messages, generation_kwargs)
hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
Expand All @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict

generated_text = ""

for chunk in api_output: # pylint: disable=not-an-iterable
text = chunk.choices[0].delta.content
for chunk in api_output:
# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]

text = choice.delta.content
if text:
generated_text += text
finish_reason = chunk.choices[0].finish_reason

finish_reason = choice.finish_reason

meta = {}
if finish_reason:
Expand All @@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)

message = ChatMessage.from_assistant(generated_text)
message.meta.update(
meta.update(
{
"model": self._client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
}
)

message = ChatMessage.from_assistant(text=generated_text, meta=meta)

return {"replies": [message]}

def _run_non_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []

api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
}
)
chat_messages.append(message)

return {"replies": chat_messages}
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)

if len(api_chat_output.choices) == 0:
return {"replies": []}

# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = api_chat_output.choices[0]

text = choice.message.content
tool_calls = []

if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)

meta = {
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, I've been burned b4 on these usage structures assuming they are always there. Please check HF pydantic classes to make sure they are not optional.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now taken a more cautious approach. Thank you!

"completion_tokens": api_chat_output.usage.completion_tokens,
},
}

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}
2 changes: 1 addition & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
Expand Down
15 changes: 14 additions & 1 deletion haystack/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

from pydantic import create_model

Expand Down Expand Up @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
del property_schema[key]


def _check_duplicate_tool_names(tools: List[Tool]) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a common check we do, so I moved it to the Tool module

"""
Check for duplicate tool names.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and raises ValueError if they are found.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!


:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")


def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
"""
Deserialize Tools in a dictionary inplace.
Expand Down
42 changes: 40 additions & 2 deletions haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from typing import Any, Callable, Dict, List, Optional, Union

from haystack import logging
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice

with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import:
import torch

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError

Expand Down Expand Up @@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)


def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use this function in the HF Local Chat Generator too (I recently discovered that it does not work properly without the conversion), so I moved it to utils.

(I will open issue + PR to update the HF Local Chat Generator.)

"""
Convert a message to the format expected by Hugging Face.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
if len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

# HF always expects a content field, even if it is empty
hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""}

if tool_call_results:
result = tool_call_results[0]
hf_msg["content"] = result.result
if tc_id := result.origin.id:
hf_msg["tool_call_id"] = tc_id
# HF does not provide a way to communicate errors in tool invocations, so we ignore the error field
return hf_msg

if text_contents:
hf_msg["content"] = text_contents[0]
if tool_calls:
hf_tool_calls = []
for tc in tool_calls:
hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}}
if tc.id is not None:
hf_tool_call["id"] = tc.id
hf_tool_calls.append(hf_tool_call)
hf_msg["tool_calls"] = hf_tool_calls

return hf_msg


with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ extra-dependencies = [
"numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x

"transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/hfapi-tools-a7224150bce52564.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add support for Tools in the Hugging Face API Chat Generator.
Loading
Loading