Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support baidu llm model #168

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions aisuite/providers/baidu_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import httpx
import os
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse

class BaiduProvider(Provider):
def __init__(self, **config):
"""
Initialize the provider with the given configuration.
The token is fetched from the config or environment variables.
"""
# Ensure API key is provided either in config or via environment variable
self.token = config.get("token") or os.getenv("BAIDU_TOKEN")
if not self.token:
raise ValueError(
"Baidu token is missing. Please provide it in the config or set the BAIDU_TOKEN environment variable."
)
self.timeout = config.get("timeout", 200)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Inference API endpoint using httpx.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.token}",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}
url = "https://qianfan.baidubce.com/v2/chat/completions"
try:
# Make the request to Baidu Qianfan endpoint.
response = httpx.post(url, json=data, headers=headers, timeout=self.timeout)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Baidu qianfan request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response