diff --git a/main.py b/main.py index 0c79409ee..93602eb65 100644 --- a/main.py +++ b/main.py @@ -74,10 +74,10 @@ def handle_text_message(event): elif text.startswith('/圖像'): prompt = text[3:].strip() + memory.append(user_id, 'user', prompt) is_successful, response, error_message = model_management[user_id].image_generations(prompt) if not is_successful: raise Exception(error_message) - memory.append(user_id, 'user', prompt) msg = ImageSendMessage( original_content_url=response, preview_image_url=response @@ -85,17 +85,18 @@ def handle_text_message(event): memory.append(user_id, 'assistant', response) else: + memory.append(user_id, 'user', text) is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), os.getenv('OPENAI_MODEL_ENGINE')) if not is_successful: raise Exception(error_message) - memory.append(user_id, 'user', text) role, response = get_role_and_content(response) msg = TextSendMessage(text=response) memory.append(user_id, role, response) except ValueError: - msg = TextSendMessage(text='Token 無效,請重新註冊,注意格式有空格,格式為 /註冊 sk-xxxxx') + msg = TextSendMessage(text='Token 無效,請重新註冊,格式為 /註冊 sk-xxxxx') except Exception as e: + memory.remove(user_id) msg = TextSendMessage(text=str(e)) line_bot_api.reply_message(event.reply_token, msg) @@ -114,18 +115,19 @@ def handle_audio_message(event): raise ValueError('Invalid API token') else: transciption, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1') + memory.append(user_id, 'user', transciption) if error_message: raise Exception(error_message) is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo') if not is_successful: raise Exception(error_message) - memory.append(user_id, 'user', transciption) role, response = get_role_and_content(response) memory.append(user_id, role, response) msg = TextSendMessage(text=response) except ValueError: msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /註冊 [API TOKEN]') except Exception as e: + memory.remove(user_id) msg = TextSendMessage(text=str(e)) os.remove(input_audio_path) line_bot_api.reply_message(event.reply_token, msg) diff --git a/src/models.py b/src/models.py index 973d78d62..87acc3d67 100644 --- a/src/models.py +++ b/src/models.py @@ -6,10 +6,10 @@ class ModelInterface: def check_token_valid(self) -> bool: pass - def chat_completions(self, messages: List[Dict]) -> str: + def chat_completions(self, messages: List[Dict], model_engine: str) -> str: pass - def audio_transcriptions(self, file) -> str: + def audio_transcriptions(self, file, model_engine: str) -> str: pass def image_generations(self, prompt: str) -> str: @@ -62,4 +62,4 @@ def image_generations(self, prompt: str) -> str: "n": 1, "size": "512x512" } - return self._request('/images/generations', body=json_body) + return self._request('POST', '/images/generations', body=json_body) diff --git a/src/utils.py b/src/utils.py index dccad92d9..032ffdf25 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,8 +4,8 @@ t2s_converter = opencc.OpenCC('t2s.json') -def get_role_and_content(response): +def get_role_and_content(response: str): role = response['choices'][0]['message']['role'] content = response['choices'][0]['message']['content'].strip() - response = s2t_converter.convert(content) - return role, response + content = s2t_converter.convert(content) + return role, content