-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: 添加重试机制和消息转换器,并支持Gemini v1beta API
- Loading branch information
Showing
12 changed files
with
742 additions
and
584 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# app/services/chat/api_client.py | ||
|
||
from typing import Dict, Any, AsyncGenerator | ||
import httpx | ||
from abc import ABC, abstractmethod | ||
|
||
class ApiClient(ABC): | ||
"""API客户端基类""" | ||
|
||
@abstractmethod | ||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: | ||
pass | ||
|
||
@abstractmethod | ||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: | ||
pass | ||
|
||
class GeminiApiClient(ApiClient): | ||
"""Gemini API客户端""" | ||
|
||
def __init__(self, base_url: str, timeout: int = 300): | ||
self.base_url = base_url | ||
self.timeout = timeout | ||
|
||
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: | ||
timeout = httpx.Timeout(self.timeout, read=self.timeout) | ||
if model.endswith("-search"): | ||
model = model[:-7] | ||
with httpx.Client(timeout=timeout) as client: | ||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" | ||
response = client.post(url, json=payload) | ||
if response.status_code != 200: | ||
error_content = response.text | ||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}") | ||
return response.json() | ||
|
||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: | ||
timeout = httpx.Timeout(self.timeout, read=self.timeout) | ||
if model.endswith("-search"): | ||
model = model[:-7] | ||
async with httpx.AsyncClient(timeout=timeout) as client: | ||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" | ||
async with client.stream("POST", url, json=payload) as response: | ||
if response.status_code != 200: | ||
error_content = await response.aread() | ||
error_msg = error_content.decode("utf-8") | ||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}") | ||
async for line in response.aiter_lines(): | ||
yield line |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# app/services/chat/message_converter.py | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List, Dict, Any | ||
|
||
class MessageConverter(ABC): | ||
"""消息转换器基类""" | ||
|
||
@abstractmethod | ||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
pass | ||
|
||
class OpenAIMessageConverter(MessageConverter): | ||
"""OpenAI消息格式转换器""" | ||
|
||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
converted_messages = [] | ||
for msg in messages: | ||
role = "user" if msg["role"] == "user" else "model" | ||
parts = [] | ||
|
||
if isinstance(msg["content"], str): | ||
parts.append({"text": msg["content"]}) | ||
elif isinstance(msg["content"], list): | ||
for content in msg["content"]: | ||
if isinstance(content, str): | ||
parts.append({"text": content}) | ||
elif isinstance(content, dict): | ||
if content["type"] == "text": | ||
parts.append({"text": content["text"]}) | ||
elif content["type"] == "image_url": | ||
parts.append(self._convert_image(content["image_url"]["url"])) | ||
|
||
converted_messages.append({"role": role, "parts": parts}) | ||
|
||
return converted_messages | ||
|
||
def _convert_image(self, image_url: str) -> Dict[str, Any]: | ||
if image_url.startswith("data:image"): | ||
return { | ||
"inline_data": { | ||
"mime_type": "image/jpeg", | ||
"data": image_url.split(",")[1] | ||
} | ||
} | ||
return { | ||
"image_url": { | ||
"url": image_url | ||
} | ||
} |
Oops, something went wrong.