From e3dccf44062b4dae03f92806ac5d27574af797f4 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 23 May 2024 16:28:24 +0200 Subject: [PATCH] add timeout to AzureOpenAIGenerator (#7724) * add timeout to AzureOpenAIGenerator * add to chat also * Update azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml --- haystack/components/generators/azure.py | 5 +++++ haystack/components/generators/chat/azure.py | 3 +++ .../azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml | 4 ++++ test/components/generators/chat/test_azure.py | 3 +++ test/components/generators/test_azure.py | 4 ++++ 5 files changed, 19 insertions(+) create mode 100644 releasenotes/notes/azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 2c432d8231..caa3e0aae0 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -62,6 +62,7 @@ def __init__( organization: Optional[str] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, system_prompt: Optional[str] = None, + timeout: Optional[float] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -77,6 +78,7 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param system_prompt: The prompt to use for the system. If not provided, the system prompt will be + :param timeout: The timeout to be passed to the underlying `AzureOpenAI` client. :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. @@ -123,6 +125,7 @@ def __init__( self.azure_deployment = azure_deployment self.organization = organization self.model: str = azure_deployment or "gpt-35-turbo" + self.timeout = timeout self.client = AzureOpenAI( api_version=api_version, @@ -131,6 +134,7 @@ def __init__( api_key=api_key.resolve_value() if api_key is not None else None, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None, organization=organization, + timeout=timeout, ) def to_dict(self) -> Dict[str, Any]: @@ -152,6 +156,7 @@ def to_dict(self) -> Dict[str, Any]: system_prompt=self.system_prompt, api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, + timeout=self.timeout, ) @classmethod diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index b6cd2e153b..376ebb5ce4 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -80,6 +80,7 @@ def __init__( azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False), organization: Optional[str] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + timeout: Optional[float] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -139,6 +140,7 @@ def __init__( self.azure_deployment = azure_deployment self.organization = organization self.model = azure_deployment or "gpt-35-turbo" + self.timeout = timeout self.client = AzureOpenAI( api_version=api_version, @@ -165,6 +167,7 @@ def to_dict(self) -> Dict[str, Any]: api_version=self.api_version, streaming_callback=callback_name, generation_kwargs=self.generation_kwargs, + timeout=self.timeout, api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, ) diff --git a/releasenotes/notes/azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml b/releasenotes/notes/azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml new file mode 100644 index 0000000000..49217671fc --- /dev/null +++ b/releasenotes/notes/azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + `AzureOpenAIGenerator` and `AzureOpenAIChatGenerator` can now be configured passing a timeout for the underlying `AzureOpenAI` client. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index c9693caac8..80ffb95641 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -55,6 +55,7 @@ def test_to_dict_default(self, monkeypatch): "organization": None, "streaming_callback": None, "generation_kwargs": {}, + "timeout": None, }, } @@ -64,6 +65,7 @@ def test_to_dict_with_parameters(self, monkeypatch): api_key=Secret.from_env_var("ENV_VAR", strict=False), azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False), azure_endpoint="some-non-existing-endpoint", + timeout=2.5, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() @@ -77,6 +79,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "azure_deployment": "gpt-35-turbo", "organization": None, "streaming_callback": None, + "timeout": 2.5, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index d5d52b6f26..39bef98ae3 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -39,6 +39,7 @@ def test_init_with_parameters(self): assert component.client.api_key == "fake-api-key" assert component.azure_deployment == "gpt-35-turbo" assert component.streaming_callback is print_streaming_chunk + assert component.timeout is None assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): @@ -56,6 +57,7 @@ def test_to_dict_default(self, monkeypatch): "azure_endpoint": "some-non-existing-endpoint", "organization": None, "system_prompt": None, + "timeout": None, "generation_kwargs": {}, }, } @@ -66,6 +68,7 @@ def test_to_dict_with_parameters(self, monkeypatch): api_key=Secret.from_env_var("ENV_VAR", strict=False), azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False), azure_endpoint="some-non-existing-endpoint", + timeout=3.5, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) @@ -81,6 +84,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "azure_endpoint": "some-non-existing-endpoint", "organization": None, "system_prompt": None, + "timeout": 3.5, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, }