Skip to content

Commit

Permalink
feat: 添加重试机制和消息转换器,并支持Gemini v1beta API
Browse files Browse the repository at this point in the history
  • Loading branch information
snailyp committed Dec 27, 2024
1 parent 6e90463 commit 870b1ec
Show file tree
Hide file tree
Showing 12 changed files with 742 additions and 584 deletions.
56 changes: 25 additions & 31 deletions app/api/gemini_routes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse

from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.core.security import SecurityService
from app.schemas.gemini_models import GeminiRequest
from app.services.chat_service import ChatService
from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService

from app.services.chat.retry_handler import RetryHandler
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
logger = get_gemini_logger()

# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)


@router.get("/models")
Expand All @@ -34,58 +35,51 @@ async def list_models(
return models_json

@router.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
# x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")

api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
MAX_RETRIES = 3

while retries < MAX_RETRIES:
try:
response = await chat_service.generate_content(
model_name=model_name,
request=request,
api_key=api_key
)
return response
try:
response = chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response

except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e


@router.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
# x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")

api_key = await key_manager.get_next_working_key()
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")

try:
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
response_stream = chat_service.stream_generate_content(
model_name=model_name,
response_stream =chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
Expand Down
56 changes: 26 additions & 30 deletions app/api/openai_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from fastapi.responses import StreamingResponse

from app.core.security import SecurityService
from app.services.chat.retry_handler import RetryHandler
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat_service import ChatService
from app.services.openai_chat_service import OpenAIChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.core.config import settings
Expand All @@ -31,47 +32,42 @@ async def list_models(
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return model_service.get_gemini_openai_models(api_key)
try:
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e


@router.post("/v1/chat/completions")
@router.post("/hf/v1/chat/completions")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = ChatService(settings.BASE_URL, key_manager)
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
logger.info("-" * 50 + "chat_completion" + "-" * 50)
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
max_retries = 3

while retries < max_retries:
try:
response = await chat_service.create_chat_completion(
request=request,
api_key=api_key,
)

# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
try:
response = await chat_service.create_chat_completion(
request=request,
api_key=api_key,
)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
logger.info("Chat completion request successful")
return response

except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {max_retries}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached. Raising error")
raise
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e



@router.post("/v1/embeddings")
Expand All @@ -93,7 +89,7 @@ async def embedding(
return response
except Exception as e:
logger.error(f"Embedding request failed: {str(e)}")
raise
raise HTTPException(status_code=500, detail="Embedding request failed") from e


@router.get("/v1/keys/list")
Expand All @@ -120,4 +116,4 @@ async def get_keys_list(
raise HTTPException(
status_code=500,
detail="Internal server error while fetching keys list"
)
) from e
4 changes: 4 additions & 0 deletions app/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def get_embeddings_logger():

def get_request_logger():
return Logger.setup_logger("request")


def get_retry_logger():
return Logger.setup_logger("retry")
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# 包含所有路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)


@app.get("/health")
Expand Down
49 changes: 49 additions & 0 deletions app/services/chat/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# app/services/chat/api_client.py

from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod

class ApiClient(ABC):
"""API客户端基类"""

@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
pass

@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
pass

class GeminiApiClient(ApiClient):
"""Gemini API客户端"""

def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url
self.timeout = timeout

def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()

async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", url, json=payload) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
async for line in response.aiter_lines():
yield line
50 changes: 50 additions & 0 deletions app/services/chat/message_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# app/services/chat/message_converter.py

from abc import ABC, abstractmethod
from typing import List, Dict, Any

class MessageConverter(ABC):
"""消息转换器基类"""

@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass

class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""

def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []

if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text":
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(self._convert_image(content["image_url"]["url"]))

converted_messages.append({"role": role, "parts": parts})

return converted_messages

def _convert_image(self, image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}
Loading

0 comments on commit 870b1ec

Please sign in to comment.