From 447c55e51d8e17791f887ca9b93ba2e3626f8b81 Mon Sep 17 00:00:00 2001 From: allthatjazzleo Date: Sat, 13 Jan 2024 18:54:51 +0800 Subject: [PATCH] rebase main --- bot/bot.py | 120 +++++++++++++++++++++++++++++++++--------------- bot/database.py | 17 +++++++ 2 files changed, 100 insertions(+), 37 deletions(-) diff --git a/bot/bot.py b/bot/bot.py index f4510a6dc..7336addd1 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -110,22 +110,22 @@ async def register_user_if_not_exists(update: Update, context: CallbackContext, async def is_bot_mentioned(update: Update, context: CallbackContext): - try: - message = update.message + try: + message = update.message - if message.chat.type == "private": - return True + if message.chat.type == "private": + return True - if message.text is not None and ("@" + context.bot.username) in message.text: - return True + if message.text is not None and ("@" + context.bot.username) in message.text: + return True - if message.reply_to_message is not None: - if message.reply_to_message.from_user.id == context.bot.id: - return True - except: - return True - else: - return False + if message.reply_to_message is not None: + if message.reply_to_message.from_user.id == context.bot.id: + return True + except: + return True + else: + return False async def start_handle(update: Update, context: CallbackContext): @@ -150,14 +150,14 @@ async def help_handle(update: Update, context: CallbackContext): async def help_group_chat_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update, context, update.message.from_user) - user_id = update.message.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + await register_user_if_not_exists(update, context, update.message.from_user) + user_id = update.message.from_user.id + db.set_user_attribute(user_id, "last_interaction", datetime.now()) - text = HELP_GROUP_CHAT_MESSAGE.format(bot_username="@" + context.bot.username) + text = HELP_GROUP_CHAT_MESSAGE.format(bot_username="@" + context.bot.username) - await update.message.reply_text(text, parse_mode=ParseMode.HTML) - await update.message.reply_video(config.help_group_chat_video_path) + await update.message.reply_text(text, parse_mode=ParseMode.HTML) + await update.message.reply_video(config.help_group_chat_video_path) async def retry_handle(update: Update, context: CallbackContext): @@ -204,8 +204,17 @@ async def message_handle(update: Update, context: CallbackContext, message=None, await generate_image_handle(update, context, message=message) return + if db.get_user_attribute(user_id, "current_buffer_setting") is True: + if _message != "Done": + existing_buffer = db.get_buffer_message(user_id) + db.set_buffer_message(user_id, _message if existing_buffer is None else existing_buffer + _message) + await update.message.reply_text("Buffered 📝, send Done to generate a response", parse_mode=ParseMode.HTML) + return + + async def message_handle_fn(): # new dialog timeout + nonlocal _message if use_new_dialog_timeout: if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0: db.start_new_dialog(user_id) @@ -223,9 +232,15 @@ async def message_handle_fn(): # send typing action await update.message.chat.send_action(action="typing") + existing_buffer = db.get_buffer_message(user_id) + if existing_buffer is not None: + _message = existing_buffer + # clear buffer after use + db.set_buffer_message(user_id, None) + if _message is None or len(_message) == 0: - await update.message.reply_text("🥲 You sent empty message. Please, try again!", parse_mode=ParseMode.HTML) - return + await update.message.reply_text("🥲 You sent empty message. Please, try again!", parse_mode=ParseMode.HTML) + return dialog_messages = db.get_dialog_messages(user_id, dialog_id=None) parse_mode = { @@ -462,25 +477,25 @@ async def show_chat_modes_handle(update: Update, context: CallbackContext): async def show_chat_modes_callback_handle(update: Update, context: CallbackContext): - await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user) - if await is_previous_message_not_answered_yet(update.callback_query, context): return + await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user) + if await is_previous_message_not_answered_yet(update.callback_query, context): return - user_id = update.callback_query.from_user.id - db.set_user_attribute(user_id, "last_interaction", datetime.now()) + user_id = update.callback_query.from_user.id + db.set_user_attribute(user_id, "last_interaction", datetime.now()) - query = update.callback_query - await query.answer() + query = update.callback_query + await query.answer() - page_index = int(query.data.split("|")[1]) - if page_index < 0: - return + page_index = int(query.data.split("|")[1]) + if page_index < 0: + return - text, reply_markup = get_chat_mode_menu(page_index) - try: - await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML) - except telegram.error.BadRequest as e: - if str(e).startswith("Message is not modified"): - pass + text, reply_markup = get_chat_mode_menu(page_index) + try: + await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML) + except telegram.error.BadRequest as e: + if str(e).startswith("Message is not modified"): + pass async def set_chat_mode_handle(update: Update, context: CallbackContext): @@ -523,7 +538,19 @@ def get_settings_menu(user_id: int): buttons.append( InlineKeyboardButton(title, callback_data=f"set_settings|{model_key}") ) - reply_markup = InlineKeyboardMarkup([buttons]) + + # buttons to choose buffer mode + current_buffer_setting = db.get_user_attribute(user_id, "current_buffer_setting") + if current_buffer_setting is None or current_buffer_setting is False: + title = "❌ Buffer Mode" + action = "set_buffer|True" + else: + title = "✅ Buffer Mode" + action = "set_buffer|False" + buffer_buttons = [ + InlineKeyboardButton(title, callback_data=action) + ] + reply_markup = InlineKeyboardMarkup([buttons, buffer_buttons]) return text, reply_markup @@ -557,6 +584,24 @@ async def set_settings_handle(update: Update, context: CallbackContext): if str(e).startswith("Message is not modified"): pass +async def set_buffer_handle(update: Update, context: CallbackContext): + await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user) + user_id = update.callback_query.from_user.id + + query = update.callback_query + await query.answer() + + _, buffer_setting = query.data.split("|") + db.set_user_attribute(user_id, "current_buffer_setting", buffer_setting == "True") + db.start_new_dialog(user_id) + + text, reply_markup = get_settings_menu(user_id) + try: + await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML) + except telegram.error.BadRequest as e: + if str(e).startswith("Message is not modified"): + pass + async def show_balance_handle(update: Update, context: CallbackContext): await register_user_if_not_exists(update, context, update.message.from_user) @@ -684,6 +729,7 @@ def run_bot() -> None: application.add_handler(CommandHandler("settings", settings_handle, filters=user_filter)) application.add_handler(CallbackQueryHandler(set_settings_handle, pattern="^set_settings")) + application.add_handler(CallbackQueryHandler(set_buffer_handle, pattern="^set_buffer")) application.add_handler(CommandHandler("balance", show_balance_handle, filters=user_filter)) diff --git a/bot/database.py b/bot/database.py index b6bafe358..4d83f17ad 100644 --- a/bot/database.py +++ b/bot/database.py @@ -14,6 +14,7 @@ def __init__(self): self.user_collection = self.db["user"] self.dialog_collection = self.db["dialog"] + self.buffer_collection = self.db["buffer"] def check_if_user_exists(self, user_id: int, raise_exception: bool = False): if self.user_collection.count_documents({"_id": user_id}) > 0: @@ -43,6 +44,7 @@ def add_new_user( "last_interaction": datetime.now(), "first_seen": datetime.now(), + "current_buffer_setting": False, "current_dialog_id": None, "current_chat_mode": "assistant", "current_model": config.models["available_text_models"][0], @@ -78,6 +80,9 @@ def start_new_dialog(self, user_id: int): {"$set": {"current_dialog_id": dialog_id}} ) + # clear buffer if any + self.buffer_collection.delete_many({"_id": user_id}) + return dialog_id def get_user_attribute(self, user_id: int, key: str): @@ -107,6 +112,18 @@ def update_n_used_tokens(self, user_id: int, model: str, n_input_tokens: int, n_ self.set_user_attribute(user_id, "n_used_tokens", n_used_tokens_dict) + def set_buffer_message(self, user_id: int, value: Any): + self.check_if_user_exists(user_id, raise_exception=True) + self.buffer_collection.update_one({"_id": user_id}, {"$set": {"buffer": value}}, upsert=True) + + def get_buffer_message(self, user_id: int): + self.check_if_user_exists(user_id, raise_exception=True) + + buffer_dict = self.buffer_collection.find_one({"_id": user_id}) + if buffer_dict is None or "buffer" not in buffer_dict: + return None + return buffer_dict["buffer"] + def get_dialog_messages(self, user_id: int, dialog_id: Optional[str] = None): self.check_if_user_exists(user_id, raise_exception=True)