Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async version in aisuite.AsyncClient class #185

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .client import Client
from .async_client import AsyncClient
from .framework.message import Message
from .utils.tools import Tools
56 changes: 56 additions & 0 deletions aisuite/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from .client import Client, Chat, Completions
from .provider import ProviderFactory


class AsyncClient(Client):
@property
def chat(self):
"""Return the async chat API interface."""
if not self._chat:
self._chat = AsyncChat(self)
return self._chat


class AsyncChat(Chat):
def __init__(self, client: "AsyncClient"):
self.client = client
self._completions = AsyncCompletions(self.client)


class AsyncCompletions(Completions):
async def create(self, model: str, messages: list, **kwargs):
"""
Create async chat completion based on the model, messages, and any extra arguments.
"""
# Check that correct format is used
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
)

# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)

# Validate if the provider is supported
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config, is_async=True
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the chat completion to the correct provider's async implementation
return await provider.chat_completions_create_async(
model_name, messages, **kwargs
)
11 changes: 9 additions & 2 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass

class AsyncProvider(ABC):
@abstractmethod
async def chat_completions_create_async(self, model, messages, **kwargs):
"""Method for async chat completion calls, to be implemented by each provider."""
raise NotImplementedError("Async chat completion calls are not implemented for this provider.")


class ProviderFactory:
"""Factory to dynamically load provider instances based on naming conventions."""

PROVIDERS_DIR = Path(__file__).parent / "providers"

@classmethod
def create_provider(cls, provider_key, config):
def create_provider(cls, provider_key, config, is_async=False):
"""Dynamically load and create an instance of a provider based on the naming convention."""
# Convert provider_key to the expected module and class names
provider_class_name = f"{provider_key.capitalize()}Provider"
async_suffix = "Async" if is_async else ""
provider_class_name = f"{provider_key.capitalize()}{async_suffix}Provider"
provider_module_name = f"{provider_key}_provider"

module_path = f"aisuite.providers.{provider_module_name}"
Expand Down
28 changes: 27 additions & 1 deletion aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import anthropic
import json
from aisuite.provider import Provider
from aisuite.provider import Provider, AsyncProvider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function

Expand Down Expand Up @@ -222,3 +222,29 @@ def _prepare_kwargs(self, kwargs):
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])

return kwargs

class AnthropicAsyncProvider(AsyncProvider):
def __init__(self, **config):
"""Initialize the Anthropic provider with the given configuration."""
self.async_client = anthropic.AsyncAnthropic(**config)
self.converter = AnthropicMessageConverter()

async def chat_completions_create_async(self, model, messages, **kwargs):
"""Create a chat completion using the async Anthropic API."""
kwargs = self._prepare_kwargs(kwargs)
system_message, converted_messages = self.converter.convert_request(messages)

response = await self.async_client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
return self.converter.convert_response(response)

def _prepare_kwargs(self, kwargs):
"""Prepare kwargs for the API call."""
kwargs = kwargs.copy()
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)

if "tools" in kwargs:
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])

return kwargs
76 changes: 75 additions & 1 deletion aisuite/providers/fireworks_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import httpx
import json
from aisuite.provider import Provider, LLMError
from aisuite.provider import Provider, AsyncProvider, LLMError
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall

Expand Down Expand Up @@ -130,6 +130,80 @@ def chat_completions_create(self, model, messages, **kwargs):
except Exception as e:
raise LLMError(f"An error occurred: {e}")

class FireworksAsyncProvider(AsyncProvider):
"""
Fireworks AI Provider using httpx for direct API calls.
"""

BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"

