-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add AzureOpenAIGenerator and AzureOpenAIChatGenerator #6648
Merged
Merged
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
94ce488
Add AzureGenerator and AzureChatGenerator
vblagoje e6b9d85
Add release note
vblagoje 3d507dc
Fix minor errors
vblagoje 4f751af
Fix mypy
vblagoje cd811a7
Rename AzureGenerator->AzureOpenAIGenerator, AzureChatGenerator->Azur…
vblagoje e3c11a1
Add comment about not calling super().__init__
vblagoje ed7538a
Add unit tests
vblagoje da6f21a
Add integration test
vblagoje d98b046
Add organization param
vblagoje 928a5e5
Merge branch 'main' into azure_generators
vblagoje 7e88d0a
Update release notes
vblagoje c9a7683
Merge branch 'main' into azure_generators
vblagoje 5826ef3
Fix unit test
vblagoje File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,12 @@ | ||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator | ||
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator | ||
from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator | ||
from haystack.components.generators.azure import AzureOpenAIGenerator | ||
|
||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "GPTGenerator"] | ||
__all__ = [ | ||
"HuggingFaceLocalGenerator", | ||
"HuggingFaceTGIGenerator", | ||
"OpenAIGenerator", | ||
"GPTGenerator", | ||
"AzureOpenAIGenerator", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import logging | ||
import os | ||
from typing import Optional, Callable, Dict, Any | ||
|
||
# pylint: disable=import-error | ||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI | ||
|
||
from haystack import default_to_dict, default_from_dict | ||
from haystack.components.generators import OpenAIGenerator | ||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler | ||
from haystack.dataclasses import StreamingChunk | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AzureOpenAIGenerator(OpenAIGenerator): | ||
""" | ||
Enables text generation using OpenAI's large language models (LLMs) on Azure. It supports gpt-4 and gpt-3.5-turbo | ||
family of models. | ||
|
||
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method | ||
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` | ||
parameter in `run` method. | ||
|
||
For more details on OpenAI models deployed on Azure, refer to the Microsoft | ||
[documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/). | ||
|
||
|
||
```python | ||
from haystack.components.generators import AzureOpenAIGenerator | ||
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>", | ||
api_key="<you api key>", | ||
azure_deployment="<this a model name, e.g. gpt-35-turbo>") | ||
response = client.run("What's Natural Language Processing? Be brief.") | ||
print(response) | ||
|
||
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on | ||
>> the interaction between computers and human language. It involves enabling computers to understand, interpret, | ||
>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model': | ||
>> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16, | ||
>> 'completion_tokens': 49, 'total_tokens': 65}}]} | ||
``` | ||
|
||
Key Features and Compatibility: | ||
- **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models. | ||
- **Streaming Support**: Supports streaming responses from the OpenAI API. | ||
- **Customizability**: Supports all parameters supported by the OpenAI API. | ||
|
||
Input and Output Format: | ||
- **String Format**: This component uses the strings for both input and output. | ||
""" | ||
|
||
# pylint: disable=super-init-not-called | ||
def __init__( | ||
self, | ||
azure_endpoint: Optional[str] = None, | ||
api_version: Optional[str] = "2023-05-15", | ||
azure_deployment: Optional[str] = "gpt-35-turbo", | ||
api_key: Optional[str] = None, | ||
azure_ad_token: Optional[str] = None, | ||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None, | ||
organization: Optional[str] = None, | ||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
system_prompt: Optional[str] = None, | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/` | ||
:param api_version: The version of the API to use. Defaults to 2023-05-15 | ||
:param azure_deployment: The deployment of the model, usually the model name. | ||
:param api_key: The API key to use for authentication. | ||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id | ||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked | ||
on every request. | ||
:param organization: The Organization ID, defaults to `None`. See | ||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function accepts StreamingChunk as an argument. | ||
:param system_prompt: The prompt to use for the system. If not provided, the system prompt will be | ||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to | ||
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for | ||
more details. | ||
Some of the supported parameters: | ||
- `max_tokens`: The maximum number of tokens the output text can have. | ||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. | ||
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. | ||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model | ||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens | ||
comprising the top 10% probability mass are considered. | ||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, | ||
it will generate two completions for each of the three prompts, ending up with 6 completions in total. | ||
- `stop`: One or more sequences after which the LLM should stop generating tokens. | ||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean | ||
the model will be less likely to repeat the same token in the text. | ||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text. | ||
Bigger values mean the model will be less likely to repeat the same token in the text. | ||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the | ||
values are the bias to add to that token. | ||
""" | ||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact | ||
# with the API. | ||
|
||
# Why is this here? | ||
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not | ||
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead | ||
# of passing it as a parameter. | ||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") | ||
if not azure_endpoint: | ||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") | ||
|
||
self.generation_kwargs = generation_kwargs or {} | ||
self.system_prompt = system_prompt | ||
self.streaming_callback = streaming_callback | ||
self.api_version = api_version | ||
self.azure_endpoint = azure_endpoint | ||
self.azure_deployment = azure_deployment | ||
self.organization = organization | ||
self.model_name: str = azure_deployment or "gpt-35-turbo" | ||
|
||
self.client = AzureOpenAI( | ||
api_version=api_version, | ||
azure_endpoint=azure_endpoint, | ||
azure_deployment=azure_deployment, | ||
api_key=api_key, | ||
azure_ad_token=azure_ad_token, | ||
azure_ad_token_provider=azure_ad_token_provider, | ||
organization=organization, | ||
) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:return: The serialized component as a dictionary. | ||
""" | ||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None | ||
return default_to_dict( | ||
self, | ||
azure_endpoint=self.azure_endpoint, | ||
azure_deployment=self.azure_deployment, | ||
organization=self.organization, | ||
api_version=self.api_version, | ||
streaming_callback=callback_name, | ||
generation_kwargs=self.generation_kwargs, | ||
system_prompt=self.system_prompt, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: The dictionary representation of this component. | ||
:return: The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
serialized_callback_handler = init_params.get("streaming_callback") | ||
if serialized_callback_handler: | ||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) | ||
return default_from_dict(cls, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator | ||
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator | ||
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator | ||
|
||
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator"] | ||
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator", "AzureOpenAIChatGenerator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import logging | ||
import os | ||
from typing import Optional, Callable, Dict, Any | ||
|
||
# pylint: disable=import-error | ||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI | ||
|
||
from haystack import default_to_dict, default_from_dict | ||
from haystack.components.generators.chat import OpenAIChatGenerator | ||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler | ||
from haystack.dataclasses import StreamingChunk | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AzureOpenAIChatGenerator(OpenAIChatGenerator): | ||
""" | ||
Enables text generation using OpenAI's large language models (LLMs) on Azure. It supports gpt-4 and gpt-3.5-turbo | ||
family of models accessed through the chat completions API endpoint. | ||
|
||
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method | ||
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` | ||
parameter in `run` method. | ||
|
||
For more details on OpenAI models deployed on Azure, refer to the Microsoft | ||
[documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/). | ||
|
||
```python | ||
from haystack.components.generators.chat import AzureOpenAIGenerator | ||
from haystack.dataclasses import ChatMessage | ||
|
||
messages = [ChatMessage.from_user("What's Natural Language Processing?")] | ||
|
||
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>", | ||
api_key="<you api key>", | ||
azure_deployment="<this a model name, e.g. gpt-35-turbo>") | ||
response = client.run(messages) | ||
print(response) | ||
|
||
>>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence | ||
>>that focuses on enabling computers to understand, interpret, and generate human language in a way that is | ||
>>meaningful and useful.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, | ||
>>meta={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', | ||
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]} | ||
|
||
``` | ||
|
||
Key Features and Compatibility: | ||
- **Primary Compatibility**: Designed to work seamlessly with the OpenAI API Chat Completion endpoint | ||
and gpt-4 and gpt-3.5-turbo family of models. | ||
- **Streaming Support**: Supports streaming responses from the OpenAI API Chat Completion endpoint. | ||
- **Customizability**: Supports all parameters supported by the OpenAI API Chat Completion endpoint. | ||
|
||
Input and Output Format: | ||
- **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output, | ||
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the | ||
ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md. | ||
""" | ||
|
||
# pylint: disable=super-init-not-called | ||
def __init__( | ||
self, | ||
azure_endpoint: Optional[str] = None, | ||
api_version: Optional[str] = "2023-05-15", | ||
azure_deployment: Optional[str] = "gpt-35-turbo", | ||
api_key: Optional[str] = None, | ||
azure_ad_token: Optional[str] = None, | ||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None, | ||
organization: Optional[str] = None, | ||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/` | ||
:param api_version: The version of the API to use. Defaults to 2023-05-15 | ||
:param azure_deployment: The deployment of the model, usually the model name. | ||
:param api_key: The API key to use for authentication. | ||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id | ||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked | ||
on every request. | ||
:param organization: The Organization ID, defaults to `None`. See | ||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function accepts StreamingChunk as an argument. | ||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to | ||
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for | ||
more details. | ||
Some of the supported parameters: | ||
- `max_tokens`: The maximum number of tokens the output text can have. | ||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. | ||
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. | ||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model | ||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens | ||
comprising the top 10% probability mass are considered. | ||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, | ||
it will generate two completions for each of the three prompts, ending up with 6 completions in total. | ||
- `stop`: One or more sequences after which the LLM should stop generating tokens. | ||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean | ||
the model will be less likely to repeat the same token in the text. | ||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text. | ||
Bigger values mean the model will be less likely to repeat the same token in the text. | ||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the | ||
values are the bias to add to that token. | ||
""" | ||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact | ||
# with the API. | ||
|
||
# Why is this here? | ||
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not | ||
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead | ||
# of passing it as a parameter. | ||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") | ||
if not azure_endpoint: | ||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") | ||
|
||
self.generation_kwargs = generation_kwargs or {} | ||
self.streaming_callback = streaming_callback | ||
self.api_version = api_version | ||
self.azure_endpoint = azure_endpoint | ||
self.azure_deployment = azure_deployment | ||
self.organization = organization | ||
self.model_name = azure_deployment or "gpt-35-turbo" | ||
|
||
self.client = AzureOpenAI( | ||
api_version=api_version, | ||
azure_endpoint=azure_endpoint, | ||
azure_deployment=azure_deployment, | ||
api_key=api_key, | ||
azure_ad_token=azure_ad_token, | ||
azure_ad_token_provider=azure_ad_token_provider, | ||
organization=organization, | ||
) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:return: The serialized component as a dictionary. | ||
""" | ||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None | ||
return default_to_dict( | ||
self, | ||
azure_endpoint=self.azure_endpoint, | ||
azure_deployment=self.azure_deployment, | ||
organization=self.organization, | ||
api_version=self.api_version, | ||
streaming_callback=callback_name, | ||
generation_kwargs=self.generation_kwargs, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: The dictionary representation of this component. | ||
:return: The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
serialized_callback_handler = init_params.get("streaming_callback") | ||
if serialized_callback_handler: | ||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) | ||
return default_from_dict(cls, data) |
4 changes: 4 additions & 0 deletions
4
releasenotes/notes/add-azure-generators-a30c786204b22e48.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features: | ||
- | | ||
Adds support for Azure OpenAI models with AzureGenerator and AzureChatGenerator components. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's rename here to AzureOpenAIGenerator and AzureOpenAIChatGenerator |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be helpful to add a comment here explaining that super().init() call is skipped on purpose and that this is done because we need to initialize an AzureOpenAI client instead of the OpenAI client.