Skip to content

Commit

Permalink
refactor: 重构Gemini和OpenAI聊天服务以支持工具和安全设置
Browse files Browse the repository at this point in the history
- 将 `_build_payload`、`_build_tools`、`_get_safety_settings` 和 `_has_image_parts` 函数从 `OpenAIChatService` 和 `GeminiChatService` 类中提取为独立的函数。
- 将 `_handle_stream_response` 和 `_handle_normal_response` 函数从 `GeminiResponseHandler` 和 `OpenAIResponseHandler` 类中提取为独立的函数。
- 将 `_extract_text` 函数从 `OpenAIResponseHandler` 类中提取为独立的函数, 并在 `GeminiResponseHandler` 中复用。
- 将 `_convert_image` 函数从 `OpenAIMessageConverter` 类中提取为独立的函数。
- 优化 `OpenAIChatService` 和 `GeminiChatService` 中的代码结构, 使其更清晰。
- 优化 `app/api/openai_routes.py` 和 `app/api/gemini_routes.py` 中的路由函数, 移除不必要的参数。
  • Loading branch information
snailyp committed Feb 6, 2025
1 parent b60b063 commit cd45f4b
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 481 deletions.
23 changes: 13 additions & 10 deletions app/api/gemini_routes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse

from app.core.config import settings
Expand All @@ -10,6 +9,7 @@
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler

router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
logger = get_gemini_logger()
Expand All @@ -22,26 +22,29 @@

@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
key: str = None,
token: str = Depends(security_service.verify_key),
):
async def list_models(_=Depends(security_service.verify_key)):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0", "displayName": "Gemini 2.0 Flash Search Experimental", "description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, "outputTokenLimit": 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, "topP": 0.95, "topK": 64, "maxTemperature": 2})
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental",
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2})
return models_json


@router.post("/models/{model_name}:generateContent")
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
Expand Down Expand Up @@ -70,7 +73,7 @@ async def generate_content(
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
Expand All @@ -81,7 +84,7 @@ async def stream_generate_content(
logger.info(f"Using API key: {api_key}")

try:
response_stream =chat_service.stream_generate_content(
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
Expand Down
32 changes: 12 additions & 20 deletions app/api/openai_routes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends, Header
from fastapi import HTTPException, APIRouter, Depends
from fastapi.responses import StreamingResponse

from app.core.config import settings
from app.core.logger import get_openai_logger
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.core.config import settings
from app.core.logger import get_openai_logger

router = APIRouter()
logger = get_openai_logger()
Expand All @@ -24,10 +23,7 @@

@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
):
async def list_models(_=Depends(security_service.verify_authorization)):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
Expand All @@ -43,10 +39,9 @@ async def list_models(
@router.post("/hf/v1/chat/completions")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
logger.info("-" * 50 + "chat_completion" + "-" * 50)
Expand All @@ -67,15 +62,13 @@ async def chat_completion(
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e



@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
Expand All @@ -95,8 +88,7 @@ async def embedding(
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
authorization: str = Header(None),
token: str = Depends(security_service.verify_auth_token),
_=Depends(security_service.verify_auth_token),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
Expand Down
2 changes: 1 addition & 1 deletion app/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ def get_request_logger():


def get_retry_logger():
return Logger.setup_logger("retry")
return Logger.setup_logger("retry")
8 changes: 4 additions & 4 deletions app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def verify_key(self, key: str):
return key

async def verify_authorization(
self, authorization: Optional[str] = Header(None)
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
Expand Down Expand Up @@ -45,7 +45,7 @@ async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
logger.error("Invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")

return x_goog_api_key

async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
Expand All @@ -56,5 +56,5 @@ async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -
if token != self.auth_token:
logger.error("Invalid auth_token")
raise HTTPException(status_code=401, detail="Invalid auth_token")
return token

return token
6 changes: 2 additions & 4 deletions app/schemas/gemini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@


class SafetySetting(BaseModel):
category: Optional[Literal[
"HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
threshold: Optional[Literal[
"HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None


class GenerationConfig(BaseModel):
Expand Down
14 changes: 8 additions & 6 deletions app/services/chat/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,26 @@
import httpx
from abc import ABC, abstractmethod


class ApiClient(ABC):
"""API客户端基类"""

@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
pass

@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
pass


class GeminiApiClient(ApiClient):
"""Gemini API客户端"""

def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url
self.timeout = timeout

def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
Expand All @@ -33,14 +35,14 @@ def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) ->
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()

async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", url, json=payload) as response:
async with client.stream(method="POST", url=url, json=payload) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
Expand Down
43 changes: 23 additions & 20 deletions app/services/chat/message_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,39 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any


class MessageConverter(ABC):
"""消息转换器基类"""

@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass


def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}


class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""

def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []

if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
Expand All @@ -29,22 +46,8 @@ def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if content["type"] == "text":
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(self._convert_image(content["image_url"]["url"]))
parts.append(_convert_image(content["image_url"]["url"]))

converted_messages.append({"role": role, "parts": parts})

return converted_messages

def _convert_image(self, image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}
Loading

0 comments on commit cd45f4b

Please sign in to comment.