diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index a552744..d16cd35 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -1,5 +1,4 @@ -from http.client import HTTPException -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from app.core.config import settings @@ -10,6 +9,7 @@ from app.services.key_manager import KeyManager from app.services.model_service import ModelService from app.services.chat.retry_handler import RetryHandler + router = APIRouter(prefix="/gemini/v1beta") router_v1beta = APIRouter(prefix="/v1beta") logger = get_gemini_logger() @@ -22,26 +22,29 @@ @router.get("/models") @router_v1beta.get("/models") -async def list_models( - key: str = None, - token: str = Depends(security_service.verify_key), -): +async def list_models(_=Depends(security_service.verify_key)): """获取可用的Gemini模型列表""" logger.info("-" * 50 + "list_gemini_models" + "-" * 50) logger.info("Handling Gemini models list request") api_key = await key_manager.get_next_working_key() logger.info(f"Using API key: {api_key}") models_json = model_service.get_gemini_models(api_key) - models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0", "displayName": "Gemini 2.0 Flash Search Experimental", "description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, "outputTokenLimit": 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, "topP": 0.95, "topK": 64, "maxTemperature": 2}) + models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0", + "displayName": "Gemini 2.0 Flash Search Experimental", + "description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, + "outputTokenLimit": 8192, + "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, + "topP": 0.95, "topK": 64, "maxTemperature": 2}) return models_json + @router.post("/models/{model_name}:generateContent") @router_v1beta.post("/models/{model_name}:generateContent") @RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key") async def generate_content( model_name: str, request: GeminiRequest, - x_goog_api_key: str = Depends(security_service.verify_goog_api_key), + _=Depends(security_service.verify_goog_api_key), api_key: str = Depends(key_manager.get_next_working_key), ): chat_service = GeminiChatService(settings.BASE_URL, key_manager) @@ -70,7 +73,7 @@ async def generate_content( async def stream_generate_content( model_name: str, request: GeminiRequest, - x_goog_api_key: str = Depends(security_service.verify_goog_api_key), + _=Depends(security_service.verify_goog_api_key), api_key: str = Depends(key_manager.get_next_working_key), ): chat_service = GeminiChatService(settings.BASE_URL, key_manager) @@ -81,7 +84,7 @@ async def stream_generate_content( logger.info(f"Using API key: {api_key}") try: - response_stream =chat_service.stream_generate_content( + response_stream = chat_service.stream_generate_content( model=model_name, request=request, api_key=api_key diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 32e03b7..d3d6a2c 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -1,16 +1,15 @@ -from http.client import HTTPException -from fastapi import APIRouter, Depends, Header +from fastapi import HTTPException, APIRouter, Depends from fastapi.responses import StreamingResponse +from app.core.config import settings +from app.core.logger import get_openai_logger from app.core.security import SecurityService +from app.schemas.openai_models import ChatRequest, EmbeddingRequest from app.services.chat.retry_handler import RetryHandler +from app.services.embedding_service import EmbeddingService from app.services.key_manager import KeyManager from app.services.model_service import ModelService from app.services.openai_chat_service import OpenAIChatService -from app.services.embedding_service import EmbeddingService -from app.schemas.openai_models import ChatRequest, EmbeddingRequest -from app.core.config import settings -from app.core.logger import get_openai_logger router = APIRouter() logger = get_openai_logger() @@ -24,10 +23,7 @@ @router.get("/v1/models") @router.get("/hf/v1/models") -async def list_models( - authorization: str = Header(None), - token: str = Depends(security_service.verify_authorization), -): +async def list_models(_=Depends(security_service.verify_authorization)): logger.info("-" * 50 + "list_models" + "-" * 50) logger.info("Handling models list request") api_key = await key_manager.get_next_working_key() @@ -43,10 +39,9 @@ async def list_models( @router.post("/hf/v1/chat/completions") @RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key") async def chat_completion( - request: ChatRequest, - authorization: str = Header(None), - token: str = Depends(security_service.verify_authorization), - api_key: str = Depends(key_manager.get_next_working_key), + request: ChatRequest, + _=Depends(security_service.verify_authorization), + api_key: str = Depends(key_manager.get_next_working_key), ): chat_service = OpenAIChatService(settings.BASE_URL, key_manager) logger.info("-" * 50 + "chat_completion" + "-" * 50) @@ -67,15 +62,13 @@ async def chat_completion( except Exception as e: logger.error(f"Chat completion failed after retries: {str(e)}") raise HTTPException(status_code=500, detail="Chat completion failed") from e - @router.post("/v1/embeddings") @router.post("/hf/v1/embeddings") async def embedding( - request: EmbeddingRequest, - authorization: str = Header(None), - token: str = Depends(security_service.verify_authorization), + request: EmbeddingRequest, + _=Depends(security_service.verify_authorization), ): logger.info("-" * 50 + "embedding" + "-" * 50) logger.info(f"Handling embedding request for model: {request.model}") @@ -95,8 +88,7 @@ async def embedding( @router.get("/v1/keys/list") @router.get("/hf/v1/keys/list") async def get_keys_list( - authorization: str = Header(None), - token: str = Depends(security_service.verify_auth_token), + _=Depends(security_service.verify_auth_token), ): """获取有效和无效的API key列表""" logger.info("-" * 50 + "get_keys_list" + "-" * 50) diff --git a/app/core/logger.py b/app/core/logger.py index 62f722a..66aec33 100644 --- a/app/core/logger.py +++ b/app/core/logger.py @@ -128,4 +128,4 @@ def get_request_logger(): def get_retry_logger(): - return Logger.setup_logger("retry") \ No newline at end of file + return Logger.setup_logger("retry") diff --git a/app/core/security.py b/app/core/security.py index f2508fa..4e80c4d 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -17,7 +17,7 @@ async def verify_key(self, key: str): return key async def verify_authorization( - self, authorization: Optional[str] = Header(None) + self, authorization: Optional[str] = Header(None) ) -> str: if not authorization: logger.error("Missing Authorization header") @@ -45,7 +45,7 @@ async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None) if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token: logger.error("Invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") - + return x_goog_api_key async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str: @@ -56,5 +56,5 @@ async def verify_auth_token(self, authorization: Optional[str] = Header(None)) - if token != self.auth_token: logger.error("Invalid auth_token") raise HTTPException(status_code=401, detail="Invalid auth_token") - - return token \ No newline at end of file + + return token diff --git a/app/schemas/gemini_models.py b/app/schemas/gemini_models.py index 2addb13..6869379 100644 --- a/app/schemas/gemini_models.py +++ b/app/schemas/gemini_models.py @@ -3,10 +3,8 @@ class SafetySetting(BaseModel): - category: Optional[Literal[ - "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None - threshold: Optional[Literal[ - "HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None + category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None + threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None class GenerationConfig(BaseModel): diff --git a/app/services/chat/api_client.py b/app/services/chat/api_client.py index a27302b..df7f0cb 100644 --- a/app/services/chat/api_client.py +++ b/app/services/chat/api_client.py @@ -4,24 +4,26 @@ import httpx from abc import ABC, abstractmethod + class ApiClient(ABC): """API客户端基类""" - + @abstractmethod async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: pass - + @abstractmethod async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: pass + class GeminiApiClient(ApiClient): """Gemini API客户端""" - + def __init__(self, base_url: str, timeout: int = 300): self.base_url = base_url self.timeout = timeout - + def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: timeout = httpx.Timeout(self.timeout, read=self.timeout) if model.endswith("-search"): @@ -33,14 +35,14 @@ def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> error_content = response.text raise Exception(f"API call failed with status code {response.status_code}, {error_content}") return response.json() - + async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: timeout = httpx.Timeout(self.timeout, read=self.timeout) if model.endswith("-search"): model = model[:-7] async with httpx.AsyncClient(timeout=timeout) as client: url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" - async with client.stream("POST", url, json=payload) as response: + async with client.stream(method="POST", url=url, json=payload) as response: if response.status_code != 200: error_content = await response.aread() error_msg = error_content.decode("utf-8") diff --git a/app/services/chat/message_converter.py b/app/services/chat/message_converter.py index 17372ae..ff7e795 100644 --- a/app/services/chat/message_converter.py +++ b/app/services/chat/message_converter.py @@ -3,22 +3,39 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any + class MessageConverter(ABC): """消息转换器基类""" - + @abstractmethod def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: pass + +def _convert_image(image_url: str) -> Dict[str, Any]: + if image_url.startswith("data:image"): + return { + "inline_data": { + "mime_type": "image/jpeg", + "data": image_url.split(",")[1] + } + } + return { + "image_url": { + "url": image_url + } + } + + class OpenAIMessageConverter(MessageConverter): """OpenAI消息格式转换器""" - + def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: converted_messages = [] for msg in messages: role = "user" if msg["role"] == "user" else "model" parts = [] - + if isinstance(msg["content"], str): parts.append({"text": msg["content"]}) elif isinstance(msg["content"], list): @@ -29,22 +46,8 @@ def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if content["type"] == "text": parts.append({"text": content["text"]}) elif content["type"] == "image_url": - parts.append(self._convert_image(content["image_url"]["url"])) - + parts.append(_convert_image(content["image_url"]["url"])) + converted_messages.append({"role": role, "parts": parts}) - + return converted_messages - - def _convert_image(self, image_url: str) -> Dict[str, Any]: - if image_url.startswith("data:image"): - return { - "inline_data": { - "mime_type": "image/jpeg", - "data": image_url.split(",")[1] - } - } - return { - "image_url": { - "url": image_url - } - } \ No newline at end of file diff --git a/app/services/chat/response_handler.py b/app/services/chat/response_handler.py index 1d6ede4..c669323 100644 --- a/app/services/chat/response_handler.py +++ b/app/services/chat/response_handler.py @@ -6,331 +6,225 @@ import uuid from app.core.config import settings + class ResponseHandler(ABC): """响应处理器基类""" - + @abstractmethod def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]: pass + class GeminiResponseHandler(ResponseHandler): """Gemini响应处理器""" - + def __init__(self): self.thinking_first = True self.thinking_status = False - + def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]: if stream: - return self._handle_stream_response(response, model, stream) - return self._handle_normal_response(response, model, stream) - - def _handle_stream_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: - text = self._extract_text(response, model, stream=stream) - content = {"parts": [{"text": text}],"role": "model"} - response["candidates"][0]["content"] = content - return response - - def _handle_normal_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: - text = self._extract_text(response, model, stream=stream) - content = {"parts": [{"text": text}],"role": "model"} - response["candidates"][0]["content"] = content - return response - - def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str: - text = "" - if stream: - if response.get("candidates"): - candidate = response["candidates"][0] - content = candidate.get("content", {}) - parts = content.get("parts", []) - # if "thinking" in model: - # if settings.SHOW_THINKING_PROCESS: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "> thinking\n\n" + parts[0].get("text") - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = ( - # "> thinking\n\n" - # + parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # text = ( - # parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "" - # elif self.thinking_status: - # text = "" - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = parts[1].get("text") - # else: - # text = parts[1].get("text") - # else: - # if "text" in parts[0]: - # text = parts[0].get("text") - # elif "executableCode" in parts[0]: - # text = _format_code_block(parts[0]["executableCode"]) - # elif "codeExecution" in parts[0]: - # text = _format_code_block(parts[0]["codeExecution"]) - # elif "executableCodeResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["executableCodeResult"] - # ) - # elif "codeExecutionResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["codeExecutionResult"] - # ) - # else: - # text = "" - if "text" in parts[0]: - text = parts[0].get("text") - elif "executableCode" in parts[0]: - text = _format_code_block(parts[0]["executableCode"]) - elif "codeExecution" in parts[0]: - text = _format_code_block(parts[0]["codeExecution"]) - elif "executableCodeResult" in parts[0]: - text = _format_execution_result( - parts[0]["executableCodeResult"] - ) - elif "codeExecutionResult" in parts[0]: - text = _format_execution_result( - parts[0]["codeExecutionResult"] - ) - else: - text = "" - text = _add_search_link_text(model, candidate, text) - else: - if response.get("candidates"): - candidate = response["candidates"][0] - if "thinking" in model: - if settings.SHOW_THINKING_PROCESS: - if len(candidate["content"]["parts"]) == 2: - text = ( - "> thinking\n\n" - + candidate["content"]["parts"][0]["text"] - + "\n\n---\n> output\n\n" - + candidate["content"]["parts"][1]["text"] - ) - else: - text = candidate["content"]["parts"][0]["text"] - else: - if len(candidate["content"]["parts"]) == 2: - text = candidate["content"]["parts"][1]["text"] - else: - text = candidate["content"]["parts"][0]["text"] - else: - text = candidate["content"]["parts"][0]["text"] - text = _add_search_link_text(model, candidate, text) - else: - text = "暂无返回" - return text - + return _handle_gemini_stream_response(response, model, stream) + return _handle_gemini_normal_response(response, model, stream) + + +def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: + text = _extract_text(response, model, stream=True) + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": text} if text else {}, + "finish_reason": finish_reason + }] + } + + +def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: + text = _extract_text(response, model, stream=False) + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": text + }, + "finish_reason": finish_reason + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } class OpenAIResponseHandler(ResponseHandler): """OpenAI响应处理器""" - + def __init__(self, config): self.config = config self.thinking_first = True self.thinking_status = False - + def handle_response( - self, - response: Dict[str, Any], - model: str, - stream: bool = False, - finish_reason: str = None + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + finish_reason: str = None ) -> Optional[Dict[str, Any]]: if stream: - return self._handle_stream_response(response, model, finish_reason) - return self._handle_normal_response(response, model, finish_reason) - - def _handle_stream_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: - text = self._extract_text(response, model, stream=True) - return { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [{ - "index": 0, - "delta": {"content": text} if text else {}, - "finish_reason": finish_reason - }] - } - - def _handle_normal_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: - text = self._extract_text(response, model, stream=False) - return { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": text - }, - "finish_reason": finish_reason - }], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } - } - - def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str: - text = "" - if stream: - if response.get("candidates"): - candidate = response["candidates"][0] - content = candidate.get("content", {}) - parts = content.get("parts", []) - # if "thinking" in model: - # if settings.SHOW_THINKING_PROCESS: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "> thinking\n\n" + parts[0].get("text") - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = ( - # "> thinking\n\n" - # + parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # text = ( - # parts[0].get("text") - # + "\n\n---\n> output\n\n" - # + parts[1].get("text") - # ) - # else: - # if len(parts) == 1: - # if self.thinking_first: - # self.thinking_first = False - # self.thinking_status = True - # text = "" - # elif self.thinking_status: - # text = "" - # else: - # text = parts[0].get("text") - - # if len(parts) == 2: - # self.thinking_status = False - # if self.thinking_first: - # self.thinking_first = False - # text = parts[1].get("text") - # else: - # text = parts[1].get("text") - # else: - # if "text" in parts[0]: - # text = parts[0].get("text") - # elif "executableCode" in parts[0]: - # text = _format_code_block(parts[0]["executableCode"]) - # elif "codeExecution" in parts[0]: - # text = _format_code_block(parts[0]["codeExecution"]) - # elif "executableCodeResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["executableCodeResult"] - # ) - # elif "codeExecutionResult" in parts[0]: - # text = _format_execution_result( - # parts[0]["codeExecutionResult"] - # ) - # else: - # text = "" - # text = _add_search_link_text(model, candidate, text) - if "text" in parts[0]: - text = parts[0].get("text") - elif "executableCode" in parts[0]: - text = _format_code_block(parts[0]["executableCode"]) - elif "codeExecution" in parts[0]: - text = _format_code_block(parts[0]["codeExecution"]) - elif "executableCodeResult" in parts[0]: - text = _format_execution_result( - parts[0]["executableCodeResult"] - ) - elif "codeExecutionResult" in parts[0]: - text = _format_execution_result( - parts[0]["codeExecutionResult"] - ) - else: - text = "" - text = _add_search_link_text(model, candidate, text) - else: - if response.get("candidates"): - candidate = response["candidates"][0] - if "thinking" in model: - if settings.SHOW_THINKING_PROCESS: - if len(candidate["content"]["parts"]) == 2: - text = ( + return _handle_openai_stream_response(response, model, finish_reason) + return _handle_openai_normal_response(response, model, finish_reason) + + +def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str: + text = "" + if stream: + if response.get("candidates"): + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + # if "thinking" in model: + # if settings.SHOW_THINKING_PROCESS: + # if len(parts) == 1: + # if self.thinking_first: + # self.thinking_first = False + # self.thinking_status = True + # text = "> thinking\n\n" + parts[0].get("text") + # else: + # text = parts[0].get("text") + + # if len(parts) == 2: + # self.thinking_status = False + # if self.thinking_first: + # self.thinking_first = False + # text = ( + # "> thinking\n\n" + # + parts[0].get("text") + # + "\n\n---\n> output\n\n" + # + parts[1].get("text") + # ) + # else: + # text = ( + # parts[0].get("text") + # + "\n\n---\n> output\n\n" + # + parts[1].get("text") + # ) + # else: + # if len(parts) == 1: + # if self.thinking_first: + # self.thinking_first = False + # self.thinking_status = True + # text = "" + # elif self.thinking_status: + # text = "" + # else: + # text = parts[0].get("text") + + # if len(parts) == 2: + # self.thinking_status = False + # if self.thinking_first: + # self.thinking_first = False + # text = parts[1].get("text") + # else: + # text = parts[1].get("text") + # else: + # if "text" in parts[0]: + # text = parts[0].get("text") + # elif "executableCode" in parts[0]: + # text = _format_code_block(parts[0]["executableCode"]) + # elif "codeExecution" in parts[0]: + # text = _format_code_block(parts[0]["codeExecution"]) + # elif "executableCodeResult" in parts[0]: + # text = _format_execution_result( + # parts[0]["executableCodeResult"] + # ) + # elif "codeExecutionResult" in parts[0]: + # text = _format_execution_result( + # parts[0]["codeExecutionResult"] + # ) + # else: + # text = "" + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = _format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = _format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = _format_execution_result( + parts[0]["executableCodeResult"] + ) + elif "codeExecutionResult" in parts[0]: + text = _format_execution_result( + parts[0]["codeExecutionResult"] + ) + else: + text = "" + text = _add_search_link_text(model, candidate, text) + else: + if response.get("candidates"): + candidate = response["candidates"][0] + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(candidate["content"]["parts"]) == 2: + text = ( "> thinking\n\n" + candidate["content"]["parts"][0]["text"] + "\n\n---\n> output\n\n" + candidate["content"]["parts"][1]["text"] - ) - else: - text = candidate["content"]["parts"][0]["text"] + ) else: - if len(candidate["content"]["parts"]) == 2: - text = candidate["content"]["parts"][1]["text"] - else: - text = candidate["content"]["parts"][0]["text"] + text = candidate["content"]["parts"][0]["text"] else: - text = candidate["content"]["parts"][0]["text"] - text = _add_search_link_text(model, candidate, text) + if len(candidate["content"]["parts"]) == 2: + text = candidate["content"]["parts"][1]["text"] + else: + text = candidate["content"]["parts"][0]["text"] else: - text = "暂无返回" - return text + text = candidate["content"]["parts"][0]["text"] + text = _add_search_link_text(model, candidate, text) + else: + text = "暂无返回" + return text + + +def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: + text = _extract_text(response, model, stream=stream) + content = {"parts": [{"text": text}], "role": "model"} + response["candidates"][0]["content"] = content + return response + + +def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: + text = _extract_text(response, model, stream=stream) + content = {"parts": [{"text": text}], "role": "model"} + response["candidates"][0]["content"] = content + return response + - def _format_code_block(code_data: dict) -> str: """格式化代码块输出""" language = code_data.get("language", "").lower() code = code_data.get("code", "").strip() return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n""" - - -def _add_search_link_text(model:str, candidate:dict, text:str) -> str: + + +def _add_search_link_text(model: str, candidate: dict, text: str) -> str: if ( - settings.SHOW_SEARCH_LINK - and model.endswith("-search") - and "groundingMetadata" in candidate - and "groundingChunks" in candidate["groundingMetadata"] + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] ): grounding_chunks = candidate["groundingMetadata"]["groundingChunks"] text += "\n\n---\n\n" @@ -351,4 +245,4 @@ def _format_execution_result(result_data: dict) -> str: """格式化执行结果输出""" outcome = result_data.get("outcome", "") output = result_data.get("output", "").strip() - return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n""" \ No newline at end of file + return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n""" diff --git a/app/services/chat/retry_handler.py b/app/services/chat/retry_handler.py index a8915aa..28985ff 100644 --- a/app/services/chat/retry_handler.py +++ b/app/services/chat/retry_handler.py @@ -8,26 +8,27 @@ T = TypeVar('T') logger = get_retry_logger() + class RetryHandler: """重试处理装饰器""" - + def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"): self.max_retries = max_retries self.key_manager = key_manager self.key_arg = key_arg - + def __call__(self, func: Callable[..., T]) -> Callable[..., T]: @wraps(func) async def wrapper(*args, **kwargs) -> T: last_exception = None - + for attempt in range(self.max_retries): try: return await func(*args, **kwargs) except Exception as e: last_exception = e logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}") - + if self.key_manager: old_key = kwargs.get(self.key_arg) new_key = await self.key_manager.handle_api_failure(old_key) @@ -36,5 +37,5 @@ async def wrapper(*args, **kwargs) -> T: logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}") raise last_exception - - return wrapper \ No newline at end of file + + return wrapper diff --git a/app/services/gemini_chat_service.py b/app/services/gemini_chat_service.py index 2674b11..c8a949a 100644 --- a/app/services/gemini_chat_service.py +++ b/app/services/gemini_chat_service.py @@ -10,6 +10,61 @@ from app.services.key_manager import KeyManager logger = get_gemini_logger() + + +def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + +def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: + """构建工具""" + tools = [] + if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( + model.endswith("-search") or "-thinking" in model + ) and not _has_image_parts(payload.get("contents", [])): + tools.append({"code_execution": {}}) + if model.endswith("-search"): + tools.append({"googleSearch": {}}) + return tools + + +def _get_safety_settings(model: str) -> List[Dict[str, str]]: + """获取安全设置""" + if model == "gemini-2.0-flash-exp": + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"} + ] + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} + ] + + +def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: + """构建请求payload""" + payload = request.model_dump() + return { + "contents": payload.get("contents", []), + "tools": _build_tools(model, payload), + "safetySettings": _get_safety_settings(model), + "generationConfig": payload.get("generationConfig", {}), + "systemInstruction": payload.get("systemInstruction", []) + } + + class GeminiChatService: """聊天服务""" @@ -17,18 +72,18 @@ def __init__(self, base_url: str, key_manager: KeyManager): self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager self.response_handler = GeminiResponseHandler() - + def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]: """生成内容""" - payload = self._build_payload(model, request) + payload = _build_payload(model, request) response = self.api_client.generate_content(payload, model, api_key) return self.response_handler.handle_response(response, model, stream=False) - + async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]: """流式生成内容""" retries = 0 max_retries = 3 - payload = self._build_payload(model, request) + payload = _build_payload(model, request) while retries < max_retries: try: async for line in self.api_client.stream_generate_content(payload, model, api_key): @@ -47,52 +102,3 @@ async def stream_generate_content(self, model: str, request: GeminiRequest, api_ if retries >= max_retries: logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error") break - - def _build_payload(self,model: str, request: GeminiRequest) -> Dict[str, Any]: - """构建请求payload""" - payload = request.model_dump() - return { - "contents": payload.get("contents", []), - "tools": self._build_tools(model, payload), - "safetySettings": self._get_safety_settings(model), - "generationConfig": payload.get("generationConfig", {}), - "systemInstruction": payload.get("systemInstruction", []) - } - - def _build_tools(self, model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: - """构建工具""" - tools = [] - if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( - model.endswith("-search") or "-thinking" in model - ) and not self._has_image_parts(payload.get("contents", [])): - tools.append({"code_execution": {}}) - if model.endswith("-search"): - tools.append({"googleSearch": {}}) - return tools - - def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool: - """判断消息是否包含图片部分""" - for content in contents: - if "parts" in content: - for part in content["parts"]: - if "image_url" in part or "inline_data" in part: - return True - return False - - def _get_safety_settings(self, model: str) -> List[Dict[str, str]]: - """获取安全设置""" - if model == "gemini-2.0-flash-exp": - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"} - ] - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} - ] \ No newline at end of file diff --git a/app/services/model_service.py b/app/services/model_service.py index e3e0c7b..586d1b6 100644 --- a/app/services/model_service.py +++ b/app/services/model_service.py @@ -36,9 +36,9 @@ def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]: return None def convert_to_openai_models_format( - self, gemini_models: Dict[str, Any] + self, gemini_models: Dict[str, Any] ) -> Dict[str, Any]: - openai_format = {"object": "list", "data": [],"success": True} + openai_format = {"object": "list", "data": [], "success": True} for model in gemini_models.get("models", []): model_id = model["name"].split("/")[-1] diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py index 79300cd..fd25833 100644 --- a/app/services/openai_chat_service.py +++ b/app/services/openai_chat_service.py @@ -13,6 +13,76 @@ logger = get_openai_logger() +def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + +def _build_tools( + request: ChatRequest, messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """构建工具""" + tools = [] + model = request.model + + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not _has_image_parts(messages) + ): + tools.append({"code_execution": {}}) + if model.endswith("-search"): + tools.append({"googleSearch": {}}) + return tools + + +def _get_safety_settings(model: str) -> List[Dict[str, str]]: + """获取安全设置""" + # if ( + # "2.0" in model + # and "gemini-2.0-flash-thinking-exp" not in model + # and "gemini-2.0-pro-exp" not in model + # ): + if model == "gemini-2.0-flash-exp": + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, + ] + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ] + + +def _build_payload( + request: ChatRequest, messages: List[Dict[str, Any]] +) -> Dict[str, Any]: + """构建请求payload""" + return { + "contents": messages, + "generationConfig": { + "temperature": request.temperature, + "maxOutputTokens": request.max_tokens, + "stopSequences": request.stop, + "topP": request.top_p, + "topK": request.top_k, + }, + "tools": _build_tools(request, messages), + "safetySettings": _get_safety_settings(request.model), + } + + class OpenAIChatService: """聊天服务""" @@ -32,7 +102,7 @@ async def create_chat_completion( messages = self.message_converter.convert(request.messages) # 构建请求payload - payload = self._build_payload(request, messages) + payload = _build_payload(request, messages) if request.stream: return self._handle_stream_completion(request.model, payload, api_key) @@ -84,69 +154,3 @@ async def _handle_stream_completion( yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" yield "data: [DONE]\n\n" break - - def _build_payload( - self, request: ChatRequest, messages: List[Dict[str, Any]] - ) -> Dict[str, Any]: - """构建请求payload""" - return { - "contents": messages, - "generationConfig": { - "temperature": request.temperature, - "maxOutputTokens": request.max_tokens, - "stopSequences": request.stop, - "topP": request.top_p, - "topK": request.top_k, - }, - "tools": self._build_tools(request, messages), - "safetySettings": self._get_safety_settings(request.model), - } - - def _build_tools( - self, request: ChatRequest, messages: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """构建工具""" - tools = [] - model = request.model - - if ( - settings.TOOLS_CODE_EXECUTION_ENABLED - and not (model.endswith("-search") or "-thinking" in model) - and not self._has_image_parts(messages) - ): - tools.append({"code_execution": {}}) - if model.endswith("-search"): - tools.append({"googleSearch": {}}) - return tools - - def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool: - """判断消息是否包含图片部分""" - for content in contents: - if "parts" in content: - for part in content["parts"]: - if "image_url" in part or "inline_data" in part: - return True - return False - - def _get_safety_settings(self, model: str) -> List[Dict[str, str]]: - """获取安全设置""" - # if ( - # "2.0" in model - # and "gemini-2.0-flash-thinking-exp" not in model - # and "gemini-2.0-pro-exp" not in model - # ): - if model == "gemini-2.0-flash-exp": - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, - ] - return [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, - ]