-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The _create_client method now handles both synchronous and asynchronous client creation, reducing code duplication. Added a check for the presence of an API key and raise a ValueError if it's missing. Added a get_model_list method to fetch available models, which can be useful for debugging or providing model options to users.
- Loading branch information
Showing
1 changed file
with
56 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,72 @@ | ||
import os | ||
from typing import Union | ||
|
||
from app.core.config import settings | ||
from dotenv import load_dotenv | ||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI | ||
|
||
load_dotenv() | ||
|
||
|
||
class Clients: | ||
def __init__(self): | ||
self.client_azure_4o = self._create_client(sync=True) | ||
self.aclient_azure_4o = self._create_client(sync=False) | ||
|
||
self.client_azure_4o = self._create_azure_client() | ||
self.aclient_azure_4o = self._create_azure_aclient() | ||
def _create_client(self, sync: bool = True) -> Union[OpenAI, AzureOpenAI, AsyncOpenAI, AsyncAzureOpenAI]: | ||
""" | ||
Create and return an OpenAI client based on the environment configuration. | ||
Args: | ||
sync (bool): If True, return a synchronous client. If False, return an asynchronous client. | ||
Returns: | ||
Union[OpenAI, AzureOpenAI, AsyncOpenAI, AsyncAzureOpenAI]: The appropriate OpenAI client. | ||
""" | ||
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY") | ||
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01") | ||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") | ||
|
||
def _create_azure_client(self): | ||
# if os.getenv(OPENAI_API_KEY) exists, use it | ||
if os.getenv("OPENAI_API_KEY"): | ||
return OpenAI() | ||
else: | ||
return AzureOpenAI( | ||
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||
api_version="2024-02-01", | ||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||
) | ||
if not api_key: | ||
raise ValueError("API key not found in environment variables.") | ||
|
||
def _create_azure_aclient(self): | ||
# if os.getenv(OPENAI_API_KEY) exists, use it | ||
if os.getenv("OPENAI_API_KEY"): | ||
return AsyncOpenAI() | ||
return OpenAI() if sync else AsyncOpenAI() | ||
else: | ||
return AsyncAzureOpenAI( | ||
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||
api_version="2024-02-01", | ||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||
client_class = AzureOpenAI if sync else AsyncAzureOpenAI | ||
return client_class( | ||
api_key=api_key, | ||
api_version=api_version, | ||
azure_endpoint=azure_endpoint, | ||
) | ||
|
||
@property | ||
def default_client(self) -> Union[OpenAI, AzureOpenAI]: | ||
""" | ||
Returns the default synchronous client. | ||
""" | ||
return self.client_azure_4o | ||
|
||
@property | ||
def default_async_client(self) -> Union[AsyncOpenAI, AsyncAzureOpenAI]: | ||
""" | ||
Returns the default asynchronous client. | ||
""" | ||
return self.aclient_azure_4o | ||
|
||
def get_model_list(self) -> list: | ||
""" | ||
Fetch and return a list of available models. | ||
""" | ||
return self.default_client.models.list() | ||
|
||
# Example usage | ||
if __name__ == "__main__": | ||
clients = Clients() | ||
print(f"Default client: {type(clients.default_client)}") | ||
print(f"Default async client: {type(clients.default_async_client)}") | ||
|
||
try: | ||
models = clients.get_model_list() | ||
print(f"Available models: {models}") | ||
except Exception as e: | ||
print(f"Error fetching models: {e}") |