Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add buffer mode #269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 83 additions & 37 deletions bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,22 @@ async def register_user_if_not_exists(update: Update, context: CallbackContext,


async def is_bot_mentioned(update: Update, context: CallbackContext):
try:
message = update.message
try:
message = update.message

if message.chat.type == "private":
return True
if message.chat.type == "private":
return True

if message.text is not None and ("@" + context.bot.username) in message.text:
return True
if message.text is not None and ("@" + context.bot.username) in message.text:
return True

if message.reply_to_message is not None:
if message.reply_to_message.from_user.id == context.bot.id:
return True
except:
return True
else:
return False
if message.reply_to_message is not None:
if message.reply_to_message.from_user.id == context.bot.id:
return True
except:
return True
else:
return False


async def start_handle(update: Update, context: CallbackContext):
Expand All @@ -150,14 +150,14 @@ async def help_handle(update: Update, context: CallbackContext):


async def help_group_chat_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
await register_user_if_not_exists(update, context, update.message.from_user)
user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())

text = HELP_GROUP_CHAT_MESSAGE.format(bot_username="@" + context.bot.username)
text = HELP_GROUP_CHAT_MESSAGE.format(bot_username="@" + context.bot.username)

await update.message.reply_text(text, parse_mode=ParseMode.HTML)
await update.message.reply_video(config.help_group_chat_video_path)
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
await update.message.reply_video(config.help_group_chat_video_path)


async def retry_handle(update: Update, context: CallbackContext):
Expand Down Expand Up @@ -204,8 +204,17 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
await generate_image_handle(update, context, message=message)
return

if db.get_user_attribute(user_id, "current_buffer_setting") is True:
if _message != "Done":
existing_buffer = db.get_buffer_message(user_id)
db.set_buffer_message(user_id, _message if existing_buffer is None else existing_buffer + _message)
await update.message.reply_text("Buffered 📝, send <b>Done</b> to generate a response", parse_mode=ParseMode.HTML)
return


async def message_handle_fn():
# new dialog timeout
nonlocal _message
if use_new_dialog_timeout:
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
db.start_new_dialog(user_id)
Expand All @@ -223,9 +232,15 @@ async def message_handle_fn():
# send typing action
await update.message.chat.send_action(action="typing")

existing_buffer = db.get_buffer_message(user_id)
if existing_buffer is not None:
_message = existing_buffer
# clear buffer after use
db.set_buffer_message(user_id, None)

if _message is None or len(_message) == 0:
await update.message.reply_text("🥲 You sent <b>empty message</b>. Please, try again!", parse_mode=ParseMode.HTML)
return
await update.message.reply_text("🥲 You sent <b>empty message</b>. Please, try again!", parse_mode=ParseMode.HTML)
return

dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
parse_mode = {
Expand Down Expand Up @@ -462,25 +477,25 @@ async def show_chat_modes_handle(update: Update, context: CallbackContext):


async def show_chat_modes_callback_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
if await is_previous_message_not_answered_yet(update.callback_query, context): return
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
if await is_previous_message_not_answered_yet(update.callback_query, context): return

user_id = update.callback_query.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
user_id = update.callback_query.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())

query = update.callback_query
await query.answer()
query = update.callback_query
await query.answer()

page_index = int(query.data.split("|")[1])
if page_index < 0:
return
page_index = int(query.data.split("|")[1])
if page_index < 0:
return

text, reply_markup = get_chat_mode_menu(page_index)
try:
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
except telegram.error.BadRequest as e:
if str(e).startswith("Message is not modified"):
pass
text, reply_markup = get_chat_mode_menu(page_index)
try:
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
except telegram.error.BadRequest as e:
if str(e).startswith("Message is not modified"):
pass


async def set_chat_mode_handle(update: Update, context: CallbackContext):
Expand Down Expand Up @@ -523,7 +538,19 @@ def get_settings_menu(user_id: int):
buttons.append(
InlineKeyboardButton(title, callback_data=f"set_settings|{model_key}")
)
reply_markup = InlineKeyboardMarkup([buttons])

# buttons to choose buffer mode
current_buffer_setting = db.get_user_attribute(user_id, "current_buffer_setting")
if current_buffer_setting is None or current_buffer_setting is False:
title = "❌ Buffer Mode"
action = "set_buffer|True"
else:
title = "✅ Buffer Mode"
action = "set_buffer|False"
buffer_buttons = [
InlineKeyboardButton(title, callback_data=action)
]
reply_markup = InlineKeyboardMarkup([buttons, buffer_buttons])

return text, reply_markup

Expand Down Expand Up @@ -557,6 +584,24 @@ async def set_settings_handle(update: Update, context: CallbackContext):
if str(e).startswith("Message is not modified"):
pass

async def set_buffer_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
user_id = update.callback_query.from_user.id

query = update.callback_query
await query.answer()

_, buffer_setting = query.data.split("|")
db.set_user_attribute(user_id, "current_buffer_setting", buffer_setting == "True")
db.start_new_dialog(user_id)

text, reply_markup = get_settings_menu(user_id)
try:
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
except telegram.error.BadRequest as e:
if str(e).startswith("Message is not modified"):
pass


async def show_balance_handle(update: Update, context: CallbackContext):
await register_user_if_not_exists(update, context, update.message.from_user)
Expand Down Expand Up @@ -684,6 +729,7 @@ def run_bot() -> None:

application.add_handler(CommandHandler("settings", settings_handle, filters=user_filter))
application.add_handler(CallbackQueryHandler(set_settings_handle, pattern="^set_settings"))
application.add_handler(CallbackQueryHandler(set_buffer_handle, pattern="^set_buffer"))

application.add_handler(CommandHandler("balance", show_balance_handle, filters=user_filter))

Expand Down
17 changes: 17 additions & 0 deletions bot/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):

self.user_collection = self.db["user"]
self.dialog_collection = self.db["dialog"]
self.buffer_collection = self.db["buffer"]

def check_if_user_exists(self, user_id: int, raise_exception: bool = False):
if self.user_collection.count_documents({"_id": user_id}) > 0:
Expand Down Expand Up @@ -43,6 +44,7 @@ def add_new_user(
"last_interaction": datetime.now(),
"first_seen": datetime.now(),

"current_buffer_setting": False,
"current_dialog_id": None,
"current_chat_mode": "assistant",
"current_model": config.models["available_text_models"][0],
Expand Down Expand Up @@ -78,6 +80,9 @@ def start_new_dialog(self, user_id: int):
{"$set": {"current_dialog_id": dialog_id}}
)

# clear buffer if any
self.buffer_collection.delete_many({"_id": user_id})

return dialog_id

def get_user_attribute(self, user_id: int, key: str):
Expand Down Expand Up @@ -107,6 +112,18 @@ def update_n_used_tokens(self, user_id: int, model: str, n_input_tokens: int, n_

self.set_user_attribute(user_id, "n_used_tokens", n_used_tokens_dict)

def set_buffer_message(self, user_id: int, value: Any):
self.check_if_user_exists(user_id, raise_exception=True)
self.buffer_collection.update_one({"_id": user_id}, {"$set": {"buffer": value}}, upsert=True)

def get_buffer_message(self, user_id: int):
self.check_if_user_exists(user_id, raise_exception=True)

buffer_dict = self.buffer_collection.find_one({"_id": user_id})
if buffer_dict is None or "buffer" not in buffer_dict:
return None
return buffer_dict["buffer"]

def get_dialog_messages(self, user_id: int, dialog_id: Optional[str] = None):
self.check_if_user_exists(user_id, raise_exception=True)

Expand Down