Skip to content

Commit

Permalink
refactor: 修改安全设置逻辑以匹配特定模型
Browse files Browse the repository at this point in the history
  • Loading branch information
snailyp committed Feb 5, 2025
1 parent b465fd1 commit b60b063
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 42 deletions.
2 changes: 1 addition & 1 deletion app/services/gemini_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:

def _get_safety_settings(self, model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if "2.0" in model and "gemini-2.0-flash-thinking-exp" not in model:
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
Expand Down
89 changes: 48 additions & 41 deletions app/services/openai_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from typing import Dict, Any, AsyncGenerator, List, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.schemas.openai_models import ChatRequest
from app.core.config import settings
from app.services.key_manager import KeyManager

logger = get_openai_logger()


class OpenAIChatService:
"""聊天服务"""

Expand All @@ -19,7 +21,7 @@ def __init__(self, base_url: str, key_manager: KeyManager):
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager

async def create_chat_completion(
self,
request: ChatRequest,
Expand All @@ -28,68 +30,64 @@ async def create_chat_completion(
"""创建聊天完成"""
# 转换消息格式
messages = self.message_converter.convert(request.messages)

# 构建请求payload
payload = self._build_payload(request, messages)

if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return self._handle_normal_completion(request.model, payload, api_key)

def _handle_normal_completion(
self,
model: str,
payload: Dict[str, Any],
api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response,
model,
stream=False,
finish_reason="stop"
response, model, stream=False, finish_reason="stop"
)

async def _handle_stream_completion(
self,
model: str,
payload: Dict[str, Any],
api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = 3
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
chunk = json.loads(line[6:])
openai_chunk = self.response_handler.handle_response(
chunk,
model,
stream=True,
finish_reason=None
chunk, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
break # 成功后退出循环
break # 成功后退出循环
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break

def _build_payload(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> Dict[str, Any]:

def _build_payload(
self, request: ChatRequest, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""构建请求payload"""
return {
"contents": messages,
Expand All @@ -98,25 +96,29 @@ def _build_payload(self, request: ChatRequest, messages: List[Dict[str, Any]]) -
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k
"topK": request.top_k,
},
"tools": self._build_tools(request, messages),
"safetySettings": self._get_safety_settings(request.model)
"safetySettings": self._get_safety_settings(request.model),
}

def _build_tools(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

def _build_tools(
self, request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model

if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not self._has_image_parts(messages):
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not self._has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools

def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
Expand All @@ -125,21 +127,26 @@ def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
if "image_url" in part or "inline_data" in part:
return True
return False

def _get_safety_settings(self, model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if "2.0" in model and "gemini-2.0-flash-thinking-exp" not in model:
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]

0 comments on commit b60b063

Please sign in to comment.