def __init__(self, **config):
"""
Initialize the Fireworks provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
self.api_key = config.get("api_key", os.getenv("FIREWORKS_API_KEY"))
if not self.api_key:
raise ValueError(
"Fireworks API key is missing. Please provide it in the config or set the FIREWORKS_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)
self.transformer = FireworksMessageConverter()

async def chat_completions_create_async(self, model, messages, **kwargs):
"""
Makes an async request to the Fireworks AI chat completions endpoint.
"""
# Remove 'stream' from kwargs if present
kwargs.pop("stream", None)

# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

# Prepare the request payload
data = {
"model": model,
"messages": transformed_messages,
}

# Add tools if provided
if "tools" in kwargs:
data["tools"] = kwargs["tools"]
kwargs.pop("tools")

# Add tool_choice if provided
if "tool_choice" in kwargs:
data["tool_choice"] = kwargs["tool_choice"]
kwargs.pop("tool_choice")

# Add remaining kwargs
data.update(kwargs)

headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

async with httpx.AsyncClient() as client:
try:
# Make the async request to Fireworks AI endpoint
response = await client.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
return self.transformer.convert_response(response.json())
except httpx.HTTPStatusError as error:
error_message = (
f"The request failed with status code: {error.status_code}\n"
)
error_message += f"Headers: {error.headers}\n"
error_message += error.response.text
raise LLMError(error_message)
except Exception as e:
raise LLMError(f"An error occurred: {e}")

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
Expand Down
37 changes: 36 additions & 1 deletion aisuite/providers/mistral_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mistralai import Mistral
from aisuite.framework.message import Message
from aisuite.framework import ChatCompletionResponse
from aisuite.provider import Provider, LLMError
from aisuite.provider import Provider, AsyncProvider, LLMError
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


Expand Down Expand Up @@ -72,3 +72,38 @@ def chat_completions_create(self, model, messages, **kwargs):
return self.transformer.convert_response(response)
except Exception as e:
raise LLMError(f"An error occurred: {e}")

class MistralAsyncProvider(AsyncProvider):
"""
Mistral AI Provider using the official Mistral client.
"""

def __init__(self, **config):
"""
Initialize the Mistral provider with the given configuration.
Pass the entire configuration dictionary to the Mistral client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("MISTRAL_API_KEY"))
if not config["api_key"]:
raise ValueError(
"Mistral API key is missing. Please provide it in the config or set the MISTRAL_API_KEY environment variable."
)
self.client = Mistral(**config)
self.transformer = MistralMessageConverter()

async def chat_completions_create_async(self, model, messages, **kwargs):
"""
Makes a request to Mistral using the official client.
"""
try:
# Transform messages using converter
transformed_messages = self.transformer.convert_request(messages)

response = await self.client.chat.complete_async(
model=model, messages=transformed_messages, **kwargs
)

return self.transformer.convert_response(response)
except Exception as e:
raise LLMError(f"An error occurred: {e}")
37 changes: 36 additions & 1 deletion aisuite/providers/openai_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import openai
import os
from aisuite.provider import Provider, LLMError
from aisuite.provider import Provider, AsyncProvider, LLMError
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


Expand Down Expand Up @@ -38,3 +38,38 @@ def chat_completions_create(self, model, messages, **kwargs):
return response
except Exception as e:
raise LLMError(f"An error occurred: {e}")

class OpenaiAsyncProvider(AsyncProvider):
def __init__(self, **config):
"""
Initialize the OpenAI provider with the given configuration.
Pass the entire configuration dictionary to the OpenAI client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("OPENAI_API_KEY"))
if not config["api_key"]:
raise ValueError(
"OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
)

# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically
# infer certain values from the environment variables.
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc.

# Pass the entire config to the OpenAI client constructor
self.async_client = openai.AsyncOpenAI(**config)
self.transformer = OpenAICompliantMessageConverter()

async def chat_completions_create_async(self, model, messages, **kwargs):
# Any exception raised by OpenAI will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
try:
transformed_messages = self.transformer.convert_request(messages)
response = await self.async_client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs # Pass any additional arguments to the OpenAI API
)
return response
except Exception as e:
raise LLMError(f"An error occurred: {e}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ optional = true
[tool.poetry.group.test.dependencies]
pytest = "^8.2.2"
pytest-cov = "^6.0.0"
pytest-asyncio = "^0.25.3"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading