Skip to content

Commit

Permalink
Update clients.py
Browse files Browse the repository at this point in the history
The _create_client method now handles both synchronous and asynchronous client creation, reducing code duplication.
Added a check for the presence of an API key and raise a ValueError if it's missing.
 Added a get_model_list method to fetch available models, which can be useful for debugging or providing model options to users.
  • Loading branch information
Yash-2707 authored Oct 22, 2024
1 parent db85fb4 commit 3329bf0
Showing 1 changed file with 56 additions and 20 deletions.
76 changes: 56 additions & 20 deletions backend/app/services/clients.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,72 @@
import os
from typing import Union

from app.core.config import settings
from dotenv import load_dotenv
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI

load_dotenv()


class Clients:
def __init__(self):
self.client_azure_4o = self._create_client(sync=True)
self.aclient_azure_4o = self._create_client(sync=False)

self.client_azure_4o = self._create_azure_client()
self.aclient_azure_4o = self._create_azure_aclient()
def _create_client(self, sync: bool = True) -> Union[OpenAI, AzureOpenAI, AsyncOpenAI, AsyncAzureOpenAI]:
"""
Create and return an OpenAI client based on the environment configuration.
Args:
sync (bool): If True, return a synchronous client. If False, return an asynchronous client.
Returns:
Union[OpenAI, AzureOpenAI, AsyncOpenAI, AsyncAzureOpenAI]: The appropriate OpenAI client.
"""
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")

def _create_azure_client(self):
# if os.getenv(OPENAI_API_KEY) exists, use it
if os.getenv("OPENAI_API_KEY"):
return OpenAI()
else:
return AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
)
if not api_key:
raise ValueError("API key not found in environment variables.")

def _create_azure_aclient(self):
# if os.getenv(OPENAI_API_KEY) exists, use it
if os.getenv("OPENAI_API_KEY"):
return AsyncOpenAI()
return OpenAI() if sync else AsyncOpenAI()
else:
return AsyncAzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
client_class = AzureOpenAI if sync else AsyncAzureOpenAI
return client_class(
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
)

@property
def default_client(self) -> Union[OpenAI, AzureOpenAI]:
"""
Returns the default synchronous client.
"""
return self.client_azure_4o

@property
def default_async_client(self) -> Union[AsyncOpenAI, AsyncAzureOpenAI]:
"""
Returns the default asynchronous client.
"""
return self.aclient_azure_4o

def get_model_list(self) -> list:
"""
Fetch and return a list of available models.
"""
return self.default_client.models.list()

# Example usage
if __name__ == "__main__":
clients = Clients()
print(f"Default client: {type(clients.default_client)}")
print(f"Default async client: {type(clients.default_async_client)}")

try:
models = clients.get_model_list()
print(f"Available models: {models}")
except Exception as e:
print(f"Error fetching models: {e}")

0 comments on commit 3329bf0

Please sign in to comment.