From dec6cd0a346f47ae82d05987c051664f9627e7aa Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 10:09:42 +0200 Subject: [PATCH 1/6] add HuggingFaceAPITextEmbedder --- docs/pydoc/config/embedders_api.yml | 1 + haystack/components/embedders/__init__.py | 2 + .../hugging_face_api_text_embedder.py | 197 ++++++++++++++++++ .../hugging_face_tei_text_embedder.py | 7 + haystack/utils/hf.py | 22 ++ haystack/utils/url_validation.py | 6 + .../hfapitextembedder-97bf5f739f413f3e.yaml | 13 ++ .../test_hugging_face_api_text_embedder.py | 174 ++++++++++++++++ 8 files changed, 422 insertions(+) create mode 100644 haystack/components/embedders/hugging_face_api_text_embedder.py create mode 100644 haystack/utils/url_validation.py create mode 100644 releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml create mode 100644 test/components/embedders/test_hugging_face_api_text_embedder.py diff --git a/docs/pydoc/config/embedders_api.yml b/docs/pydoc/config/embedders_api.yml index c5b5b8a906..326d98c881 100644 --- a/docs/pydoc/config/embedders_api.yml +++ b/docs/pydoc/config/embedders_api.yml @@ -7,6 +7,7 @@ loaders: "azure_text_embedder", "hugging_face_tei_document_embedder", "hugging_face_tei_text_embedder", + "hugging_face_api_text_embedder", "openai_document_embedder", "openai_text_embedder", "sentence_transformers_document_embedder", diff --git a/haystack/components/embedders/__init__.py b/haystack/components/embedders/__init__.py index 6ff4e339a7..a2e3d15a4f 100644 --- a/haystack/components/embedders/__init__.py +++ b/haystack/components/embedders/__init__.py @@ -1,5 +1,6 @@ from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder +from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder @@ -10,6 +11,7 @@ __all__ = [ "HuggingFaceTEITextEmbedder", "HuggingFaceTEIDocumentEmbedder", + "HuggingFaceAPITextEmbedder", "SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder", "OpenAITextEmbedder", diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py new file mode 100644 index 0000000000..d95ef7c6d2 --- /dev/null +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -0,0 +1,197 @@ +import json +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model +from haystack.utils.url_validation import is_valid_http_url + +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import: + from huggingface_hub import InferenceClient + +logger = logging.getLogger(__name__) + + +@component +class HuggingFaceAPITextEmbedder: + """ + This component can be used to embedding strings using different Hugging Face APIs: + - [free Serverless Inference API]((https://huggingface.co/inference-api) + - [paid Inference Endpoints](https://huggingface.co/inference-endpoints) + - [self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) + + + Example usage with the free Serverless Inference API: + ```python + from haystack.components.embedders import HuggingFaceAPITextEmbedder + from haystack.utils import Secret + + text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api", + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("")) + + print(text_embedder.run("I love pizza!")) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + ``` + + Example usage with paid Inference Endpoints: + ```python + from haystack.components.embedders import HuggingFaceAPITextEmbedder + from haystack.utils import Secret + text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api", + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("")) + + print(text_embedder.run("I love pizza!")) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + ``` + + Example usage with self-hosted Text Embeddings Inference: + ```python + from haystack.components.embedders import HuggingFaceAPITextEmbedder + from haystack.utils import Secret + + text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference", + api_params={"url": "http://localhost:8080"}) + + print(text_embedder.run("I love pizza!")) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + ``` + """ + + def __init__( + self, + api_type: Union[HFEmbeddingAPIType, str] = HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params: Optional[Dict[str, str]] = None, + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + prefix: str = "", + suffix: str = "", + truncate: bool = True, + normalize: bool = False, + ): + """ + Create an HuggingFaceAPITextEmbedder component. + + :param api_type: + The type of Hugging Face API to use. + :param api_params: + A dictionary containing the following keys: + - `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`. + - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`. + :param token: The HuggingFace token to use as HTTP bearer authorization + You can find your HF token in your [account settings](https://huggingface.co/settings/tokens) + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param truncate: + Truncate input text from the end to the maximum length supported by the model. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed). + :param normalize: + Normalize the embeddings to unit length. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed). + """ + huggingface_hub_import.check() + + if isinstance(api_type, str): + api_type = HFEmbeddingAPIType.from_str(api_type) + + api_params = api_params or {} + + if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + model = api_params.get("model") + if model is None: + raise ValueError( + "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." + ) + check_valid_model(model, HFModelType.EMBEDDING, token) + model_or_url = model + elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]: + url = api_params.get("url") + if url is None: + raise ValueError( + "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`." + ) + if not is_valid_http_url(url): + raise ValueError(f"Invalid URL: {url}") + model_or_url = url + else: + raise ValueError( + f"Unsupported API type: {api_type}. Supported types are: {[e.value for e in HFEmbeddingAPIType]}" + ) + + self.api_type = api_type + self.api_params = api_params + self.token = token + self.prefix = prefix + self.suffix = suffix + self.truncate = truncate + self.normalize = normalize + self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + api_type=self.api_type, + api_params=self.api_params, + prefix=self.prefix, + suffix=self.suffix, + token=self.token.to_dict() if self.token else None, + truncate=self.truncate, + normalize=self.normalize, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPITextEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + @component.output_types(embedding=List[float]) + def run(self, text: str): + """ + Embed a single string. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + """ + if not isinstance(text, str): + raise TypeError( + "HuggingFaceAPITextEmbedder expects a string as an input." + "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder." + ) + + text_to_embed = self.prefix + text + self.suffix + + response = self._client.post( + json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize}, + task="feature-extraction", + ) + embedding = json.loads(response.decode())[0] + + return {"embedding": embedding} diff --git a/haystack/components/embedders/hugging_face_tei_text_embedder.py b/haystack/components/embedders/hugging_face_tei_text_embedder.py index f618214e30..5956ee1d83 100644 --- a/haystack/components/embedders/hugging_face_tei_text_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_text_embedder.py @@ -1,4 +1,5 @@ import json +import warnings from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -74,6 +75,12 @@ def __init__( Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI. """ + warnings.warn( + "`HuggingFaceTEITextEmbedder` is deprecated and will be removed in Haystack 2.3.0." + "Use `HuggingFaceAPITextEmbedder` instead.", + DeprecationWarning, + ) + huggingface_hub_import.check() if url: diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 14f2bfacbf..deeca89360 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -21,6 +21,28 @@ logger = logging.getLogger(__name__) +class HFEmbeddingAPIType(Enum): + """ + API type to use for Hugging Face API Embedders. + """ + + TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference" + INFERENCE_ENDPOINTS = "inference_endpoints" + SERVERLESS_INFERENCE_API = "serverless_inference_api" + + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "HFEmbeddingAPIType": + enum_map = {e.value: e for e in HFEmbeddingAPIType} + mode = enum_map.get(string) + if mode is None: + msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}" + raise ValueError(msg) + return mode + + class HFModelType(Enum): EMBEDDING = 1 GENERATION = 2 diff --git a/haystack/utils/url_validation.py b/haystack/utils/url_validation.py new file mode 100644 index 0000000000..6d4c9e69d5 --- /dev/null +++ b/haystack/utils/url_validation.py @@ -0,0 +1,6 @@ +from urllib.parse import urlparse + + +def is_valid_http_url(url) -> bool: + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) diff --git a/releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml b/releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml new file mode 100644 index 0000000000..1f8ac29dee --- /dev/null +++ b/releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Introduce `HuggingFaceAPITextEmbedder`. + This component can be used to embed strings using different Hugging Face APIs: + - free Serverless Inference API + - paid Inference Endpoints + - self-hosted Text Embeddings Inference. + This embedder will replace the `HuggingFaceTEITextEmbedder` in the future. +deprecations: + - | + Deprecate `HuggingFaceTEITextEmbedder`. This component will be removed in Haystack 2.3.0. + Use `HuggingFaceAPITextEmbedder` instead. diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py new file mode 100644 index 0000000000..0acb5a12fa --- /dev/null +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -0,0 +1,174 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput +from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array, random + +from haystack.components.embedders import HuggingFaceAPITextEmbedder +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFEmbeddingAPIType + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.components.embedders.hugging_face_api_text_embedder.check_valid_model", MagicMock(return_value=None) + ) as mock: + yield mock + + +def mock_embedding_generation(json, **kwargs): + response = str(array([random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode() + return response + + +class TestHuggingFaceAPITextEmbedder: + def test_init_invalid_api_type(self): + with pytest.raises(ValueError): + HuggingFaceAPITextEmbedder(api_type="invalid_api_type") + + def test_init_serverless(self, mock_check_valid_model): + model = "BAAI/bge-small-en-v1.5" + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model} + ) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": model} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + + def test_init_serverless_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} + ) + + def test_init_serverless_no_model(self): + with pytest.raises(ValueError): + HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} + ) + + def test_init_tei(self): + url = "https://some_model.com" + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url} + ) + + assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE + assert embedder.api_params == {"url": url} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + + def test_init_tei_invalid_url(self): + with pytest.raises(ValueError): + HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"} + ) + + def test_init_tei_no_url(self): + with pytest.raises(ValueError): + HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"} + ) + + def test_to_dict(self, mock_check_valid_model): + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + prefix="prefix", + suffix="suffix", + truncate=False, + normalize=True, + ) + + data = embedder.to_dict() + + assert data == { + "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + }, + } + + def test_from_dict(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + }, + } + + embedder = HuggingFaceAPITextEmbedder.from_dict(data) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"} + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert not embedder.truncate + assert embedder.normalize + + def test_run_wrong_input_format(self, mock_check_valid_model): + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError): + embedder.run(text=list_integers_input) + + def test_run(self, mock_check_valid_model): + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + prefix="prefix ", + suffix=" suffix", + ) + + result = embedder.run(text="The food was delicious") + + mock_embedding_patch.assert_called_once_with( + json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False}, + task="feature-extraction", + ) + + assert len(result["embedding"]) == 384 + assert all(isinstance(x, float) for x in result["embedding"]) + + @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.integration + def test_live_run_serverless(self): + embedder = HuggingFaceAPITextEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"}, + ) + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 384 + assert all(isinstance(x, float) for x in result["embedding"]) From 045fe1cce038ca3fb79a55ae757f5fc1d6566440 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 10:10:21 +0200 Subject: [PATCH 2/6] add HuggingFaceAPITextEmbedder --- haystack/components/embedders/hugging_face_api_text_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index d95ef7c6d2..295dd29cac 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -16,7 +16,7 @@ @component class HuggingFaceAPITextEmbedder: """ - This component can be used to embedding strings using different Hugging Face APIs: + This component can be used to embed strings using different Hugging Face APIs: - [free Serverless Inference API]((https://huggingface.co/inference-api) - [paid Inference Endpoints](https://huggingface.co/inference-endpoints) - [self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) From fb21201a017fb1390b26d93224ad541e248f93e7 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 10:51:11 +0200 Subject: [PATCH 3/6] rm unneeded else --- .../components/embedders/hugging_face_api_text_embedder.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 295dd29cac..0eb9b29feb 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -123,10 +123,6 @@ def __init__( if not is_valid_http_url(url): raise ValueError(f"Invalid URL: {url}") model_or_url = url - else: - raise ValueError( - f"Unsupported API type: {api_type}. Supported types are: {[e.value for e in HFEmbeddingAPIType]}" - ) self.api_type = api_type self.api_params = api_params From 14cedf4059440f8475534ca0cfd24ec1f5c91839 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 12:14:19 +0200 Subject: [PATCH 4/6] small fixes --- .../components/embedders/hugging_face_api_text_embedder.py | 2 +- .../embedders/test_hugging_face_api_text_embedder.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 0eb9b29feb..e4dcf895f3 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -40,7 +40,7 @@ class HuggingFaceAPITextEmbedder: ```python from haystack.components.embedders import HuggingFaceAPITextEmbedder from haystack.utils import Secret - text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api", + text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints", api_params={"model": "BAAI/bge-small-en-v1.5"}, token=Secret.from_token("")) diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 0acb5a12fa..33468b736d 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -1,12 +1,10 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest -from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput from huggingface_hub.utils import RepositoryNotFoundError from numpy import array, random from haystack.components.embedders import HuggingFaceAPITextEmbedder -from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret from haystack.utils.hf import HFEmbeddingAPIType From e467febdd0d9f1e340b68f466c5b20af0fa834f2 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 16:52:10 +0200 Subject: [PATCH 5/6] changes requested --- .../embedders/hugging_face_api_text_embedder.py | 12 +++++------- haystack/utils/hf.py | 5 +++++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index e4dcf895f3..de7c3097b2 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -17,9 +17,9 @@ class HuggingFaceAPITextEmbedder: """ This component can be used to embed strings using different Hugging Face APIs: - - [free Serverless Inference API]((https://huggingface.co/inference-api) - - [paid Inference Endpoints](https://huggingface.co/inference-endpoints) - - [self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) + - [Free Serverless Inference API]((https://huggingface.co/inference-api) + - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) + - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) Example usage with the free Serverless Inference API: @@ -65,8 +65,8 @@ class HuggingFaceAPITextEmbedder: def __init__( self, - api_type: Union[HFEmbeddingAPIType, str] = HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, - api_params: Optional[Dict[str, str]] = None, + api_type: Union[HFEmbeddingAPIType, str], + api_params: Dict[str, str], token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), prefix: str = "", suffix: str = "", @@ -104,8 +104,6 @@ def __init__( if isinstance(api_type, str): api_type = HFEmbeddingAPIType.from_str(api_type) - api_params = api_params or {} - if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: model = api_params.get("model") if model is None: diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index deeca89360..486d457551 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -26,8 +26,13 @@ class HFEmbeddingAPIType(Enum): API type to use for Hugging Face API Embedders. """ + # HF [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference). TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference" + + # HF [Inference Endpoints](https://huggingface.co/inference-endpoints). INFERENCE_ENDPOINTS = "inference_endpoints" + + # HF [Serverless Inference API](https://huggingface.co/inference-api). SERVERLESS_INFERENCE_API = "serverless_inference_api" def __str__(self): From 7aa15720767bf37dd5aea72f357805d4817a92d2 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 5 Apr 2024 17:56:02 +0200 Subject: [PATCH 6/6] fix test --- .../components/embedders/test_hugging_face_api_text_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 33468b736d..c5a14db481 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -25,7 +25,7 @@ def mock_embedding_generation(json, **kwargs): class TestHuggingFaceAPITextEmbedder: def test_init_invalid_api_type(self): with pytest.raises(ValueError): - HuggingFaceAPITextEmbedder(api_type="invalid_api_type") + HuggingFaceAPITextEmbedder(api_type="invalid_api_type", api_params={}) def test_init_serverless(self, mock_check_valid_model): model = "BAAI/bge-small-en-v1.5"