diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 5cd7a1430c..3727d234af 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -4,7 +4,7 @@ # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators import OpenAIGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) +@component class AzureOpenAIGenerator(OpenAIGenerator): """ A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text. diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index e1d3029ee7..697eb58b89 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -4,7 +4,7 @@ # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) +@component class AzureOpenAIChatGenerator(OpenAIChatGenerator): """ A Chat Generator component that uses the Azure OpenAI API to generate text. diff --git a/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml new file mode 100644 index 0000000000..071969e62e --- /dev/null +++ b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Azure generators components fixed, they were missing the `@component` decorator. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 85e0e2c12c..868eaba6d7 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -3,6 +3,7 @@ import pytest from openai import OpenAIError +from haystack import Pipeline from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage @@ -77,6 +78,15 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index 0999e450ee..6870dae892 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -1,4 +1,6 @@ import os + +from haystack import Pipeline from haystack.utils.auth import Secret import pytest @@ -80,6 +82,15 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),