From 2d862d1792a8bd4d68e543179a1a13bb4e816846 Mon Sep 17 00:00:00 2001 From: evgfilim1 Date: Wed, 2 Nov 2022 02:45:48 +0500 Subject: [PATCH] Generalize modules, some minor improvements Create `BaseModule` and `BaseHandler` classes Remove `MessageNotModified` warning Accept varargs instead of list when registering a command Register all shortcuts via decorators Allow hooks to return a value Resolve hooks and shortcuts dependencies Rename `ShortcutTransformersModule` to `ShortcutsModule` Add `MiddlewareManager.__contains__` Rename `__is_prod__` to `is_prod` Bump version to 0.5 --- locales/evgfilim1-userbot.pot | 79 ++--- userbot/__init__.py | 6 +- userbot/__main__.py | 42 ++- userbot/commands/__init__.py | 2 +- userbot/commands/about.py | 4 +- userbot/commands/chat_admin.py | 2 +- userbot/commands/dice.py | 2 +- userbot/commands/download.py | 3 +- userbot/commands/messages.py | 4 +- userbot/commands/notes.py | 8 +- userbot/hooks.py | 10 +- userbot/middleware_manager.py | 3 + userbot/modules/__init__.py | 4 +- userbot/modules/base.py | 230 ++++++++++++ userbot/modules/commands.py | 618 ++++++++++++++++----------------- userbot/modules/hooks.py | 213 ++++++++---- userbot/modules/shortcuts.py | 169 ++++++--- userbot/shortcuts.py | 11 +- 18 files changed, 892 insertions(+), 518 deletions(-) create mode 100644 userbot/modules/base.py diff --git a/locales/evgfilim1-userbot.pot b/locales/evgfilim1-userbot.pot index 799f5be..744bb4a 100644 --- a/locales/evgfilim1-userbot.pot +++ b/locales/evgfilim1-userbot.pot @@ -6,9 +6,9 @@ #, fuzzy msgid "" msgstr "" -"Project-Id-Version: evgfilim1/userbot 0.4.x\n" +"Project-Id-Version: evgfilim1/userbot 0.5.x\n" "Report-Msgid-Bugs-To: https://github.com/evgfilim1/userbot/issues\n" -"POT-Creation-Date: 2022-10-28 23:38+0500\n" +"POT-Creation-Date: 2022-11-02 02:36+0500\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -152,7 +152,7 @@ msgstr "" msgid "{icon} The file has been downloaded to {output}" msgstr "" -#: userbot/commands/download.py:77 +#: userbot/commands/download.py:78 msgid "Downloading file(s)..." msgstr "" @@ -234,14 +234,14 @@ msgstr "" msgid "{icon} Message with key={key!r} not found" msgstr "" -#: userbot/commands/notes.py:52 +#: userbot/commands/notes.py:53 #, python-brace-format msgid "" "{icon} File reference expired, please save the note again.\n" "Note key: {key}" msgstr "" -#: userbot/commands/notes.py:77 userbot/commands/notes.py:112 +#: userbot/commands/notes.py:78 userbot/commands/notes.py:113 #, python-brace-format msgid "" "{icon} Please specify note key\n" @@ -249,19 +249,19 @@ msgid "" "Possible fix: {message_text} key" msgstr "" -#: userbot/commands/notes.py:87 +#: userbot/commands/notes.py:88 #, python-brace-format msgid "{icon} Note {key} saved" msgstr "" -#: userbot/commands/notes.py:98 +#: userbot/commands/notes.py:99 #, python-brace-format msgid "" "{icon} Saved notes:\n" "{t}" msgstr "" -#: userbot/commands/notes.py:116 +#: userbot/commands/notes.py:117 #, python-brace-format msgid "{icon} Note {key} deleted" msgstr "" @@ -331,7 +331,25 @@ msgstr "" msgid "{icon} Stopping userbot..." msgstr "" -#: userbot/modules/commands.py:118 +#: userbot/modules/base.py:103 +msgid "Userbot is processing the message..." +msgstr "" + +#: userbot/modules/base.py:122 +#, python-brace-format +msgid "" +"{icon} Timed out after {timeout} while processing the message.\n" +"More info can be found in logs." +msgstr "" + +#: userbot/modules/base.py:127 +#, python-brace-format +msgid "{timeout} second" +msgid_plural "{timeout} seconds" +msgstr[0] "" +msgstr[1] "" + +#: userbot/modules/commands.py:206 #, python-brace-format msgid "" "{icon} An error occurred during executing command.\n" @@ -343,7 +361,7 @@ msgid "" "More info can be found in logs." msgstr "" -#: userbot/modules/commands.py:136 +#: userbot/modules/commands.py:222 #, python-brace-format msgid "" "{icon} Successfully executed.\n" @@ -353,63 +371,32 @@ msgid "" "Result:" msgstr "" -#: userbot/modules/commands.py:145 +#: userbot/modules/commands.py:233 #, python-brace-format msgid "{text} See reply." msgstr "" -#: userbot/modules/commands.py:234 -#, python-brace-format -msgid "" -"{icon} Command timed out after {timeout}.\n" -"\n" -"Command: {message_text}\n" -"\n" -"More info can be found in logs." -msgstr "" - -#: userbot/modules/commands.py:240 -#, python-brace-format -msgid "{timeout} second" -msgid_plural "{timeout} seconds" -msgstr[0] "" -msgstr[1] "" - -#: userbot/modules/commands.py:259 -#, python-brace-format -msgid "Executing {command}..." -msgstr "" - -#: userbot/modules/commands.py:301 -#, python-brace-format -msgid "" -"{result}\n" -"\n" -"{icon} MessageNotModified was raised, check that there's only one instance of userbot " -"is running." -msgstr "" - -#: userbot/modules/commands.py:460 +#: userbot/modules/commands.py:400 #, python-brace-format msgid "" "Help for {args}:\n" "{usage}" msgstr "" -#: userbot/modules/commands.py:466 +#: userbot/modules/commands.py:406 msgid "" "List of userbot commands available:\n" "\n" msgstr "" -#: userbot/modules/hooks.py:27 +#: userbot/modules/hooks.py:34 #, python-brace-format msgid "" "Hooks in this chat:\n" "{hooks}" msgstr "" -#: userbot/modules/hooks.py:128 +#: userbot/modules/hooks.py:213 #, python-brace-format msgid "" "Available hooks:\n" diff --git a/userbot/__init__.py b/userbot/__init__.py index 459104b..a638531 100644 --- a/userbot/__init__.py +++ b/userbot/__init__.py @@ -1,10 +1,10 @@ __all__ = [ - "__is_prod__", + "is_prod", "__version__", ] from os import environ from typing import Final -__is_prod__: Final[bool] = bool(environ.get("GITHUB_SHA", "")) -__version__: Final[str] = "0.4.x" + ("-dev" if not __is_prod__ else "") +is_prod: Final[bool] = bool(environ.get("GITHUB_SHA", "")) +__version__: Final[str] = "0.5.x" + ("-dev" if not is_prod else "") diff --git a/userbot/__main__.py b/userbot/__main__.py index 36c0973..95227ac 100644 --- a/userbot/__main__.py +++ b/userbot/__main__.py @@ -9,21 +9,21 @@ from pyrogram.handlers import RawUpdateHandler from pyrogram.methods.utilities.idle import idle -from userbot import __is_prod__, __version__ +from userbot import __version__, is_prod from userbot.commands import commands from userbot.commands.chat_admin import react2ban_raw_reaction_handler from userbot.config import Config, RedisConfig -from userbot.constants import GH_PATTERN from userbot.hooks import hooks from userbot.job_manager import AsyncJobManager from userbot.middlewares import KwargsMiddleware, icon_middleware, translate_middleware -from userbot.shortcuts import get_note, github, shortcuts +from userbot.modules import HooksModule +from userbot.shortcuts import shortcuts from userbot.storage import RedisStorage, Storage from userbot.utils import GitHubClient, fetch_stickers logging.basicConfig(level=logging.WARNING) _log = logging.getLogger(__name__) -_log.setLevel(logging.INFO if __is_prod__ else logging.DEBUG) +_log.setLevel(logging.INFO if is_prod else logging.DEBUG) async def _main( @@ -69,27 +69,35 @@ def main() -> None: github_client = GitHubClient(AsyncClient(http2=True)) _log.debug("Registering handlers...") - shortcuts.add_handler(partial(github, github_client=github_client), GH_PATTERN) - shortcuts.add_handler(partial(get_note, storage=storage), r"n://(.+?)/") client.add_handler( RawUpdateHandler(partial(react2ban_raw_reaction_handler, storage=storage)), group=1, ) - commands.add_middleware( - KwargsMiddleware( - { - "storage": storage, - "data_dir": config.data_location, - "notes_chat": config.media_notes_chat, - } - ) + + root_hooks = HooksModule(commands=commands, storage=storage) + root_hooks.add_submodule(hooks) + + kwargs_middleware = KwargsMiddleware( + { + "storage": storage, + "data_dir": config.data_location, + "notes_chat": config.media_notes_chat, + "github_client": github_client, + } ) - commands.add_middleware(icon_middleware) + commands.add_middleware(kwargs_middleware) commands.add_middleware(translate_middleware) + commands.add_middleware(icon_middleware) + root_hooks.add_middleware(kwargs_middleware) + root_hooks.add_middleware(translate_middleware) + root_hooks.add_middleware(icon_middleware) + shortcuts.add_middleware(kwargs_middleware) + shortcuts.add_middleware(translate_middleware) + shortcuts.add_middleware(icon_middleware) # `HooksModule` must be registered before `CommandsModule` because it adds some commands - hooks.register(client, storage, commands) - commands.register(client, with_help=True) + root_hooks.register(client) + commands.register(client) shortcuts.register(client) job_manager = AsyncJobManager() diff --git a/userbot/commands/__init__.py b/userbot/commands/__init__.py index 4e71f05..998f0be 100644 --- a/userbot/commands/__init__.py +++ b/userbot/commands/__init__.py @@ -25,7 +25,7 @@ except ImportError: test_commands = None -commands = CommandsModule() +commands = CommandsModule(root=True) for submodule in ( about_commands, diff --git a/userbot/commands/about.py b/userbot/commands/about.py index dd46194..5696b2c 100644 --- a/userbot/commands/about.py +++ b/userbot/commands/about.py @@ -7,7 +7,7 @@ from pyrogram import Client -from .. import __is_prod__ +from .. import is_prod from ..constants import Icons, PremiumIcons from ..modules import CommandsModule from ..translation import Translation @@ -33,7 +33,7 @@ async def about(client: Client, icons: Type[Icons], tr: Translation) -> str: f"{github_icon} evgfilim1/userbot\n" f"{commit_icon} {commit}" ) - if __is_prod__: + if is_prod: t += _(" (deployments)").format( base_url=base_url, ) diff --git a/userbot/commands/chat_admin.py b/userbot/commands/chat_admin.py index f24cbb9..e77bb74 100644 --- a/userbot/commands/chat_admin.py +++ b/userbot/commands/chat_admin.py @@ -189,7 +189,7 @@ async def react2ban( return _(_REACT2BAN_TEXT) -@commands.add(["no_react2ban", "noreact2ban"], usage="") +@commands.add("no_react2ban", "noreact2ban", usage="") async def no_react2ban( message: Message, storage: Storage, diff --git a/userbot/commands/dice.py b/userbot/commands/dice.py index 5d082a7..8096a7c 100644 --- a/userbot/commands/dice.py +++ b/userbot/commands/dice.py @@ -40,7 +40,7 @@ def _str_die(self, node): return ", ".join(the_rolls) -@commands.add(["roll", "dice"], usage="") +@commands.add("roll", "dice", usage="") async def dice(command: CommandObject) -> str: """Rolls dice according to d20.roll syntax diff --git a/userbot/commands/download.py b/userbot/commands/download.py index 9bf9f62..e7cf44a 100644 --- a/userbot/commands/download.py +++ b/userbot/commands/download.py @@ -72,7 +72,8 @@ async def _downloader( @commands.add( - ["download", "dl"], + "download", + "dl", usage="[reply] [filename]", waiting_message=_("Downloading file(s)..."), ) diff --git a/userbot/commands/messages.py b/userbot/commands/messages.py index 696d804..7ec9682 100644 --- a/userbot/commands/messages.py +++ b/userbot/commands/messages.py @@ -21,7 +21,7 @@ commands = CommandsModule("Messages") -@commands.add(["delete", "delet", "del"], usage="") +@commands.add("delete", "delet", "del", usage="") async def delete_this(message: Message) -> None: """Deletes replied message for everyone""" try: @@ -128,7 +128,7 @@ async def user_first_message( await message.delete() -@commands.add(["copyhere", "cphere", "cph"], usage="") +@commands.add("copyhere", "cphere", "cph", usage="") async def copy_here(message: Message) -> None: """Copies replied message to current chat""" await message.reply_to_message.copy(message.chat.id) diff --git a/userbot/commands/notes.py b/userbot/commands/notes.py index 20fbcac..ab0e13e 100644 --- a/userbot/commands/notes.py +++ b/userbot/commands/notes.py @@ -19,7 +19,7 @@ commands = CommandsModule("Notes") -@commands.add(["get", "note", "n"], usage="") +@commands.add("get", "note", "n", usage="") async def get_note( client: Client, message: Message, @@ -63,7 +63,7 @@ async def get_note( await message.delete() -@commands.add(["save", "note_add", "nadd"], usage=" ") +@commands.add("save", "note_add", "nadd", usage=" ") async def save_note( message: Message, command: CommandObject, @@ -88,7 +88,7 @@ async def save_note( return _("{icon} Note {key} saved").format(icon=icons.BOOKMARK, key=key) -@commands.add(["notes", "ns"]) +@commands.add("notes", "ns") async def saved_notes(storage: Storage, icons: Type[Icons], tr: Translation) -> str: """Shows all saved notes""" _ = tr.gettext @@ -99,7 +99,7 @@ async def saved_notes(storage: Storage, icons: Type[Icons], tr: Translation) -> return _("{icon} Saved notes:\n{t}").format(icon=icons.BOOKMARK, t=t) -@commands.add(["note_del", "ndel"], usage="") +@commands.add("note_del", "ndel", usage="") async def delete_note( message: Message, command: CommandObject, diff --git a/userbot/hooks.py b/userbot/hooks.py index 1f01faa..0693ddb 100644 --- a/userbot/hooks.py +++ b/userbot/hooks.py @@ -23,7 +23,7 @@ @hooks.add("emojis", filters.regex(r"\b((?:дак\b|кря(?:к.?|\b))|блин)", flags=re.I)) -async def on_emojis(_: Client, message: Message) -> None: +async def on_emojis(message: Message) -> str: t = "" for match in message.matches: m = match[1].lower() @@ -31,11 +31,11 @@ async def on_emojis(_: Client, message: Message) -> None: t += "🦆" elif m == "блин": t += "🥞" - await message.reply(t) + return t @hooks.add("tap", (filters.regex(r"\b(?:тык|nsr)\b", flags=re.I) | sticker(TAP_FLT))) -async def on_tap(_: Client, message: Message) -> None: +async def on_tap(message: Message) -> None: await message.reply_sticker(TAP_STICKER) @@ -47,10 +47,10 @@ async def mibib(client: Client, message: Message) -> None: @hooks.add("bra", filters.regex(r"\b(?:бра|bra)\b", flags=re.I)) -async def on_bra(_: Client, message: Message) -> None: +async def on_bra(message: Message) -> None: await message.reply_photo(BRA_MEME_PICTURE) @hooks.add("uwu", filters.regex(r"\b(?:uwu|owo|уву|ово)\b", flags=re.I)) -async def on_uwu(_: Client, message: Message) -> None: +async def on_uwu(message: Message) -> None: await message.reply_photo(UWU_MEME_PICTURE) diff --git a/userbot/middleware_manager.py b/userbot/middleware_manager.py index 6745a7a..02c9d0d 100644 --- a/userbot/middleware_manager.py +++ b/userbot/middleware_manager.py @@ -49,3 +49,6 @@ async def __call__( @property def has_handlers(self) -> bool: return len(self._middlewares) > 0 + + def __contains__(self, item: Middleware[_ReturnT]) -> bool: + return item in self._middlewares diff --git a/userbot/modules/__init__.py b/userbot/modules/__init__.py index 243218d..435ce51 100644 --- a/userbot/modules/__init__.py +++ b/userbot/modules/__init__.py @@ -2,9 +2,9 @@ "CommandObject", "CommandsModule", "HooksModule", - "ShortcutTransformersModule", + "ShortcutsModule", ] from .commands import CommandObject, CommandsModule from .hooks import HooksModule -from .shortcuts import ShortcutTransformersModule +from .shortcuts import ShortcutsModule diff --git a/userbot/modules/base.py b/userbot/modules/base.py new file mode 100644 index 0000000..b34141b --- /dev/null +++ b/userbot/modules/base.py @@ -0,0 +1,230 @@ +__all__ = [ + "BaseHandler", + "BaseModule", + "HandlerT", +] + +import asyncio +import inspect +import logging +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, Generic, NamedTuple, Type, TypeVar + +from pyrogram import Client +from pyrogram.enums import ParseMode +from pyrogram.errors import MessageNotModified, MessageTooLong +from pyrogram.filters import Filter +from pyrogram.handlers.handler import Handler +from pyrogram.types import Message + +from ..constants import Icons +from ..middleware_manager import Middleware, MiddlewareManager +from ..translation import Translation +from ..utils import async_partial + +_log = logging.getLogger(__name__) + +HandlerT = TypeVar("HandlerT", bound=Callable[..., Awaitable[str | None]]) + + +class _NewMessage(NamedTuple): + message: Message + edited: bool + + +class BaseHandler(ABC): + def __init__( + self, + *, + handler: HandlerT, + handle_edits: bool, + waiting_message: str | None, + timeout: int | None, + ) -> None: + self._handler = handler + self._handle_edits = handle_edits + self._waiting_message = waiting_message + self._timeout = timeout + + self._signature = inspect.signature(self.handler) + + @property + def handler(self) -> HandlerT: + return self._handler + + @property + def handle_edits(self) -> bool: + return self._handle_edits + + @property + def waiting_message(self) -> str | None: + return self._waiting_message + + @property + def timeout(self) -> int | None: + return self._timeout + + @staticmethod + async def _edit_or_reply_html_text(message: Message, text: str, **kwargs: Any) -> _NewMessage: + """Edit a message if it's outgoing, otherwise reply to it.""" + if message.outgoing or message.from_user is not None and message.from_user.is_self: + return _NewMessage( + message=await message.edit_text(text, parse_mode=ParseMode.HTML, **kwargs), + edited=True, + ) + return _NewMessage( + message=await message.reply_text(text, parse_mode=ParseMode.HTML, **kwargs), + edited=False, + ) + + async def _invoke_handler(self, data: dict[str, Any]) -> str | None: + """Filter data and call the handler.""" + suitable_kwargs = {} + # TODO (2022-11-01): check all params are passed, otherwise raise an error + for name, param in self._signature.parameters.items(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + suitable_kwargs = data # pass all kwargs + break + if name in data: + suitable_kwargs[name] = data[name] + return await self.handler(**suitable_kwargs) + + async def _send_waiting_message(self, data: dict[str, Any]) -> None: + """Edit a message after some time to show that the bot is working on the message.""" + await asyncio.sleep(0.75) + message: Message = data["message"] + icons: Type[Icons] = data["icons"] + tr: Translation = data["tr"] + _ = tr.gettext + if self.waiting_message is not None: + # Waiting messages may be marked for translation, so we need to translate them here. + text = _(self.waiting_message).strip() + else: + text = _("Userbot is processing the message...") + message, edited = await self._edit_or_reply_html_text(message, f"{icons.WATCH} {text}") + if not edited: + data["new_message"] = message + + async def _timed_out_handler(self, data: dict[str, Any]) -> str | None: + """Handles the case when the handler times out. Reports an error by default.""" + _log.warning( + f"Handler %r timed out", + self, + exc_info=True, + extra={"data": data}, + ) + icons: Type[Icons] = data["icons"] + tr: Translation = data["tr"] + _ = tr.gettext + __ = tr.ngettext + # cannot be None as `TimeoutError` is raised only if timeout is not None + timeout: int = self.timeout + return _( + "{icon} Timed out after {timeout} while processing the message.\n" + "More info can be found in logs." + ).format( + icon=icons.STOP, + timeout=__("{timeout} second", "{timeout} seconds", timeout).format(timeout=timeout), + ) + + async def _invoke_with_timeout(self, data: dict[str, Any]) -> str | None: + """Call the handler with a timeout.""" + waiting_task = asyncio.create_task(self._send_waiting_message(data)) + try: + return await asyncio.wait_for(self._invoke_handler(data), timeout=self.timeout) + except asyncio.TimeoutError: + waiting_task.cancel() + return await self._timed_out_handler(data) + finally: + waiting_task.cancel() + + async def _exception_handler(self, e: Exception, data: dict[str, Any]) -> str | None: + """Handle exceptions raised by the handler. Re-raises an exception by default.""" + raise + + async def _message_too_long_handler(self, result: str, data: dict[str, Any]) -> None: + """Handles the case when the result is too long to be sent. + Re-raises an exception by default.""" + raise + + async def _message_not_modified_handler(self, result: str, data: dict[str, Any]) -> None: + """Handles the case when the result doesn't modify the message. + Logs a warning to console by default.""" + _log.warning( + "Message was not modified while executing handler %r", + self, + exc_info=True, + extra={"data": data}, + ) + return + + async def _result_handler(self, result: str, data: dict[str, Any]) -> None: + """Handle the result of the handler. Edits a message by default.""" + actual_message: Message = data.get("new_message", data["message"]) + try: + await self._edit_or_reply_html_text(actual_message, result) + except MessageTooLong: + await self._message_too_long_handler(result, data) + except MessageNotModified: + await self._message_not_modified_handler(result, data) + + async def __call__( + self, + client: Client, + message: Message, + *, + middleware: MiddlewareManager[str | None], + ) -> None: + data = { + "client": client, + "message": message, + } + try: + result = await middleware(self._invoke_with_timeout, data) + except Exception as e: + result = await self._exception_handler(e, data) + if not result: # empty string or None + return # nothing to send or edit + await self._result_handler(result, data) + + +_HT = TypeVar("_HT", bound=BaseHandler) + + +class BaseModule(Generic[_HT]): + """Base class for modules""" + + def __init__(self): + self._handlers: list[_HT] = [] + self._middleware: MiddlewareManager[str | None] = MiddlewareManager() + + def add_handler(self, handler: _HT) -> None: + self._handlers.append(handler) + + def add_submodule(self, module: "BaseModule[_HT]") -> None: + if module is self: + raise ValueError("Cannot add a module to itself") + self._handlers.extend(module._handlers) + if module._middleware.has_handlers: + raise NotImplementedError( + "Submodule has middlewares registered, this is not supported yet" + ) + + def add_middleware(self, middleware: Middleware[str | None]) -> None: + self._middleware.register(middleware) + + @abstractmethod + def _create_handlers_filters(self, handler: _HT) -> tuple[list[Type[Handler]], Filter]: + """Create Pyrogram handlers and filters for the given handler.""" + pass + + def register(self, client: Client) -> None: + for handler in self._handlers: + handlers, filters = self._create_handlers_filters(handler) + for handler_cls in handlers: + client.add_handler( + handler_cls( + async_partial(handler.__call__, middleware=self._middleware), + filters, + ), + ) diff --git a/userbot/modules/commands.py b/userbot/modules/commands.py index 399011f..3f7f0c8 100644 --- a/userbot/modules/commands.py +++ b/userbot/modules/commands.py @@ -1,46 +1,43 @@ from __future__ import annotations __all__ = [ - "CommandObject", "CommandsModule", + "CommandObject", ] -import asyncio +import functools import html import inspect import logging +import operator import re from dataclasses import dataclass from io import BytesIO from pathlib import Path from traceback import FrameSummary, extract_tb from types import TracebackType -from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeAlias, TypeVar +from typing import Any, Callable, Iterable, Type, TypeAlias, overload from httpx import AsyncClient, HTTPError -from pyrogram import Client -from pyrogram import filters as flt +from pyrogram import Client, filters from pyrogram.enums import ParseMode -from pyrogram.errors import MessageNotModified, MessageTooLong +from pyrogram.filters import Filter from pyrogram.handlers import EditedMessageHandler, MessageHandler +from pyrogram.handlers.handler import Handler from pyrogram.types import Message -from .. import __is_prod__ +from .. import is_prod from ..constants import DefaultIcons, Icons, PremiumIcons -from ..middleware_manager import Middleware, MiddlewareManager +from ..middlewares import icon_middleware, translate_middleware from ..translation import Translation -from ..utils import async_partial +from .base import BaseHandler, BaseModule, HandlerT -_DEFAULT_PREFIX = "." if __is_prod__ else "," +_DEFAULT_PREFIX = "." if is_prod else "," _DEFAULT_TIMEOUT = 30 -_PS = ParamSpec("_PS") -_RT = TypeVar("_RT") -_CommandT: TypeAlias = list[str] | re.Pattern[str] -_CommandHandlerT: TypeAlias = Callable[_PS, Awaitable[str | None]] - _log = logging.getLogger(__name__) -_nekobin = AsyncClient(base_url="https://nekobin.com/") + +_CommandT: TypeAlias = str | re.Pattern[str] def _extract_frames(traceback: TracebackType) -> tuple[FrameSummary, FrameSummary]: @@ -88,77 +85,10 @@ def _format_exception(exc: Exception) -> str: return res -async def _post_to_nekobin(text: str) -> str: - """Posts the text to nekobin. Returns the URL.""" - res = await _nekobin.post("/api/documents", json={"content": text}) - res.raise_for_status() - return f"{_nekobin.base_url.join(res.json()['result']['key'])}.html" - - -def _handle_command_exception(e: Exception, data: dict[str, Any]) -> str: - """Handles the exception raised by the command.""" - client: Client = data["client"] - message: Message = data["message"] - command: CommandObject = data["command"] - # In case one of the middlewares failed, we need to fill in the missing data with the fallback - # values - icons_default = PremiumIcons if client.me.is_premium else DefaultIcons - icons: Type[Icons] = data.setdefault("icons", icons_default) - tr: Translation = data.setdefault("tr", Translation(None)) - _ = tr.gettext - message_text = message.text - _log.exception( - "An error occurred during executing %r", - message_text, - extra={"command": command}, - ) - tb = _format_frames(*_extract_frames(e.__traceback__)) - tb += _format_exception(e) - tb = f"
{html.escape(tb)}
" - return _( - "{icon} An error occurred during executing command.\n\n" - "Command: {message_text}\n" - "Traceback:\n{tb}\n\n" - "More info can be found in logs." - ).format( - icon=icons.STOP, - message_text=html.escape(message_text), - tb=tb, - ) - - -async def _handle_message_too_long(result: str, data: dict[str, Any]) -> None: - """Handles the case when the result is too long to be sent.""" - message: Message = data["message"] - icons: Type[Icons] = data["icons"] - tr: Translation = data["tr"] - _ = tr.gettext - text = _( - "{icon} Successfully executed.\n\n" - "Command: {message_text}\n\n" - "Result:" - ).format(icon=icons.INFO, message_text=html.escape(message.text)) - try: - url = await _post_to_nekobin(result) - except HTTPError: - await message.edit( - _("{text} See reply.").format(text=text), - parse_mode=ParseMode.HTML, - ) - await message.reply_document( - BytesIO(result.encode("utf-8")), - file_name="result.html", - ) - else: - await message.edit( - f"{text} {url}", - parse_mode=ParseMode.HTML, - disable_web_page_preview=True, - ) - - @dataclass() class CommandObject: + """Represents a command object.""" + prefix: str command: str args: str @@ -174,177 +104,212 @@ def __str__(self) -> str: return f"{self.prefix}{self.command} {self.args}" -@dataclass() -class _CommandHandler: - commands: _CommandT - prefix: str - handler: _CommandHandlerT - usage: str - doc: str | None - category: str | None - hidden: bool - handle_edits: bool - waiting_message: str | None - timeout: int | None - - def __post_init__(self) -> None: +class CommandsHandler(BaseHandler): + def __init__( + self, + *, + commands: Iterable[_CommandT], + prefix: str, + handler: HandlerT, + usage: str, + doc: str | None, + category: str | None, + hidden: bool, + handle_edits: bool, + waiting_message: str | None, + timeout: int | None, + ) -> None: + if next(iter(commands), None) is None: + raise ValueError("No commands specified") + + super().__init__( + handler=handler, + handle_edits=handle_edits, + waiting_message=waiting_message, + timeout=timeout, + ) + self.commands = commands + self.prefix = prefix + self.usage = usage + self.doc = doc + self.category = category + self.hidden = hidden + if self.doc is not None: - self.doc = re.sub(r"\n(\n?)\s+", r"\n\1", self.doc).strip() - self._signature = inspect.signature(self.handler) + self.doc = self.doc.strip() + + def __repr__(self) -> str: + commands = self.commands + prefix = self.prefix + usage = self.usage + category = self.category + hidden = self.hidden + handle_edits = self.handle_edits + timeout = self.timeout + return ( + f"<{self.__class__.__name__}" + f" {commands=}" + f" {prefix=}" + f" {usage=}" + f" {category=}" + f" {hidden=}" + f" {handle_edits=}" + f" {timeout=}" + f">" + ) - def _parse_command(self, text: str) -> CommandObject: - """Parses the command from the text.""" - command, _, args = text.partition(" ") - prefix, command = command[0], command[1:] - if isinstance(self.commands, re.Pattern): - m = self.commands.match(command) - else: - m = None - return CommandObject(prefix=prefix, command=command, args=args, match=m) + def format_usage(self, *, full: bool = False) -> str: + """Formats the usage of the command.""" + commands: list[str] = [] + for command in self.commands: + match command: + case re.Pattern(pattern=pattern): + commands.append(pattern) + case str(): + commands.append(command) + case _: + raise AssertionError(f"Unexpected command type: {type(command)}") + commands_str = "|".join(commands) + usage = f" {self.usage}".rstrip() + doc = self.doc or "" + if not full: + doc = doc.strip().split("\n")[0].strip() + description = f" — {doc}" if self.doc else "" + return f"{commands_str}{usage}{description}" - async def _call_handler(self, data: dict[str, Any]) -> str | None: - """Filter data and call the handler.""" - suitable_kwargs = {} - for name, param in self._signature.parameters.items(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - suitable_kwargs = data # pass all kwargs - break - if name in data: - suitable_kwargs[name] = data[name] - return await self.handler(**suitable_kwargs) + async def _invoke_handler(self, data: dict[str, Any]) -> str | None: + data["command"] = self._parse_command(data["message"].text) + return await super()._invoke_handler(data) - async def _call_with_timeout(self, data: dict[str, Any]) -> str | None: - """Call the handler with a timeout.""" + async def _exception_handler(self, e: Exception, data: dict[str, Any]) -> str | None: + """Handles exceptions raised by the command handler.""" + client: Client = data["client"] + message: Message = data["message"] + # In case one of the middlewares failed, we need to fill in the missing data with + # the fallback values + if "icons" not in data: + data["icons"] = PremiumIcons if client.me.is_premium else DefaultIcons + if "tr" not in data: + data["tr"] = Translation(None) icons: Type[Icons] = data["icons"] tr: Translation = data["tr"] _ = tr.gettext - __ = tr.ngettext - waiting_task = asyncio.create_task(self._send_waiting_message(data)) - try: - return await asyncio.wait_for(self._call_handler(data), timeout=self.timeout) - except asyncio.TimeoutError as e: - message: Message = data["message"] - command: CommandObject = data["command"] - _log.warning( - "Command %r timed out", - message.text, - exc_info=e, - extra={"command": command}, - ) - return _( - "{icon} Command timed out after {timeout}.\n\n" - "Command: {message_text}\n\n" - "More info can be found in logs." - ).format( - icon=icons.STOP, - timeout=__("{timeout} second", "{timeout} seconds", self.timeout).format( - timeout=self.timeout, # cannot be None as TimeoutError is raised - ), - message_text=html.escape(message.text), - ) - finally: - waiting_task.cancel() + message_text = message.text + _log.exception( + "An error occurred during executing %r", + message_text, + extra={"data": data}, + ) + tb = _format_frames(*_extract_frames(e.__traceback__)) + tb += _format_exception(e) + tb = f"
{html.escape(tb)}
" + return _( + "{icon} An error occurred during executing command.\n\n" + "Command: {message_text}\n" + "Traceback:\n{tb}\n\n" + "More info can be found in logs." + ).format( + icon=icons.STOP, + message_text=html.escape(message_text), + tb=tb, + ) - async def _send_waiting_message(self, data: dict[str, Any]) -> None: - """Edit a message after some time to show that the bot is working on the command.""" - await asyncio.sleep(0.75) + async def _message_too_long_handler(self, result: str, data: dict[str, Any]) -> None: message: Message = data["message"] icons: Type[Icons] = data["icons"] tr: Translation = data["tr"] _ = tr.gettext - if self.waiting_message is not None: - # Waiting messages are marked for translation, so we need to translate them here. - text = _(self.waiting_message).strip() - else: - text = _("Executing {command}...").format( - command=html.escape(message.text), - ) - await message.edit_text(f"{icons.WATCH} {text}", parse_mode=ParseMode.HTML) - - async def __call__( - self, - client: Client, - message: Message, - *, - middleware: MiddlewareManager[str | None], - ) -> None: - """Entry point for the command handler. Edits and errors are handled here.""" - command = self._parse_command(message.text) - data = { - "client": client, - "message": message, - "command": command, - } - try: - result = await middleware(self._call_with_timeout, data) - except Exception as e: - result = _handle_command_exception(e, data) - if not result: # empty string or None - return # no reply - try: - await message.edit(result, parse_mode=ParseMode.HTML) - except MessageTooLong: - await _handle_message_too_long(result, data) - except MessageNotModified as e: - _log.warning( - "Message was not modified while executing %r", - message.text, - exc_info=e, - extra={"command": command}, - ) - if not __is_prod__: - # data was modified along the way, so we can access attributes from middlewares - icons: Type[Icons] = data["icons"] - tr: Translation = data["tr"] - _ = tr.gettext + text = _( + "{icon} Successfully executed.\n\n" + "Command: {message_text}\n\n" + "Result:" + ).format(icon=icons.INFO, message_text=html.escape(message.text)) + async with AsyncClient(base_url="https://nekobin.com/") as nekobin: + try: + res = await nekobin.post("/api/documents", json={"content": text}) + res.raise_for_status() + except HTTPError: + await message.edit( + _("{text} See reply.").format(text=text), + parse_mode=ParseMode.HTML, + ) + await message.reply_document( + BytesIO(result.encode("utf-8")), + file_name="result.html", + ) + else: + url = f"{nekobin.base_url.join(res.json()['result']['key'])}.html" await message.edit( - _( - "{result}\n\n" - "{icon} MessageNotModified was raised, check that there's only" - " one instance of userbot is running." - ).format(result=result, icon=icons.WARNING), + f"{text} {url}", + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, ) - def format_usage(self, *, full: bool = False) -> str: - match self.commands: - case re.Pattern(pattern=pattern): - commands = pattern - case list(commands): - commands = "|".join(commands) - case _: - raise AssertionError(f"Unexpected command type: {type(self.commands)}") - usage = f" {self.usage}".rstrip() - doc = self.doc or "" - if not full: - doc = doc.strip().split("\n")[0].strip() - description = f" — {doc}" if self.doc else "" - return f"{commands}{usage}{description}" + def _parse_command(self, text: str) -> CommandObject: + """Parses the command from the text.""" + command, _, args = text.partition(" ") + prefix, command = command[0], command[1:] + for cmd in self.commands: + if isinstance(cmd, re.Pattern): + m = cmd.match(command) + break + else: + m = None + return CommandObject(prefix=prefix, command=command, args=args, match=m) - def sort_key(self) -> tuple[str, str]: - """Return a key to sort the commands by.""" + @property + def _sort_key(self) -> tuple[str, str]: + """Return a key to use for sorting commands.""" category = self.category or "" - match self.commands: + command = next(iter(self.commands)) + match command: case re.Pattern(pattern=pattern): cmd = pattern - case list(commands): - cmd = commands[0] + case str(): + cmd = command case _: - raise AssertionError(f"Unexpected command type: {type(self.commands)}") + raise AssertionError(f"Unexpected command type: {type(command)}") return category, cmd + # Implemented for sorting with `sorted(...)` + def __lt__(self, other: CommandsHandler) -> bool | NotImplemented: + if not isinstance(other, CommandsHandler): + return NotImplemented + return self._sort_key < other._sort_key + -class CommandsModule: - def __init__(self, category: str | None = None): - self._handlers: list[_CommandHandler] = [] +class CommandsModule(BaseModule[CommandsHandler]): + def __init__(self, category: str | None = None, *, root: bool = False): + super().__init__() self._category = category - self._middleware: MiddlewareManager[str | None] = MiddlewareManager() + self._root = root - def add_handler( + @overload + def add( self, - handler: _CommandHandlerT, - command: str | _CommandT, + command: _CommandT, + /, + *commands: _CommandT, + prefix: str = _DEFAULT_PREFIX, + usage: str = "", + doc: str | None = None, + category: str | None = None, + hidden: bool = False, + handle_edits: bool = True, + waiting_message: str | None = None, + timeout: int | None = _DEFAULT_TIMEOUT, + ) -> Callable[[HandlerT], HandlerT]: + # usage as a decorator + pass + + @overload + def add( + self, + callable_: HandlerT, + command: _CommandT, + /, + *commands: _CommandT, prefix: str = _DEFAULT_PREFIX, - *, usage: str = "", doc: str | None = None, category: str | None = None, @@ -353,28 +318,15 @@ def add_handler( waiting_message: str | None = None, timeout: int | None = _DEFAULT_TIMEOUT, ) -> None: - if isinstance(command, str): - command = [command] - self._handlers.append( - _CommandHandler( - commands=command, - prefix=prefix, - handler=handler, - usage=usage, - doc=doc or inspect.unwrap(handler).__doc__, - category=category or self._category, - hidden=hidden, - handle_edits=handle_edits, - waiting_message=waiting_message, - timeout=timeout, - ) - ) + # usage as a function + pass def add( self, - command: str | _CommandT, + command_or_callable: _CommandT | HandlerT, + /, + *commands: _CommandT, prefix: str = _DEFAULT_PREFIX, - *, usage: str = "", doc: str | None = None, category: str | None = None, @@ -382,90 +334,78 @@ def add( handle_edits: bool = True, waiting_message: str | None = None, timeout: int | None = _DEFAULT_TIMEOUT, - ) -> Callable[[_CommandHandlerT], _CommandHandlerT]: - def decorator(f: _CommandHandlerT) -> _CommandHandlerT: - self.add_handler( - handler=f, - command=command, - prefix=prefix, - usage=usage, - doc=doc, - category=category, - hidden=hidden, - handle_edits=handle_edits, - waiting_message=waiting_message, - timeout=timeout, - ) - return f + ) -> Callable[[HandlerT], HandlerT] | None: + """Registers a command handler. Can be used as a decorator or as a registration function. - return decorator + Example: + Example usage as a decorator:: - def add_middleware(self, middleware: Middleware[str | None]) -> None: - self._middleware.register(middleware) + commands = CommandsModule() + @commands.add("start", "help") + async def start_command(message: Message, command: CommandObject) -> str: + return "Hello!" - def add_submodule(self, module: CommandsModule) -> None: - self._handlers.extend(module._handlers) - if module._middleware.has_handlers: - raise ValueError("Submodule has middleware, which is not supported") + Example usage as a function:: - def register( - self, - client: Client, - *, - with_help: bool = False, - ): - if with_help: + commands = CommandsModule() + def start_command(message: Message, command: CommandObject) -> str: + return "Hello!" + commands.add(start_command, "start", "help") + + Returns: + The original handler function if used as a decorator, otherwise `None`. + """ + + def decorator(handler: HandlerT) -> HandlerT: self.add_handler( - self._help_handler, - command="help", - usage="[command]", - doc="Sends help for all commands or for a specific one", - category="About", - ) - all_commands = set() - for handler in self._handlers: - if isinstance(handler.commands, re.Pattern): - command_re = re.compile( - f"^[{re.escape(handler.prefix)}]{handler.commands.pattern}", - flags=handler.commands.flags, + CommandsHandler( + commands=commands, + prefix=prefix, + handler=handler, + usage=usage, + doc=doc or inspect.getdoc(inspect.unwrap(handler)), + category=category or self._category, + hidden=hidden, + handle_edits=handle_edits, + waiting_message=waiting_message, + timeout=timeout, ) - f = flt.regex(command_re) - elif isinstance(handler.commands, list): - for c in (commands := handler.commands): - if c in all_commands: - raise ValueError(f"Duplicate command detected: {c}") - all_commands.update(commands) - f = flt.command(handler.commands, prefixes=handler.prefix) - else: - raise AssertionError(f"Unexpected command type: {type(handler.commands)}") - f &= flt.me & ~flt.scheduled - callback = async_partial(handler.__call__, middleware=self._middleware) - client.add_handler(MessageHandler(callback, f)) - if handler.handle_edits: - client.add_handler(EditedMessageHandler(callback, f)) + ) + return handler + + if callable(command_or_callable): + if len(commands) == 0: + raise ValueError("No commands specified") + decorator(command_or_callable) + return + + commands = (command_or_callable, *commands) + return decorator async def _help_handler(self, command: CommandObject, tr: Translation) -> str: + """Sends help for all commands or for a specific one""" _ = tr.gettext if args := command.args: for h in self._handlers: - match h.commands: - case re.Pattern() as pattern: - matches = pattern.fullmatch(args) is not None - case list(cmds): - matches = args in cmds - case _: - raise AssertionError(f"Unexpected command type: {type(h.commands)}") - if matches: - usage = h.format_usage(full=True) - return _("Help for {args}:\n{usage}").format( - args=html.escape(args), - usage=html.escape(usage), - ) + for cmd in h.commands: + match cmd: + case re.Pattern() as pattern: + matches = pattern.fullmatch(args) is not None + case str(): + matches = cmd == args + case _: + raise AssertionError(f"Unexpected command type: {type(cmd)}") + if matches: + usage = h.format_usage(full=True) + return _("Help for {args}:\n{usage}").format( + args=html.escape(args), + usage=html.escape(usage), + ) else: return f"No help found for {args}" text = _("List of userbot commands available:\n\n") prev_cat = "" - for handler in sorted(self._handlers, key=_CommandHandler.sort_key): + for handler in sorted(self._handlers): if handler.hidden: continue usage = handler.format_usage() @@ -476,3 +416,49 @@ async def _help_handler(self, command: CommandObject, tr: Translation) -> str: # This will happen if there are no handlers without category text = text.replace("\n\n\n", "\n\n") return text + + def _check_duplicates(self) -> None: + """Checks for duplicate commands and raises an error if any are found.""" + commands = set() + for handler in self._handlers: + for cmd in handler.commands: + if cmd in commands: + raise ValueError(f"Duplicate command: {cmd}") + commands.add(cmd) + + def _create_handlers_filters( + self, + handler: CommandsHandler, + ) -> tuple[list[Type[Handler]], Filter]: + f: list[Filter] = [] + for cmd in handler.commands: + if isinstance(cmd, re.Pattern): + command_re = re.compile( + f"^[{re.escape(handler.prefix)}]{cmd.pattern}", + flags=cmd.flags, + ) + f.append(filters.regex(command_re)) + elif isinstance(cmd, str): + f.append(filters.command(cmd, prefixes=handler.prefix)) + else: + raise AssertionError(f"Unexpected command type: {type(cmd)}") + h: list[Type[Handler]] = [MessageHandler] + if handler.handle_edits: + h.append(EditedMessageHandler) + return h, functools.reduce(operator.or_, f) & filters.me & ~filters.scheduled + + def register(self, client: Client) -> None: + if self._root: + self.add( + self._help_handler, + "help", + usage="[command]", + category="About", + ) + # These middlewares are expected by the base module to be registered + if icon_middleware not in self._middleware: + self.add_middleware(icon_middleware) + if translate_middleware not in self._middleware: + self.add_middleware(translate_middleware) + self._check_duplicates() + super().register(client) diff --git a/userbot/modules/hooks.py b/userbot/modules/hooks.py index 42079e9..3f1d36d 100644 --- a/userbot/modules/hooks.py +++ b/userbot/modules/hooks.py @@ -2,20 +2,27 @@ "HooksModule", ] -from dataclasses import dataclass -from typing import Awaitable, Callable +from typing import Any, Callable, Type, overload from pyrogram import Client -from pyrogram import filters as flt +from pyrogram import filters as pyrogram_filters from pyrogram.handlers import EditedMessageHandler, MessageHandler +from pyrogram.handlers.handler import Handler from pyrogram.types import Message from ..storage import Storage from ..translation import Translation -from ..utils import async_partial -from .commands import CommandsModule +from . import CommandsModule +from .base import BaseHandler, BaseModule, HandlerT -_HandlerT = Callable[[Client, Message], Awaitable[None]] + +class _HookEnabledFilter(pyrogram_filters.Filter): + def __init__(self, hook_name: str, storage: Storage) -> None: + self.storage = storage + self.hook_name = hook_name + + async def __call__(self, client: Client, message: Message) -> bool: + return await self.storage.is_hook_enabled(self.hook_name, message.chat.id) async def _list_enabled_hooks(message: Message, storage: Storage, tr: Translation) -> str: @@ -27,97 +34,175 @@ async def _list_enabled_hooks(message: Message, storage: Storage, tr: Translatio return _("Hooks in this chat:\n{hooks}").format(hooks=hooks) -@dataclass() -class _HookHandler: - name: str - filters: flt.Filter - handler: _HandlerT - handle_edits: bool +class HooksHandler(BaseHandler): + def __init__( + self, + *, + name: str, + filters: pyrogram_filters.Filter, + handler: HandlerT, + handle_edits: bool, + ) -> None: + super().__init__( + handler=handler, + handle_edits=handle_edits, + waiting_message=None, + timeout=None, + ) + self.name = name + self.filters = filters + + def __repr__(self) -> str: + name = self.name + handle_edits = self.handle_edits + return f"<{self.__class__.__name__} {name=} {handle_edits=}>" async def add_handler(self, message: Message, storage: Storage) -> None: await storage.enable_hook(self.name, message.chat.id) await message.delete() - async def del_handler(self, message: Message, storage: Storage) -> None: + async def remove_handler(self, message: Message, storage: Storage) -> None: await storage.disable_hook(self.name, message.chat.id) await message.delete() - async def __call__(self, client: Client, message: Message, storage: Storage) -> None: - if await storage.is_hook_enabled(self.name, message.chat.id): - await self.handler(client, message) + async def _send_waiting_message(self, data: dict[str, Any]) -> None: + return -class HooksModule: - def __init__(self): - self._handlers: list[_HookHandler] = [] +class HooksModule(BaseModule): + def __init__( + self, + commands: CommandsModule | None = None, + storage: Storage | None = None, + ) -> None: + super().__init__() + self.commands = commands + self.storage = storage + @overload def add( self, name: str, - filters: flt.Filter, + filters_: pyrogram_filters.Filter, + none: None = None, + /, *, handle_edits: bool = False, - ) -> Callable[[_HandlerT], _HandlerT]: - def _decorator(f: _HandlerT) -> _HandlerT: - self.add_handler( - handler=f, - name=name, - filters=filters, - handle_edits=handle_edits, - ) - return f - - return _decorator + ) -> Callable[[HandlerT], HandlerT]: + # usage as a decorator + pass - def add_handler( + @overload + def add( self, - handler: _HandlerT, + callable_: HandlerT, name: str, - filters: flt.Filter, + filters_: pyrogram_filters.Filter, + /, *, handle_edits: bool = False, ) -> None: - self._handlers.append( - _HookHandler( - name=name, - filters=filters, - handler=handler, - handle_edits=handle_edits, + # usage as a function + pass + + def add( + self, + callable_or_name: HandlerT | str, + name_or_filters: str | pyrogram_filters.Filter, + filters_: pyrogram_filters.Filter | None = None, + /, + *, + handle_edits: bool = False, + ) -> Callable[[HandlerT], HandlerT] | None: + """Registers a hook handler. Can be used as a decorator or as a registration function. + + Example: + Example usage as a decorator:: + + hooks = HooksModule() + @hooks.add("hello", filters.regex(r"(hello|hi)")) + async def hello_hook(message: Message, command: CommandObject) -> str: + return "Hello!" + + Example usage as a function:: + + hooks = HooksModule() + def hello_hook(message: Message, command: CommandObject) -> str: + return "Hello!" + hooks.add(hello_hook, "hello", filters.regex(r"(hello|hi)")) + + Returns: + The original handler function if used as a decorator, otherwise `None`. + """ + + def decorator(handler: HandlerT) -> HandlerT: + self.add_handler( + HooksHandler( + name=name_or_filters, + filters=filters_, + handler=handler, + handle_edits=handle_edits, + ) ) - ) + return handler + + if callable(callable_or_name): + if filters_ is None: + raise TypeError("No filters specified") + decorator(callable_or_name) + return - def add_submodule(self, module: "HooksModule") -> None: - self._handlers.extend(module._handlers) + filters_ = name_or_filters + name_or_filters = callable_or_name + return decorator - def register(self, client: Client, storage: Storage, commands: CommandsModule) -> None: - cmds = CommandsModule("Hooks") + def _create_handlers_filters( + self, + handler: HooksHandler, + ) -> tuple[list[Type[Handler]], pyrogram_filters.Filter]: + h: list[Type[Handler]] = [MessageHandler] + if handler.handle_edits: + h.append(EditedMessageHandler) + return ( + h, + pyrogram_filters.incoming + & _HookEnabledFilter(handler.name, self.storage) + & handler.filters, + ) + + def register(self, client: Client) -> None: + if self.commands is None: + raise RuntimeError("Please set commands attribute before registering hooks module") + if self.storage is None: + raise RuntimeError("Please set storage attribute before registering hooks module") + commands = CommandsModule("Hooks") + commands.add( + _list_enabled_hooks, + "hookshere", + "hooks_here", + ) + commands.add( + self._list_hooks, + "hooklist", + "hook_list", + ) for handler in self._handlers: - cmds.add_handler( + commands.add( handler.add_handler, - [f"{handler.name}here", f"{handler.name}_here"], + f"{handler.name}here", + f"{handler.name}_here", doc=f"Enable {handler.name} hook for this chat", hidden=True, ) - cmds.add_handler( - handler.del_handler, - [f"no{handler.name}here", f"no_{handler.name}_here"], + commands.add( + handler.remove_handler, + f"no{handler.name}here", + f"no_{handler.name}_here", doc=f"Disable {handler.name} hook for this chat", hidden=True, ) - f = flt.incoming & handler.filters - callback = async_partial(handler, storage=storage) - client.add_handler(MessageHandler(callback, f)) - if handler.handle_edits: - client.add_handler(EditedMessageHandler(callback, f)) - cmds.add_handler( - _list_enabled_hooks, - ["hookshere", "hooks_here"], - ) - cmds.add_handler( - self._list_hooks, - ["hooklist", "hook_list"], - ) - commands.add_submodule(cmds) + super().register(client) + self.commands.add_submodule(commands) async def _list_hooks(self, tr: Translation) -> str: """List all available hooks""" diff --git a/userbot/modules/shortcuts.py b/userbot/modules/shortcuts.py index 10ec63b..ccd1b81 100644 --- a/userbot/modules/shortcuts.py +++ b/userbot/modules/shortcuts.py @@ -1,81 +1,154 @@ __all__ = [ - "ShortcutTransformersModule", + "ShortcutsModule", ] import re -from dataclasses import dataclass -from typing import Awaitable, Callable +from typing import Any, Callable, Type, overload -from pyrogram import Client -from pyrogram import filters as flt -from pyrogram.enums import ParseMode +from pyrogram import ContinuePropagation, filters from pyrogram.handlers import EditedMessageHandler, MessageHandler +from pyrogram.handlers.handler import Handler from pyrogram.types import Message -_TransformHandlerT = Callable[[re.Match[str]], Awaitable[str]] +from .base import BaseHandler, BaseModule, HandlerT +_DEFAULT_TIMEOUT = 5 -@dataclass() -class _ShortcutHandler: - regex: re.Pattern[str] - handler: _TransformHandlerT - handle_edits: bool - async def __call__(self, client: Client, message: Message): +class ShortcutsHandler(BaseHandler): + def __init__( + self, + *, + pattern: re.Pattern[str], + handler: HandlerT, + handle_edits: bool, + waiting_message: str | None, + timeout: int | None, + ) -> None: + super().__init__( + handler=handler, + handle_edits=handle_edits, + waiting_message=waiting_message, + timeout=timeout, + ) + self.pattern = pattern + + def __repr__(self) -> str: + pattern = self.pattern + handle_edits = self.handle_edits + timeout = self.timeout + return f"<{self.__class__.__name__} {pattern=} {handle_edits=} {timeout=}>" + + async def _invoke_handler(self, data: dict[str, Any]) -> str | None: + message: Message = data["message"] raw_text = message.text or message.caption if raw_text is None: return text = raw_text.html - while match := self.regex.search(text): - if (result := await self.handler(match)) is not None: + while match := self.pattern.search(text): + data["match"] = match + if (result := await super()._invoke_handler(data)) is not None: text = f"{text[:match.start()]}{result}{text[match.end():]}" - await message.edit(text, parse_mode=ParseMode.HTML) - message.continue_propagation() # allow other shortcut handlers to run + return text + + async def _timed_out_handler(self, data: dict[str, Any]) -> str | None: + raise + async def _result_handler(self, result: str, data: dict[str, Any]) -> None: + await super()._result_handler(result, data) + raise ContinuePropagation # allow other shortcut handlers to run -class ShortcutTransformersModule: - def __init__(self): - self._handlers: list[_ShortcutHandler] = [] +class ShortcutsModule(BaseModule): + @overload def add( self, pattern: str | re.Pattern[str], + none: None = None, + /, *, handle_edits: bool = True, - ) -> Callable[[_TransformHandlerT], _TransformHandlerT]: - def _decorator(f: _TransformHandlerT) -> _TransformHandlerT: - self.add_handler( - handler=f, - pattern=pattern, - handle_edits=handle_edits, - ) - return f + waiting_message: str | None = None, + timeout: int | None = _DEFAULT_TIMEOUT, + ) -> Callable[[HandlerT], HandlerT]: + # usage as a decorator + pass - return _decorator - - def add_handler( + @overload + def add( self, - handler: _TransformHandlerT, + callable_: HandlerT, pattern: str | re.Pattern[str], + /, *, handle_edits: bool = True, + waiting_message: str | None = None, + timeout: int | None = _DEFAULT_TIMEOUT, ) -> None: - if isinstance(pattern, str): - pattern = re.compile(pattern) - self._handlers.append( - _ShortcutHandler( - regex=pattern, - handler=handler, - handle_edits=handle_edits, + # usage as a function + pass + + def add( + self, + callable_or_pattern: HandlerT | str | re.Pattern[str], + pattern: str | re.Pattern[str] | None = None, + /, + *, + handle_edits: bool = True, + waiting_message: str | None = None, + timeout: int | None = _DEFAULT_TIMEOUT, + ) -> Callable[[HandlerT], HandlerT] | None: + """Registers a shortcut handler. Can be used as a decorator or as a registration function. + + Example: + Example usage as a decorator:: + + shortcuts = ShortcutsModule() + @shortcuts.add(r"(hello|hi)") + async def hello_shortcut(message: Message, command: CommandObject) -> str: + return "Hello!" + + Example usage as a function:: + + shortcuts = ShortcutsModule() + def hello_shortcut(message: Message, command: CommandObject) -> str: + return "Hello!" + shortcuts.add(hello_shortcut, r"(hello|hi)") + + Returns: + The original handler function if used as a decorator, otherwise `None`. + """ + + def decorator(handler: HandlerT) -> HandlerT: + if isinstance(pattern, str): + p = re.compile(pattern) + else: + p = pattern + self.add_handler( + ShortcutsHandler( + pattern=p, + handler=handler, + handle_edits=handle_edits, + waiting_message=waiting_message, + timeout=timeout, + ) ) - ) + return handler - def add_submodule(self, module: "ShortcutTransformersModule") -> None: - self._handlers.extend(module._handlers) + if callable(callable_or_pattern): + if pattern is None: + raise ValueError("No pattern specified") + decorator(callable_or_pattern) + return + + pattern = callable_or_pattern + return decorator - def register(self, client: Client) -> None: - for handler in self._handlers: - f = flt.outgoing & ~flt.scheduled & flt.regex(handler.regex) - client.add_handler(MessageHandler(handler.__call__, f)) - if handler.handle_edits: - client.add_handler(EditedMessageHandler(handler.__call__, f)) + def _create_handlers_filters( + self, + handler: ShortcutsHandler, + ) -> tuple[list[Type[Handler]], filters.Filter]: + h: list[Type[Handler]] = [MessageHandler] + if handler.handle_edits: + h.append(EditedMessageHandler) + return h, filters.outgoing & ~filters.scheduled & filters.regex(handler.pattern) diff --git a/userbot/shortcuts.py b/userbot/shortcuts.py index c813ece..f316167 100644 --- a/userbot/shortcuts.py +++ b/userbot/shortcuts.py @@ -1,6 +1,4 @@ __all__ = [ - "get_note", - "github", "shortcuts", ] @@ -9,11 +7,12 @@ from dataclasses import dataclass from urllib.parse import quote_plus -from .modules import ShortcutTransformersModule +from .constants import GH_PATTERN +from .modules import ShortcutsModule from .storage import Storage from .utils import GitHubClient -shortcuts = ShortcutTransformersModule() +shortcuts = ShortcutsModule() @dataclass() @@ -39,7 +38,8 @@ async def mention(match: re.Match[str]) -> str: return f"{match[2] or match[1]}" -async def github(match: re.Match[str], *, github_client: GitHubClient) -> str: +@shortcuts.add(GH_PATTERN) +async def github(match: re.Match[str], github_client: GitHubClient) -> str: """Sends a link to a GitHub repository""" m = GitHubMatch(**match.groupdict()) url = f"https://github.com/{m.username}" @@ -99,6 +99,7 @@ async def shrug(_: re.Match[str]) -> str: return "¯\\_(ツ)_/¯" +@shortcuts.add(r"n://(.+?)/") async def get_note(match: re.Match[str], *, storage: Storage) -> str: """Sends a saved note""" note = await storage.get_note(match[1])