From 7c9532b20012af63484725d229db47dcc0894c6d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 8 May 2024 17:14:37 +0200 Subject: [PATCH] fix broken serialization of HFAPI components (#7661) --- .../embedders/hugging_face_api_document_embedder.py | 2 +- .../components/embedders/hugging_face_api_text_embedder.py | 2 +- haystack/components/generators/chat/hugging_face_api.py | 2 +- haystack/components/generators/hugging_face_api.py | 2 +- .../notes/fix-hf-api-serialization-026b84de29827c57.yaml | 5 +++++ .../embedders/test_hugging_face_api_document_embedder.py | 2 +- .../embedders/test_hugging_face_api_text_embedder.py | 2 +- test/components/generators/chat/test_hugging_face_api.py | 2 +- test/components/generators/test_hugging_face_api.py | 2 +- 9 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/fix-hf-api-serialization-026b84de29827c57.yaml diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 188d449821..6aa221ec02 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -175,7 +175,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - api_type=self.api_type, + api_type=str(self.api_type), api_params=self.api_params, prefix=self.prefix, suffix=self.suffix, diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index 10db3c0121..53ed7d20e4 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -142,7 +142,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - api_type=self.api_type, + api_type=str(self.api_type), api_params=self.api_params, prefix=self.prefix, suffix=self.suffix, diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 0deea74ab4..8c108dbb81 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -158,7 +158,7 @@ def to_dict(self) -> Dict[str, Any]: callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, - api_type=self.api_type, + api_type=str(self.api_type), api_params=self.api_params, token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index b0dd1aafb4..47ae35ed0e 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -142,7 +142,7 @@ def to_dict(self) -> Dict[str, Any]: callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, - api_type=self.api_type, + api_type=str(self.api_type), api_params=self.api_params, token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, diff --git a/releasenotes/notes/fix-hf-api-serialization-026b84de29827c57.yaml b/releasenotes/notes/fix-hf-api-serialization-026b84de29827c57.yaml new file mode 100644 index 0000000000..5fa7d9c3cf --- /dev/null +++ b/releasenotes/notes/fix-hf-api-serialization-026b84de29827c57.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix the broken serialization of HuggingFaceAPITextEmbedder, HuggingFaceAPIDocumentEmbedder, + HuggingFaceAPIGenerator, and HuggingFaceAPIChatGenerator. diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py index 4f8c6796f8..9ffae748aa 100644 --- a/test/components/embedders/test_hugging_face_api_document_embedder.py +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -109,7 +109,7 @@ def test_to_dict(self, mock_check_valid_model): assert data == { "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", "init_parameters": { - "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_type": "serverless_inference_api", "api_params": {"model": "BAAI/bge-small-en-v1.5"}, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "prefix": "prefix", diff --git a/test/components/embedders/test_hugging_face_api_text_embedder.py b/test/components/embedders/test_hugging_face_api_text_embedder.py index 3299ea17c9..aa00112f13 100644 --- a/test/components/embedders/test_hugging_face_api_text_embedder.py +++ b/test/components/embedders/test_hugging_face_api_text_embedder.py @@ -95,7 +95,7 @@ def test_to_dict(self, mock_check_valid_model): assert data == { "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", "init_parameters": { - "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_type": "serverless_inference_api", "api_params": {"model": "BAAI/bge-small-en-v1.5"}, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "prefix": "prefix", diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 2c9d523c19..3cc7820150 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -138,7 +138,7 @@ def test_to_dict(self, mock_check_valid_model): result = generator.to_dict() init_params = result["init_parameters"] - assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 21bca849b6..d88d90645c 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -131,7 +131,7 @@ def test_to_dict(self, mock_check_valid_model): result = generator.to_dict() init_params = result["init_parameters"] - assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {