diff --git a/discord/__init__.py b/discord/__init__.py new file mode 100644 index 0000000..62531cd --- /dev/null +++ b/discord/__init__.py @@ -0,0 +1,82 @@ +""" +Discord API Wrapper +~~~~~~~~~~~~~~~~~~~ + +A basic wrapper for the Discord API. + +:copyright: (c) 2015-2021 Rapptz & (c) 2021-present Pycord Development +:license: MIT, see LICENSE for more details. +""" + +__title__ = "pycord" +__author__ = "Pycord Development" +__license__ = "MIT" +__copyright__ = "Copyright 2015-2021 Rapptz & Copyright 2021-present Pycord Development" +__version__ = "2.2.2" + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +import logging +from typing import Literal, NamedTuple + +from . import abc, opus, sinks, ui, utils +from .activity import * +from .appinfo import * +from .asset import * +from .audit_logs import * +from .automod import * +from .bot import * +from .channel import * +from .client import * +from .cog import * +from .colour import * +from .commands import * +from .components import * +from .embeds import * +from .emoji import * +from .enums import * +from .errors import * +from .file import * +from .flags import * +from .guild import * +from .http import * +from .integrations import * +from .interactions import * +from .invite import * +from .member import * +from .mentions import * +from .message import * +from .object import * +from .partial_emoji import * +from .permissions import * +from .player import * +from .raw_models import * +from .reaction import * +from .role import * +from .scheduled_events import * +from .shard import * +from .stage_instance import * +from .sticker import * +from .team import * +from .template import * +from .threads import * +from .user import * +from .voice_client import * +from .webhook import * +from .welcome_screen import * +from .widget import * + + +class VersionInfo(NamedTuple): + major: int + minor: int + micro: int + releaselevel: Literal["alpha", "beta", "candidate", "final"] + serial: int + + +version_info: VersionInfo = VersionInfo( + major=2, minor=2, micro=2, releaselevel="final", serial=0 +) + +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/__main__.py b/discord/__main__.py new file mode 100644 index 0000000..12b5b5a --- /dev/null +++ b/discord/__main__.py @@ -0,0 +1,376 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import argparse +import platform +import sys +from pathlib import Path +from typing import Tuple + +import aiohttp +import pkg_resources + +import discord + + +def show_version() -> None: + entries = [ + "- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format( + sys.version_info + ) + ] + + version_info = discord.version_info + entries.append( + "- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info) + ) + if version_info.releaselevel != "final": + pkg = pkg_resources.get_distribution("py-cord") + if pkg: + entries.append(f" - py-cord pkg_resources: v{pkg.version}") + + entries.append(f"- aiohttp v{aiohttp.__version__}") + uname = platform.uname() + entries.append("- system info: {0.system} {0.release} {0.version}".format(uname)) + print("\n".join(entries)) + + +def core(parser, args) -> None: + if args.version: + show_version() + + +_bot_template = """#!/usr/bin/env python3 + +from discord.ext import commands +import discord +import config + +class Bot(commands.{base}): + def __init__(self, **kwargs): + super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs) + for cog in config.cogs: + try: + self.load_extension(cog) + except Exception as exc: + print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}') + + async def on_ready(self): + print(f'Logged on as {{self.user}} (ID: {{self.user.id}})') + + +bot = Bot() + +# write general commands here + +bot.run(config.token) +""" + +_gitignore_template = """# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Our configuration files +config.py +""" + +_cog_template = '''from discord.ext import commands +import discord + +class {name}(commands.Cog{attrs}): + """The description for {name} goes here.""" + + def __init__(self, bot): + self.bot = bot +{extra} +def setup(bot): + bot.add_cog({name}(bot)) +''' + +_cog_extras = """ + def cog_unload(self): + # clean up logic goes here + pass + + async def cog_check(self, ctx): + # checks that apply to every command in here + return True + + async def bot_check(self, ctx): + # checks that apply to every command to the bot + return True + + async def bot_check_once(self, ctx): + # check that apply to every command but is guaranteed to be called only once + return True + + async def cog_command_error(self, ctx, error): + # error handling to every command in here + pass + + async def cog_before_invoke(self, ctx): + # called before a command is called here + pass + + async def cog_after_invoke(self, ctx): + # called after a command is called here + pass + +""" + + +# certain file names and directory names are forbidden +# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx +# although some of this doesn't apply to Linux, we might as well be consistent +_base_table = { + "<": "-", + ">": "-", + ":": "-", + '"': "-", + # '/': '-', these are fine + # '\\': '-', + "|": "-", + "?": "-", + "*": "-", +} + +# NUL (0) and 1-31 are disallowed +_base_table.update((chr(i), None) for i in range(32)) + +_translation_table = str.maketrans(_base_table) + + +def to_path(parser, name, *, replace_spaces=False) -> Path: + if isinstance(name, Path): + return name + + if sys.platform == "win32": + forbidden = ( + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + ) + if len(name) <= 4 and name.upper() in forbidden: + parser.error("invalid directory name given, use a different one") + + name = name.translate(_translation_table) + if replace_spaces: + name = name.replace(" ", "-") + return Path(name) + + +def newbot(parser, args) -> None: + new_directory = to_path(parser, args.directory) / to_path(parser, args.name) + + # as a note exist_ok for Path is a 3.5+ only feature + # since we already checked above that we're >3.5 + try: + new_directory.mkdir(exist_ok=True, parents=True) + except OSError as exc: + parser.error(f"could not create our bot directory ({exc})") + + cogs = new_directory / "cogs" + + try: + cogs.mkdir(exist_ok=True) + init = cogs / "__init__.py" + init.touch() + except OSError as exc: + print(f"warning: could not create cogs directory ({exc})") + + try: + with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp: + fp.write('token = "place your token here"\ncogs = []\n') + except OSError as exc: + parser.error(f"could not create config file ({exc})") + + try: + with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp: + base = "Bot" if not args.sharded else "AutoShardedBot" + fp.write(_bot_template.format(base=base, prefix=args.prefix)) + except OSError as exc: + parser.error(f"could not create bot file ({exc})") + + if not args.no_git: + try: + with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp: + fp.write(_gitignore_template) + except OSError as exc: + print(f"warning: could not create .gitignore file ({exc})") + + print("successfully made bot at", new_directory) + + +def newcog(parser, args) -> None: + cog_dir = to_path(parser, args.directory) + try: + cog_dir.mkdir(exist_ok=True) + except OSError as exc: + print(f"warning: could not create cogs directory ({exc})") + + directory = cog_dir / to_path(parser, args.name) + directory = directory.with_suffix(".py") + try: + with open(str(directory), "w", encoding="utf-8") as fp: + attrs = "" + extra = _cog_extras if args.full else "" + if args.class_name: + name = args.class_name + else: + name = str(directory.stem) + if "-" in name or "_" in name: + translation = str.maketrans("-_", " ") + name = name.translate(translation).title().replace(" ", "") + else: + name = name.title() + + if args.display_name: + attrs += f', name="{args.display_name}"' + if args.hide_commands: + attrs += ", command_attrs=dict(hidden=True)" + fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) + except OSError as exc: + parser.error(f"could not create cog file ({exc})") + else: + print("successfully made cog at", directory) + + +def add_newbot_args(subparser: argparse._SubParsersAction) -> None: + parser = subparser.add_parser( + "newbot", help="creates a command bot project quickly" + ) + parser.set_defaults(func=newbot) + + parser.add_argument("name", help="the bot project name") + parser.add_argument( + "directory", + help="the directory to place it in (default: .)", + nargs="?", + default=Path.cwd(), + ) + parser.add_argument( + "--prefix", help="the bot prefix (default: $)", default="$", metavar="" + ) + parser.add_argument( + "--sharded", help="whether to use AutoShardedBot", action="store_true" + ) + parser.add_argument( + "--no-git", + help="do not create a .gitignore file", + action="store_true", + dest="no_git", + ) + + +def add_newcog_args(subparser: argparse._SubParsersAction) -> None: + parser = subparser.add_parser("newcog", help="creates a new cog template quickly") + parser.set_defaults(func=newcog) + + parser.add_argument("name", help="the cog name") + parser.add_argument( + "directory", + help="the directory to place it in (default: cogs)", + nargs="?", + default=Path("cogs"), + ) + parser.add_argument( + "--class-name", + help="the class name of the cog (default: )", + dest="class_name", + ) + parser.add_argument("--display-name", help="the cog name (default: )") + parser.add_argument( + "--hide-commands", + help="whether to hide all commands in the cog", + action="store_true", + ) + parser.add_argument( + "--full", help="add all special methods as well", action="store_true" + ) + + +def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]: + parser = argparse.ArgumentParser( + prog="discord", description="Tools for helping with Pycord" + ) + parser.add_argument( + "-v", "--version", action="store_true", help="shows the library version" + ) + parser.set_defaults(func=core) + + subparser = parser.add_subparsers(dest="subcommand", title="subcommands") + add_newbot_args(subparser) + add_newcog_args(subparser) + return parser, parser.parse_args() + + +def main() -> None: + parser, args = parse_args() + args.func(parser, args) + + +if __name__ == "__main__": + main() diff --git a/discord/abc.py b/discord/abc.py new file mode 100644 index 0000000..ce0806b --- /dev/null +++ b/discord/abc.py @@ -0,0 +1,1920 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import copy +import time +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Protocol, + Sequence, + TypeVar, + Union, + overload, + runtime_checkable, +) + +from . import utils +from .context_managers import Typing +from .enums import ChannelType +from .errors import ClientException, InvalidArgument +from .file import File +from .flags import MessageFlags +from .invite import Invite +from .iterators import HistoryIterator +from .mentions import AllowedMentions +from .permissions import PermissionOverwrite, Permissions +from .role import Role +from .scheduled_events import ScheduledEvent +from .sticker import GuildSticker, StickerItem +from .voice_client import VoiceClient, VoiceProtocol + +__all__ = ( + "Snowflake", + "User", + "PrivateChannel", + "GuildChannel", + "Messageable", + "Connectable", + "Mentionable", +) + +T = TypeVar("T", bound=VoiceProtocol) + +if TYPE_CHECKING: + from datetime import datetime + + from .asset import Asset + from .channel import ( + CategoryChannel, + DMChannel, + GroupChannel, + PartialMessageable, + TextChannel, + VoiceChannel, + ) + from .client import Client + from .embeds import Embed + from .enums import InviteTarget + from .flags import ChannelFlags + from .guild import Guild + from .member import Member + from .message import Message, MessageReference, PartialMessage + from .state import ConnectionState + from .threads import Thread + from .types.channel import Channel as ChannelPayload + from .types.channel import GuildChannel as GuildChannelPayload + from .types.channel import OverwriteType + from .types.channel import PermissionOverwrite as PermissionOverwritePayload + from .ui.view import View + from .user import ClientUser + + PartialMessageableChannel = Union[ + TextChannel, VoiceChannel, Thread, DMChannel, PartialMessageable + ] + MessageableChannel = Union[PartialMessageableChannel, GroupChannel] + SnowflakeTime = Union["Snowflake", datetime] + +MISSING = utils.MISSING + + +async def _single_delete_strategy( + messages: Iterable[Message], *, reason: str | None = None +): + for m in messages: + await m.delete(reason=reason) + + +async def _purge_messages_helper( + channel: TextChannel | Thread | VoiceChannel, + *, + limit: int | None = 100, + check: Callable[[Message], bool] = MISSING, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = False, + bulk: bool = True, + reason: str | None = None, +) -> list[Message]: + if check is MISSING: + check = lambda m: True + + iterator = channel.history( + limit=limit, + before=before, + after=after, + oldest_first=oldest_first, + around=around, + ) + ret: list[Message] = [] + count = 0 + + minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 + strategy = channel.delete_messages if bulk else _single_delete_strategy + + async for message in iterator: + if count == 100: + to_delete = ret[-100:] + await strategy(to_delete, reason=reason) + count = 0 + await asyncio.sleep(1) + + if not check(message): + continue + + if message.id < minimum_time: + # older than 14 days old + if count == 1: + await ret[-1].delete(reason=reason) + elif count >= 2: + to_delete = ret[-count:] + await strategy(to_delete, reason=reason) + + count = 0 + strategy = _single_delete_strategy + + count += 1 + ret.append(message) + + # Some messages remaining to poll + if count >= 2: + # more than 2 messages -> bulk delete + to_delete = ret[-count:] + await strategy(to_delete, reason=reason) + elif count == 1: + # delete a single message + await ret[-1].delete(reason=reason) + + return ret + + +@runtime_checkable +class Snowflake(Protocol): + """An ABC that details the common operations on a Discord model. + + Almost all :ref:`Discord models ` meet this + abstract base class. + + If you want to create a snowflake on your own, consider using + :class:`.Object`. + + Attributes + ---------- + id: :class:`int` + The model's unique ID. + """ + + __slots__ = () + id: int + + +@runtime_checkable +class User(Snowflake, Protocol): + """An ABC that details the common operations on a Discord user. + + The following implement this ABC: + + - :class:`~discord.User` + - :class:`~discord.ClientUser` + - :class:`~discord.Member` + + This ABC must also implement :class:`~discord.abc.Snowflake`. + + Attributes + ---------- + name: :class:`str` + The user's username. + discriminator: :class:`str` + The user's discriminator. + avatar: :class:`~discord.Asset` + The avatar asset the user has. + bot: :class:`bool` + If the user is a bot account. + """ + + __slots__ = () + + name: str + discriminator: str + avatar: Asset + bot: bool + + @property + def display_name(self) -> str: + """:class:`str`: Returns the user's display name.""" + raise NotImplementedError + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given user.""" + raise NotImplementedError + + +@runtime_checkable +class PrivateChannel(Snowflake, Protocol): + """An ABC that details the common operations on a private Discord channel. + + The following implement this ABC: + + - :class:`~discord.DMChannel` + - :class:`~discord.GroupChannel` + + This ABC must also implement :class:`~discord.abc.Snowflake`. + + Attributes + ---------- + me: :class:`~discord.ClientUser` + The user presenting yourself. + """ + + __slots__ = () + + me: ClientUser + + +class _Overwrites: + __slots__ = ("id", "allow", "deny", "type") + + ROLE = 0 + MEMBER = 1 + + def __init__(self, data: PermissionOverwritePayload): + self.id: int = int(data["id"]) + self.allow: int = int(data.get("allow", 0)) + self.deny: int = int(data.get("deny", 0)) + self.type: OverwriteType = data["type"] + + def _asdict(self) -> PermissionOverwritePayload: + return { + "id": self.id, + "allow": str(self.allow), + "deny": str(self.deny), + "type": self.type, + } + + def is_role(self) -> bool: + return self.type == self.ROLE + + def is_member(self) -> bool: + return self.type == self.MEMBER + + +GCH = TypeVar("GCH", bound="GuildChannel") + + +class GuildChannel: + """An ABC that details the common operations on a Discord guild channel. + + The following implement this ABC: + + - :class:`~discord.TextChannel` + - :class:`~discord.VoiceChannel` + - :class:`~discord.CategoryChannel` + - :class:`~discord.StageChannel` + - :class:`~discord.ForumChannel` + + This ABC must also implement :class:`~discord.abc.Snowflake`. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`~discord.Guild` + The guild the channel belongs to. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__ = () + + id: int + name: str + guild: Guild + type: ChannelType + position: int + category_id: int | None + flags: ChannelFlags + _state: ConnectionState + _overwrites: list[_Overwrites] + + if TYPE_CHECKING: + + def __init__( + self, *, state: ConnectionState, guild: Guild, data: dict[str, Any] + ): + ... + + def __str__(self) -> str: + return self.name + + @property + def _sorting_bucket(self) -> int: + raise NotImplementedError + + def _update(self, guild: Guild, data: dict[str, Any]) -> None: + raise NotImplementedError + + async def _move( + self, + position: int, + parent_id: Any | None = None, + lock_permissions: bool = False, + *, + reason: str | None, + ) -> None: + if position < 0: + raise InvalidArgument("Channel position cannot be less than 0.") + + http = self._state.http + bucket = self._sorting_bucket + channels: list[GuildChannel] = [ + c for c in self.guild.channels if c._sorting_bucket == bucket + ] + + channels.sort(key=lambda c: c.position) + + try: + # remove ourselves from the channel list + channels.remove(self) + except ValueError: + # not there somehow lol + return + else: + index = next( + (i for i, c in enumerate(channels) if c.position >= position), + len(channels), + ) + # add ourselves at our designated position + channels.insert(index, self) + + payload = [] + for index, c in enumerate(channels): + d: dict[str, Any] = {"id": c.id, "position": index} + if parent_id is not MISSING and c.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await http.bulk_channel_update(self.guild.id, payload, reason=reason) + + async def _edit( + self, options: dict[str, Any], reason: str | None + ) -> ChannelPayload | None: + try: + parent = options.pop("category") + except KeyError: + parent_id = MISSING + else: + parent_id = parent and parent.id + + try: + options["rate_limit_per_user"] = options.pop("slowmode_delay") + except KeyError: + pass + + try: + rtc_region = options.pop("rtc_region") + except KeyError: + pass + else: + options["rtc_region"] = None if rtc_region is None else str(rtc_region) + + try: + video_quality_mode = options.pop("video_quality_mode") + except KeyError: + pass + else: + options["video_quality_mode"] = int(video_quality_mode) + + lock_permissions = options.pop("sync_permissions", False) + + try: + position = options.pop("position") + except KeyError: + if parent_id is not MISSING: + if lock_permissions: + category = self.guild.get_channel(parent_id) + if category: + options["permission_overwrites"] = [ + c._asdict() for c in category._overwrites + ] + options["parent_id"] = parent_id + elif lock_permissions and self.category_id is not None: + # if we're syncing permissions on a pre-existing channel category without changing it + # we need to update the permissions to point to the pre-existing category + category = self.guild.get_channel(self.category_id) + if category: + options["permission_overwrites"] = [ + c._asdict() for c in category._overwrites + ] + else: + await self._move( + position, + parent_id=parent_id, + lock_permissions=lock_permissions, + reason=reason, + ) + + overwrites = options.get("overwrites") + if overwrites is not None: + perms = [] + for target, perm in overwrites.items(): + if not isinstance(perm, PermissionOverwrite): + raise InvalidArgument( + f"Expected PermissionOverwrite received {perm.__class__.__name__}" + ) + + allow, deny = perm.pair() + payload = { + "allow": allow.value, + "deny": deny.value, + "id": target.id, + "type": _Overwrites.ROLE + if isinstance(target, Role) + else _Overwrites.MEMBER, + } + + perms.append(payload) + options["permission_overwrites"] = perms + + try: + ch_type = options["type"] + except KeyError: + pass + else: + if not isinstance(ch_type, ChannelType): + raise InvalidArgument("type field must be of type ChannelType") + options["type"] = ch_type.value + + if options: + return await self._state.http.edit_channel( + self.id, reason=reason, **options + ) + + def _fill_overwrites(self, data: GuildChannelPayload) -> None: + self._overwrites = [] + everyone_index = 0 + everyone_id = self.guild.id + + for index, overridden in enumerate(data.get("permission_overwrites", [])): + overwrite = _Overwrites(overridden) + self._overwrites.append(overwrite) + + if overwrite.type == _Overwrites.MEMBER: + continue + + if overwrite.id == everyone_id: + # the @everyone role is not guaranteed to be the first one + # in the list of permission overwrites, however the permission + # resolution code kind of requires that it is the first one in + # the list since it is special. So we need the index so we can + # swap it to be the first one. + everyone_index = index + + # do the swap + tmp = self._overwrites + if tmp: + tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] + + @property + def changed_roles(self) -> list[Role]: + """List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from + their default values in the :attr:`~discord.Guild.roles` attribute. + """ + ret = [] + g = self.guild + for overwrite in filter(lambda o: o.is_role(), self._overwrites): + role = g.get_role(overwrite.id) + if role is None: + continue + + role = copy.copy(role) + role.permissions.handle_overwrite(overwrite.allow, overwrite.deny) + ret.append(role) + return ret + + @property + def mention(self) -> str: + """:class:`str`: The string that allows you to mention the channel.""" + return f"<#{self.id}>" + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f"https://discord.com/channels/{self.guild.id}/{self.id}" + + @property + def created_at(self) -> datetime: + """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def overwrites_for(self, obj: Role | User) -> PermissionOverwrite: + """Returns the channel-specific overwrites for a member or a role. + + Parameters + ---------- + obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`] + The role or user denoting + whose overwrite to get. + + Returns + ------- + :class:`~discord.PermissionOverwrite` + The permission overwrites for this object. + """ + + if isinstance(obj, User): + predicate = lambda p: p.is_member() + elif isinstance(obj, Role): + predicate = lambda p: p.is_role() + else: + predicate = lambda p: True + + for overwrite in filter(predicate, self._overwrites): + if overwrite.id == obj.id: + allow = Permissions(overwrite.allow) + deny = Permissions(overwrite.deny) + return PermissionOverwrite.from_pair(allow, deny) + + return PermissionOverwrite() + + @property + def overwrites(self) -> dict[Role | Member, PermissionOverwrite]: + """Returns all of the channel's overwrites. + + This is returned as a dictionary where the key contains the target which + can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the value is the + overwrite as a :class:`~discord.PermissionOverwrite`. + + Returns + ------- + Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`] + The channel's permission overwrites. + """ + ret = {} + for ow in self._overwrites: + allow = Permissions(ow.allow) + deny = Permissions(ow.deny) + overwrite = PermissionOverwrite.from_pair(allow, deny) + target = None + + if ow.is_role(): + target = self.guild.get_role(ow.id) + elif ow.is_member(): + target = self.guild.get_member(ow.id) + + # TODO: There is potential data loss here in the non-chunked + # case, i.e. target is None because get_member returned nothing. + # This can be fixed with a slight breaking change to the return type, + # i.e. adding discord.Object to the list of it + # However, for now this is an acceptable compromise. + if target is not None: + ret[target] = overwrite + return ret + + @property + def category(self) -> CategoryChannel | None: + """Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to. + + If there is no category then this is ``None``. + """ + return self.guild.get_channel(self.category_id) # type: ignore + + @property + def permissions_synced(self) -> bool: + """:class:`bool`: Whether the permissions for this channel are synced with the + category it belongs to. + + If there is no category then this is ``False``. + + .. versionadded:: 1.3 + """ + if self.category_id is None: + return False + + category = self.guild.get_channel(self.category_id) + return bool(category and category.overwrites == self.overwrites) + + def permissions_for(self, obj: Member | Role, /) -> Permissions: + """Handles permission resolution for the :class:`~discord.Member` + or :class:`~discord.Role`. + + This function takes into consideration the following cases: + + - Guild owner + - Guild roles + - Channel overrides + - Member overrides + + If a :class:`~discord.Role` is passed, then it checks the permissions + someone with that role would have, which is essentially: + + - The default role permissions + - The permissions of the role used as a parameter + - The default role permission overwrites + - The permission overwrites of the role used as a parameter + + .. versionchanged:: 2.0 + The object passed in can now be a role object. + + Parameters + ---------- + obj: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The object to resolve permissions for. This could be either + a member or a role. If it's a role then member overwrites + are not computed. + + Returns + ------- + :class:`~discord.Permissions` + The resolved permissions for the member or role. + """ + + # The current cases can be explained as: + # Guild owner get all permissions -- no questions asked. Otherwise... + # The @everyone role gets the first application. + # After that, the applied roles that the user has in the channel + # (or otherwise) are then OR'd together. + # After the role permissions are resolved, the member permissions + # have to take into effect. + # After all that is done, you have to do the following: + + # If manage permissions is True, then all permissions are set to True. + + # The operation first takes into consideration the denied + # and then the allowed. + + if self.guild.owner_id == obj.id: + return Permissions.all() + + default = self.guild.default_role + base = Permissions(default.permissions.value) + + # Handle the role case first + if isinstance(obj, Role): + base.value |= obj._permissions + + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite( + allow=maybe_everyone.allow, deny=maybe_everyone.deny + ) + except IndexError: + pass + + if obj.is_default(): + return base + + overwrite = utils.get(self._overwrites, type=_Overwrites.ROLE, id=obj.id) + if overwrite is not None: + base.handle_overwrite(overwrite.allow, overwrite.deny) + + return base + + roles = obj._roles + get_role = self.guild.get_role + + # Apply guild roles that the member has. + for role_id in roles: + role = get_role(role_id) + if role is not None: + base.value |= role._permissions + + # Guild-wide Administrator -> True for everything + # Bypass all channel-specific overrides + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite( + allow=maybe_everyone.allow, deny=maybe_everyone.deny + ) + remaining_overwrites = self._overwrites[1:] + else: + remaining_overwrites = self._overwrites + except IndexError: + remaining_overwrites = self._overwrites + + denies = 0 + allows = 0 + + # Apply channel specific role permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_role() and roles.has(overwrite.id): + denies |= overwrite.deny + allows |= overwrite.allow + + base.handle_overwrite(allow=allows, deny=denies) + + # Apply member specific permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_member() and overwrite.id == obj.id: + base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) + break + + # if you can't send a message in a channel then you can't have certain + # permissions as well + if not base.send_messages: + base.send_tts_messages = False + base.mention_everyone = False + base.embed_links = False + base.attach_files = False + + # if you can't read a channel then you have no permissions there + if not base.read_messages: + denied = Permissions.all_channel() + base.value &= ~denied.value + + return base + + async def delete(self, *, reason: str | None = None) -> None: + """|coro| + + Deletes the channel. + + You must have :attr:`~discord.Permissions.manage_channels` permission to use this. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this channel. + Shows up on the audit log. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to delete the channel. + ~discord.NotFound + The channel was not found or was already deleted. + ~discord.HTTPException + Deleting the channel failed. + """ + await self._state.http.delete_channel(self.id, reason=reason) + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: PermissionOverwrite | None = ..., + reason: str | None = ..., + ) -> None: + ... + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + reason: str | None = ..., + **permissions: bool, + ) -> None: + ... + + async def set_permissions( + self, target, *, overwrite=MISSING, reason=None, **permissions + ): + r"""|coro| + + Sets the channel specific permission overwrites for a target in the + channel. + + The ``target`` parameter should either be a :class:`~discord.Member` or a + :class:`~discord.Role` that belongs to guild. + + The ``overwrite`` parameter, if given, must either be ``None`` or + :class:`~discord.PermissionOverwrite`. For convenience, you can pass in + keyword arguments denoting :class:`~discord.Permissions` attributes. If this is + done, then you cannot mix the keyword arguments with the ``overwrite`` + parameter. + + If the ``overwrite`` parameter is ``None``, then the permission + overwrites are deleted. + + You must have the :attr:`~discord.Permissions.manage_roles` permission to use this. + + .. note:: + + This method *replaces* the old overwrites with the ones given. + + Examples + ---------- + + Setting allow and deny: :: + + await message.channel.set_permissions(message.author, read_messages=True, + send_messages=False) + + Deleting overwrites :: + + await channel.set_permissions(member, overwrite=None) + + Using :class:`~discord.PermissionOverwrite` :: + + overwrite = discord.PermissionOverwrite() + overwrite.send_messages = False + overwrite.read_messages = True + await channel.set_permissions(member, overwrite=overwrite) + + Parameters + ----------- + target: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The member or role to overwrite permissions for. + overwrite: Optional[:class:`~discord.PermissionOverwrite`] + The permissions to allow and deny to the target, or ``None`` to + delete the overwrite. + \*\*permissions + A keyword argument list of permissions to set for ease of use. + Cannot be mixed with ``overwrite``. + reason: Optional[:class:`str`] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + ~discord.Forbidden + You do not have permissions to edit channel specific permissions. + ~discord.HTTPException + Editing channel specific permissions failed. + ~discord.NotFound + The role or member being edited is not part of the guild. + ~discord.InvalidArgument + The overwrite parameter invalid or the target type was not + :class:`~discord.Role` or :class:`~discord.Member`. + """ + + http = self._state.http + + if isinstance(target, User): + perm_type = _Overwrites.MEMBER + elif isinstance(target, Role): + perm_type = _Overwrites.ROLE + else: + raise InvalidArgument("target parameter must be either Member or Role") + + if overwrite is MISSING: + if len(permissions) == 0: + raise InvalidArgument("No overwrite provided.") + try: + overwrite = PermissionOverwrite(**permissions) + except (ValueError, TypeError): + raise InvalidArgument("Invalid permissions given to keyword arguments.") + elif len(permissions) > 0: + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") + + # TODO: wait for event + + if overwrite is None: + await http.delete_channel_permissions(self.id, target.id, reason=reason) + elif isinstance(overwrite, PermissionOverwrite): + (allow, deny) = overwrite.pair() + await http.edit_channel_permissions( + self.id, target.id, allow.value, deny.value, perm_type, reason=reason + ) + else: + raise InvalidArgument("Invalid overwrite type provided.") + + async def _clone_impl( + self: GCH, + base_attrs: dict[str, Any], + *, + name: str | None = None, + reason: str | None = None, + ) -> GCH: + base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] + base_attrs["parent_id"] = self.category_id + base_attrs["name"] = name or self.name + guild_id = self.guild.id + cls = self.__class__ + data = await self._state.http.create_channel( + guild_id, self.type.value, reason=reason, **base_attrs + ) + obj = cls(state=self._state, guild=self.guild, data=data) + + # temporarily add it to the cache + self.guild._channels[obj.id] = obj # type: ignore + return obj + + async def clone( + self: GCH, *, name: str | None = None, reason: str | None = None + ) -> GCH: + """|coro| + + Clones this channel. This creates a channel with the same properties + as this channel. + + You must have the :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. versionadded:: 1.1 + + Parameters + ---------- + name: Optional[:class:`str`] + The name of the new channel. If not provided, defaults to this + channel name. + reason: Optional[:class:`str`] + The reason for cloning this channel. Shows up on the audit log. + + Returns + ------- + :class:`.abc.GuildChannel` + The channel that was created. + + Raises + ------ + ~discord.Forbidden + You do not have the proper permissions to create this channel. + ~discord.HTTPException + Creating the channel failed. + """ + raise NotImplementedError + + @overload + async def move( + self, + *, + beginning: bool, + offset: int = MISSING, + category: Snowflake | None = MISSING, + sync_permissions: bool = MISSING, + reason: str | None = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + end: bool, + offset: int = MISSING, + category: Snowflake | None = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + before: Snowflake, + offset: int = MISSING, + category: Snowflake | None = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + after: Snowflake, + offset: int = MISSING, + category: Snowflake | None = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + async def move(self, **kwargs) -> None: + """|coro| + + A rich interface to help move a channel relative to other channels. + + If exact position movement is required, ``edit`` should be used instead. + + You must have the :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. note:: + + Voice channels will always be sorted below text channels. + This is a Discord limitation. + + .. versionadded:: 1.7 + + Parameters + ---------- + beginning: :class:`bool` + Whether to move the channel to the beginning of the + channel list (or category if given). + This is mutually exclusive with ``end``, ``before``, and ``after``. + end: :class:`bool` + Whether to move the channel to the end of the + channel list (or category if given). + This is mutually exclusive with ``beginning``, ``before``, and ``after``. + before: :class:`~discord.abc.Snowflake` + The channel that should be before our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``after``. + after: :class:`~discord.abc.Snowflake` + The channel that should be after our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``before``. + offset: :class:`int` + The number of channels to offset the move by. For example, + an offset of ``2`` with ``beginning=True`` would move + it 2 after the beginning. A positive number moves it below + while a negative number moves it above. Note that this + number is relative and computed after the ``beginning``, + ``end``, ``before``, and ``after`` parameters. + category: Optional[:class:`~discord.abc.Snowflake`] + The category to move this channel under. + If ``None`` is given then it moves it out of the category. + This parameter is ignored if moving a category channel. + sync_permissions: :class:`bool` + Whether to sync the permissions with the category (if given). + reason: :class:`str` + The reason for the move. + + Raises + ------ + InvalidArgument + An invalid position was given or a bad mix of arguments was passed. + Forbidden + You do not have permissions to move the channel. + HTTPException + Moving the channel failed. + """ + + if not kwargs: + return + + beginning, end = kwargs.get("beginning"), kwargs.get("end") + before, after = kwargs.get("before"), kwargs.get("after") + offset = kwargs.get("offset", 0) + if sum(bool(a) for a in (beginning, end, before, after)) > 1: + raise InvalidArgument( + "Only one of [before, after, end, beginning] can be used." + ) + + bucket = self._sorting_bucket + parent_id = kwargs.get("category", MISSING) + channels: list[GuildChannel] + if parent_id not in (MISSING, None): + parent_id = parent_id.id + channels = [ + ch + for ch in self.guild.channels + if ch._sorting_bucket == bucket and ch.category_id == parent_id + ] + else: + channels = [ + ch + for ch in self.guild.channels + if ch._sorting_bucket == bucket and ch.category_id == self.category_id + ] + + channels.sort(key=lambda c: (c.position, c.id)) + + try: + # Try to remove ourselves from the channel list + channels.remove(self) + except ValueError: + # If we're not there then it's probably due to not being in the category + pass + + index = None + if beginning: + index = 0 + elif end: + index = len(channels) + elif before: + index = next((i for i, c in enumerate(channels) if c.id == before.id), None) + elif after: + index = next( + (i + 1 for i, c in enumerate(channels) if c.id == after.id), None + ) + + if index is None: + raise InvalidArgument("Could not resolve appropriate move position") + + channels.insert(max((index + offset), 0), self) + payload = [] + lock_permissions = kwargs.get("sync_permissions", False) + reason = kwargs.get("reason") + for index, channel in enumerate(channels): + d = {"id": channel.id, "position": index} + if parent_id is not MISSING and channel.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update( + self.guild.id, payload, reason=reason + ) + + async def create_invite( + self, + *, + reason: str | None = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + target_event: ScheduledEvent | None = None, + target_type: InviteTarget | None = None, + target_user: User | None = None, + target_application_id: int | None = None, + ) -> Invite: + """|coro| + + Creates an instant invite from a text or voice channel. + + You must have the :attr:`~discord.Permissions.create_instant_invite` permission to + do this. + + Parameters + ---------- + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite + doesn't expire. Defaults to ``0``. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there + are unlimited uses. Defaults to ``0``. + temporary: :class:`bool` + Denotes that the invite grants temporary membership + (i.e. they get kicked after they disconnect). Defaults to ``False``. + unique: :class:`bool` + Indicates if a unique invite URL should be created. Defaults to True. + If this is set to ``False`` then it will return a previously created + invite. + reason: Optional[:class:`str`] + The reason for creating this invite. Shows up on the audit log. + target_type: Optional[:class:`.InviteTarget`] + The type of target for the voice channel invite, if any. + + .. versionadded:: 2.0 + + target_user: Optional[:class:`User`] + The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. + The user must be streaming in the channel. + + .. versionadded:: 2.0 + + target_application_id: Optional[:class:`int`] + The id of the embedded application for the invite, required if `target_type` is + `TargetType.embedded_application`. + + .. versionadded:: 2.0 + + target_event: Optional[:class:`.ScheduledEvent`] + The scheduled event object to link to the event. + Shortcut to :meth:`.Invite.set_scheduled_event` + + See :meth:`.Invite.set_scheduled_event` for more + info on event invite linking. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + ~discord.HTTPException + Invite creation failed. + + ~discord.NotFound + The channel that was passed is a category or an invalid channel. + """ + + data = await self._state.http.create_invite( + self.id, + reason=reason, + max_age=max_age, + max_uses=max_uses, + temporary=temporary, + unique=unique, + target_type=target_type.value if target_type else None, + target_user_id=target_user.id if target_user else None, + target_application_id=target_application_id, + ) + invite = Invite.from_incomplete(data=data, state=self._state) + if target_event: + invite.set_scheduled_event(target_event) + return invite + + async def invites(self) -> list[Invite]: + """|coro| + + Returns a list of all active instant invites from this channel. + + You must have :attr:`~discord.Permissions.manage_channels` to get this information. + + Returns + ------- + List[:class:`~discord.Invite`] + The list of invites that are currently active. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to get the information. + ~discord.HTTPException + An error occurred while fetching the information. + """ + + state = self._state + data = await state.http.invites_from_channel(self.id) + guild = self.guild + return [ + Invite(state=state, data=invite, channel=self, guild=guild) + for invite in data + ] + + +class Messageable: + """An ABC that details the common operations on a model that can send messages. + + The following implement this ABC: + + - :class:`~discord.TextChannel` + - :class:`~discord.DMChannel` + - :class:`~discord.GroupChannel` + - :class:`~discord.User` + - :class:`~discord.Member` + - :class:`~discord.ext.commands.Context` + - :class:`~discord.Thread` + - :class:`~discord.ApplicationContext` + """ + + __slots__ = () + _state: ConnectionState + + async def _get_channel(self) -> MessageableChannel: + raise NotImplementedError + + @overload + async def send( + self, + content: str | None = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[GuildSticker | StickerItem] = ..., + delete_after: float = ..., + nonce: str | int = ..., + allowed_mentions: AllowedMentions = ..., + reference: Message | MessageReference | PartialMessage = ..., + mention_author: bool = ..., + view: View = ..., + suppress: bool = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: str | None = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: list[File] = ..., + stickers: Sequence[GuildSticker | StickerItem] = ..., + delete_after: float = ..., + nonce: str | int = ..., + allowed_mentions: AllowedMentions = ..., + reference: Message | MessageReference | PartialMessage = ..., + mention_author: bool = ..., + view: View = ..., + suppress: bool = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: str | None = ..., + *, + tts: bool = ..., + embeds: list[Embed] = ..., + file: File = ..., + stickers: Sequence[GuildSticker | StickerItem] = ..., + delete_after: float = ..., + nonce: str | int = ..., + allowed_mentions: AllowedMentions = ..., + reference: Message | MessageReference | PartialMessage = ..., + mention_author: bool = ..., + view: View = ..., + suppress: bool = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: str | None = ..., + *, + tts: bool = ..., + embeds: list[Embed] = ..., + files: list[File] = ..., + stickers: Sequence[GuildSticker | StickerItem] = ..., + delete_after: float = ..., + nonce: str | int = ..., + allowed_mentions: AllowedMentions = ..., + reference: Message | MessageReference | PartialMessage = ..., + mention_author: bool = ..., + view: View = ..., + suppress: bool = ..., + ) -> Message: + ... + + async def send( + self, + content=None, + *, + tts=None, + embed=None, + embeds=None, + file=None, + files=None, + stickers=None, + delete_after=None, + nonce=None, + allowed_mentions=None, + reference=None, + mention_author=None, + view=None, + suppress=None, + ): + """|coro| + + Sends a message to the destination with the content given. + + The content must be a type that can convert to a string through ``str(content)``. + If the content is set to ``None`` (the default), then the ``embed`` parameter must + be provided. + + To upload a single file, the ``file`` parameter should be used with a + single :class:`~discord.File` object. To upload multiple files, the ``files`` + parameter should be used with a :class:`list` of :class:`~discord.File` objects. + **Specifying both parameters will lead to an exception**. + + To upload a single embed, the ``embed`` parameter should be used with a + single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds`` + parameter should be used with a :class:`list` of :class:`~discord.Embed` objects. + **Specifying both parameters will lead to an exception**. + + Parameters + ---------- + content: Optional[:class:`str`] + The content of the message to send. + tts: :class:`bool` + Indicates if the message should be sent using text-to-speech. + embed: :class:`~discord.Embed` + The rich embed for the content. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + nonce: :class:`int` + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + delete_after: :class:`float` + If provided, the number of seconds to wait in the background + before deleting the message we just sent. If the deletion fails, + then it is silently ignored. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + + .. versionadded:: 1.4 + + reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`] + A reference to the :class:`~discord.Message` to which you are replying, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control + whether this mentions the author of the referenced message using the + :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions`` or by + setting ``mention_author``. + + .. versionadded:: 1.6 + + mention_author: Optional[:class:`bool`] + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + + .. versionadded:: 1.6 + view: :class:`discord.ui.View` + A Discord UI View to add to the message. + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + + .. versionadded:: 2.0 + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. + + .. versionadded:: 2.0 + suppress: :class:`bool` + Whether to suppress embeds for the message. + + Returns + ------- + :class:`~discord.Message` + The message that was sent. + + Raises + ------ + ~discord.HTTPException + Sending the message failed. + ~discord.Forbidden + You do not have the proper permissions to send the message. + ~discord.InvalidArgument + The ``files`` list is not of the appropriate size, + you specified both ``file`` and ``files``, + or you specified both ``embed`` and ``embeds``, + or the ``reference`` object is not a :class:`~discord.Message`, + :class:`~discord.MessageReference` or :class:`~discord.PartialMessage`. + """ + + channel = await self._get_channel() + state = self._state + content = str(content) if content is not None else None + + if embed is not None and embeds is not None: + raise InvalidArgument( + "cannot pass both embed and embeds parameter to send()" + ) + + if embed is not None: + embed = embed.to_dict() + + elif embeds is not None: + if len(embeds) > 10: + raise InvalidArgument( + "embeds parameter must be a list of up to 10 elements" + ) + embeds = [embed.to_dict() for embed in embeds] + + flags = MessageFlags.suppress_embeds if suppress else MessageFlags.DEFAULT_VALUE + + if stickers is not None: + stickers = [sticker.id for sticker in stickers] + + if allowed_mentions is None: + allowed_mentions = ( + state.allowed_mentions and state.allowed_mentions.to_dict() + ) + elif state.allowed_mentions is not None: + allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + else: + allowed_mentions = allowed_mentions.to_dict() + + if mention_author is not None: + allowed_mentions = allowed_mentions or AllowedMentions().to_dict() + allowed_mentions["replied_user"] = bool(mention_author) + + if reference is not None: + try: + reference = reference.to_message_reference_dict() + except AttributeError: + raise InvalidArgument( + "reference parameter must be Message, MessageReference, or PartialMessage" + ) from None + + if view: + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument( + f"view parameter must be View not {view.__class__!r}" + ) + + components = view.to_components() + else: + components = None + + if file is not None and files is not None: + raise InvalidArgument("cannot pass both file and files parameter to send()") + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + + try: + data = await state.http.send_files( + channel.id, + files=[file], + allowed_mentions=allowed_mentions, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + message_reference=reference, + stickers=stickers, + components=components, + flags=flags, + ) + finally: + file.close() + + elif files is not None: + if len(files) > 10: + raise InvalidArgument( + "files parameter must be a list of up to 10 elements" + ) + elif not all(isinstance(file, File) for file in files): + raise InvalidArgument("files parameter must be a list of File") + + try: + data = await state.http.send_files( + channel.id, + files=files, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference, + stickers=stickers, + components=components, + flags=flags, + ) + finally: + for f in files: + f.close() + else: + data = await state.http.send_message( + channel.id, + content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference, + stickers=stickers, + components=components, + flags=flags, + ) + + ret = state.create_message(channel=channel, data=data) + if view: + state.store_view(view, ret.id) + view.message = ret + + if delete_after is not None: + await ret.delete(delay=delete_after) + return ret + + async def trigger_typing(self) -> None: + """|coro| + + Triggers a *typing* indicator to the destination. + + *Typing* indicator will go away after 10 seconds, or after a message is sent. + """ + + channel = await self._get_channel() + await self._state.http.send_typing(channel.id) + + def typing(self) -> Typing: + """Returns a context manager that allows you to type for an indefinite period of time. + + This is useful for denoting long computations in your bot. + + .. note:: + + This is both a regular context manager and an async context manager. + This means that both ``with`` and ``async with`` work with this. + + Example Usage: :: + + async with channel.typing(): + # simulate something heavy + await asyncio.sleep(10) + + await channel.send('done!') + """ + return Typing(self) + + async def fetch_message(self, id: int, /) -> Message: + """|coro| + + Retrieves a single :class:`~discord.Message` from the destination. + + Parameters + ---------- + id: :class:`int` + The message ID to look for. + + Returns + ------- + :class:`~discord.Message` + The message asked for. + + Raises + ------ + ~discord.NotFound + The specified message was not found. + ~discord.Forbidden + You do not have the permissions required to get a message. + ~discord.HTTPException + Retrieving the message failed. + """ + + channel = await self._get_channel() + data = await self._state.http.get_message(channel.id, id) + return self._state.create_message(channel=channel, data=data) + + async def pins(self) -> list[Message]: + """|coro| + + Retrieves all messages that are currently pinned in the channel. + + .. note:: + + Due to a limitation with the Discord API, the :class:`.Message` + objects returned by this method do not contain complete + :attr:`.Message.reactions` data. + + Returns + ------- + List[:class:`~discord.Message`] + The messages that are currently pinned. + + Raises + ------ + ~discord.HTTPException + Retrieving the pinned messages failed. + """ + + channel = await self._get_channel() + state = self._state + data = await state.http.pins_from(channel.id) + return [state.create_message(channel=channel, data=m) for m in data] + + def can_send(self, *objects) -> bool: + """Returns a :class:`bool` indicating whether you have the permissions to send the object(s). + + Returns + ------- + :class:`bool` + Indicates whether you have the permissions to send the object(s). + + Raises + ------ + TypeError + An invalid type has been passed. + """ + mapping = { + "Message": "send_messages", + "Embed": "embed_links", + "File": "attach_files", + "Emoji": "use_external_emojis", + "GuildSticker": "use_external_stickers", + } + # Can't use channel = await self._get_channel() since its async + if hasattr(self, "permissions_for"): + channel = self + elif hasattr(self, "channel") and type(self.channel).__name__ not in ( + "DMChannel", + "GroupChannel", + ): + channel = self.channel + else: + return True # Permissions don't exist for User DMs + + objects = (None,) + objects # Makes sure we check for send_messages first + + for obj in objects: + try: + if obj is None: + permission = mapping["Message"] + else: + permission = ( + mapping.get(type(obj).__name__) or mapping[obj.__name__] + ) + + if type(obj).__name__ == "Emoji": + if ( + obj._to_partial().is_unicode_emoji + or obj.guild_id == channel.guild.id + ): + continue + elif type(obj).__name__ == "GuildSticker": + if obj.guild_id == channel.guild.id: + continue + + except (KeyError, AttributeError): + raise TypeError(f"The object {obj} is of an invalid type.") + + if not getattr(channel.permissions_for(channel.guild.me), permission): + return False + + return True + + def history( + self, + *, + limit: int | None = 100, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = None, + ) -> HistoryIterator: + """Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history. + + You must have :attr:`~discord.Permissions.read_message_history` permissions to use this. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of messages to retrieve. + If ``None``, retrieves every message in the channel. Note, however, + that this would make it a slow operation. + before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve messages before this date or message. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve messages after this date or message. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + around: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve messages around this date or message. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + When using this argument, the maximum limit is 101. Note that if the limit is an + even number, then this will return at most limit + 1 messages. + oldest_first: Optional[:class:`bool`] + If set to ``True``, return messages in oldest->newest order. Defaults to ``True`` if + ``after`` is specified, otherwise ``False``. + + Yields + ------ + :class:`~discord.Message` + The message with the message data parsed. + + Raises + ------ + ~discord.Forbidden + You do not have permissions to get channel message history. + ~discord.HTTPException + The request to get message history failed. + + Examples + -------- + + Usage :: + + counter = 0 + async for message in channel.history(limit=200): + if message.author == client.user: + counter += 1 + + Flattening into a list: :: + + messages = await channel.history(limit=123).flatten() + # messages is now a list of Message... + + All parameters are optional. + """ + return HistoryIterator( + self, + limit=limit, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + ) + + +class Connectable(Protocol): + """An ABC that details the common operations on a channel that can + connect to a voice server. + + The following implement this ABC: + + - :class:`~discord.VoiceChannel` + - :class:`~discord.StageChannel` + + Note + ---- + This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` + checks. + """ + + __slots__ = () + _state: ConnectionState + + def _get_voice_client_key(self) -> tuple[int, str]: + raise NotImplementedError + + def _get_voice_state_pair(self) -> tuple[int, int]: + raise NotImplementedError + + async def connect( + self, + *, + timeout: float = 60.0, + reconnect: bool = True, + cls: Callable[[Client, Connectable], T] = VoiceClient, + ) -> T: + """|coro| + + Connects to voice and creates a :class:`VoiceClient` to establish + your connection to the voice server. + + This requires :attr:`Intents.voice_states`. + + Parameters + ---------- + timeout: :class:`float` + The timeout in seconds to wait for the voice endpoint. + reconnect: :class:`bool` + Whether the bot should automatically attempt + a reconnect if a part of the handshake fails + or the gateway goes down. + cls: Type[:class:`VoiceProtocol`] + A type that subclasses :class:`~discord.VoiceProtocol` to connect with. + Defaults to :class:`~discord.VoiceClient`. + + Returns + ------- + :class:`~discord.VoiceProtocol` + A voice client that is fully connected to the voice server. + + Raises + ------ + asyncio.TimeoutError + Could not connect to the voice channel in time. + ~discord.ClientException + You are already connected to a voice channel. + ~discord.opus.OpusNotLoaded + The opus library has not been loaded. + """ + + key_id, _ = self._get_voice_client_key() + state = self._state + + if state._get_voice_client(key_id): + raise ClientException("Already connected to a voice channel.") + + client = state._get_client() + voice = cls(client, self) + + if not isinstance(voice, VoiceProtocol): + raise TypeError("Type must meet VoiceProtocol abstract base class.") + + state._add_voice_client(key_id, voice) + + try: + await voice.connect(timeout=timeout, reconnect=reconnect) + except asyncio.TimeoutError: + try: + await voice.disconnect(force=True) + except Exception: + # we don't care if disconnect failed because connection failed + pass + raise # re-raise + + return voice + + +class Mentionable: + # TODO: documentation, methods if needed + pass diff --git a/discord/activity.py b/discord/activity.py new file mode 100644 index 0000000..e6499c5 --- /dev/null +++ b/discord/activity.py @@ -0,0 +1,886 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Union, overload + +from .asset import Asset +from .colour import Colour +from .enums import ActivityType, try_enum +from .partial_emoji import PartialEmoji +from .utils import _get_as_snowflake + +__all__ = ( + "BaseActivity", + "Activity", + "Streaming", + "Game", + "Spotify", + "CustomActivity", +) + +"""If you're curious, this is the current schema for an activity. + +It's fairly long so I will document it here: + +All keys are optional. + +state: str (max: 128), +details: str (max: 128) +timestamps: dict + start: int (min: 1) + end: int (min: 1) +assets: dict + large_image: str (max: 32) + large_text: str (max: 128) + small_image: str (max: 32) + small_text: str (max: 128) +party: dict + id: str (max: 128), + size: List[int] (max-length: 2) + elem: int (min: 1) +secrets: dict + match: str (max: 128) + join: str (max: 128) + spectate: str (max: 128) +instance: bool +application_id: str +name: str (max: 128) +url: str +type: int +sync_id: str +session_id: str +flags: int +buttons: list[dict] + label: str (max: 32) + url: str (max: 512) +NOTE: Bots cannot access a user's activity button URLs. When received through the +gateway, the type of the buttons field will be list[str]. + +There are also activity flags which are mostly uninteresting for the library atm. + +t.ActivityFlags = { + INSTANCE: 1, + JOIN: 2, + SPECTATE: 4, + JOIN_REQUEST: 8, + SYNC: 16, + PLAY: 32 +} +""" + +if TYPE_CHECKING: + from .types.activity import Activity as ActivityPayload + from .types.activity import ActivityAssets, ActivityParty, ActivityTimestamps + + +class BaseActivity: + """The base activity that all user-settable activities inherit from. + A user-settable activity is one that can be used in :meth:`Client.change_presence`. + + The following types currently count as user-settable: + + - :class:`Activity` + - :class:`Game` + - :class:`Streaming` + - :class:`CustomActivity` + + Note that although these types are considered user-settable by the library, + Discord typically ignores certain combinations of activity depending on + what is currently set. This behaviour may change in the future so there are + no guarantees on whether Discord will actually let you set these types. + + .. versionadded:: 1.3 + """ + + __slots__ = ("_created_at",) + + def __init__(self, **kwargs): + self._created_at: float | None = kwargs.pop("created_at", None) + + @property + def created_at(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC. + + .. versionadded:: 1.3 + """ + if self._created_at is not None: + return datetime.datetime.fromtimestamp( + self._created_at / 1000, tz=datetime.timezone.utc + ) + + def to_dict(self) -> ActivityPayload: + raise NotImplementedError + + +class Activity(BaseActivity): + """Represents an activity in Discord. + + This could be an activity such as streaming, playing, listening + or watching. + + For memory optimisation purposes, some activities are offered in slimmed + down versions: + + - :class:`Game` + - :class:`Streaming` + + Attributes + ---------- + application_id: Optional[:class:`int`] + The application ID of the game. + name: Optional[:class:`str`] + The name of the activity. + url: Optional[:class:`str`] + A stream URL that the activity could be doing. + type: :class:`ActivityType` + The type of activity currently being done. + state: Optional[:class:`str`] + The user's current state. For example, "In Game". + details: Optional[:class:`str`] + The detail of the user's current activity. + timestamps: Dict[:class:`str`, :class:`int`] + A dictionary of timestamps. It contains the following optional keys: + + - ``start``: Corresponds to when the user started doing the + activity in milliseconds since Unix epoch. + - ``end``: Corresponds to when the user will finish doing the + activity in milliseconds since Unix epoch. + + assets: Dict[:class:`str`, :class:`str`] + A dictionary representing the images and their hover text of an activity. + It contains the following optional keys: + + - ``large_image``: A string representing the ID for the large image asset. + - ``large_text``: A string representing the text when hovering over the large image asset. + - ``small_image``: A string representing the ID for the small image asset. + - ``small_text``: A string representing the text when hovering over the small image asset. + + party: Dict[:class:`str`, Union[:class:`str`, List[:class:`int`]]] + A dictionary representing the activity party. It contains the following optional keys: + + - ``id``: A string representing the party ID. + - ``size``: A list of up to two integer elements denoting (current_size, maximum_size). + buttons: Union[List[Dict[:class:`str`, :class:`str`]], List[:class:`str`]] + A list of dictionaries representing custom buttons shown in a rich presence. + Each dictionary contains the following keys: + + - ``label``: A string representing the text shown on the button. + - ``url``: A string representing the URL opened upon clicking the button. + + .. note:: + + Bots cannot access a user's activity button URLs. Therefore, the type of this attribute + will be List[:class:`str`] when received through the gateway. + + .. versionadded:: 2.0 + + emoji: Optional[:class:`PartialEmoji`] + The emoji that belongs to this activity. + """ + + __slots__ = ( + "state", + "details", + "_created_at", + "timestamps", + "assets", + "party", + "flags", + "sync_id", + "session_id", + "type", + "name", + "url", + "application_id", + "emoji", + "buttons", + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.state: str | None = kwargs.pop("state", None) + self.details: str | None = kwargs.pop("details", None) + self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {}) + self.assets: ActivityAssets = kwargs.pop("assets", {}) + self.party: ActivityParty = kwargs.pop("party", {}) + self.application_id: int | None = _get_as_snowflake(kwargs, "application_id") + self.name: str | None = kwargs.pop("name", None) + self.url: str | None = kwargs.pop("url", None) + self.flags: int = kwargs.pop("flags", 0) + self.sync_id: str | None = kwargs.pop("sync_id", None) + self.session_id: str | None = kwargs.pop("session_id", None) + self.buttons: list[str] = kwargs.pop("buttons", []) + + activity_type = kwargs.pop("type", -1) + self.type: ActivityType = ( + activity_type + if isinstance(activity_type, ActivityType) + else try_enum(ActivityType, activity_type) + ) + + emoji = kwargs.pop("emoji", None) + self.emoji: PartialEmoji | None = ( + PartialEmoji.from_dict(emoji) if emoji is not None else None + ) + + def __repr__(self) -> str: + attrs = ( + ("type", self.type), + ("name", self.name), + ("url", self.url), + ("details", self.details), + ("application_id", self.application_id), + ("session_id", self.session_id), + ("emoji", self.emoji), + ) + inner = " ".join("%s=%r" % t for t in attrs) + return f"" + + def to_dict(self) -> dict[str, Any]: + ret: dict[str, Any] = {} + for attr in self.__slots__: + value = getattr(self, attr, None) + if value is None: + continue + + if isinstance(value, dict) and len(value) == 0: + continue + + ret[attr] = value + ret["type"] = int(self.type) + if self.emoji: + ret["emoji"] = self.emoji.to_dict() + return ret + + @property + def start(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" + try: + timestamp = self.timestamps["start"] / 1000 + except KeyError: + return None + else: + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + + @property + def end(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" + try: + timestamp = self.timestamps["end"] / 1000 + except KeyError: + return None + else: + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + + @property + def large_image_url(self) -> str | None: + """Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable.""" + if self.application_id is None: + return None + + try: + large_image = self.assets["large_image"] + except KeyError: + return None + else: + return f"{Asset.BASE}/app-assets/{self.application_id}/{large_image}.png" + + @property + def small_image_url(self) -> str | None: + """Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable.""" + if self.application_id is None: + return None + + try: + small_image = self.assets["small_image"] + except KeyError: + return None + else: + return f"{Asset.BASE}/app-assets/{self.application_id}/{small_image}.png" + + @property + def large_image_text(self) -> str | None: + """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" + return self.assets.get("large_text", None) + + @property + def small_image_text(self) -> str | None: + """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" + return self.assets.get("small_text", None) + + +class Game(BaseActivity): + """A slimmed down version of :class:`Activity` that represents a Discord game. + + This is typically displayed via **Playing** on the official Discord client. + + .. container:: operations + + .. describe:: x == y + + Checks if two games are equal. + + .. describe:: x != y + + Checks if two games are not equal. + + .. describe:: hash(x) + + Returns the game's hash. + + .. describe:: str(x) + + Returns the game's name. + + Parameters + ---------- + name: :class:`str` + The game's name. + + Attributes + ---------- + name: :class:`str` + The game's name. + """ + + __slots__ = ("name", "_end", "_start") + + def __init__(self, name: str, **extra): + super().__init__(**extra) + self.name: str = name + + try: + timestamps: ActivityTimestamps = extra["timestamps"] + except KeyError: + self._start = 0 + self._end = 0 + else: + self._start = timestamps.get("start", 0) + self._end = timestamps.get("end", 0) + + @property + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.playing`. + """ + return ActivityType.playing + + @property + def start(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" + if self._start: + return datetime.datetime.fromtimestamp( + self._start / 1000, tz=datetime.timezone.utc + ) + return None + + @property + def end(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" + if self._end: + return datetime.datetime.fromtimestamp( + self._end / 1000, tz=datetime.timezone.utc + ) + return None + + def __str__(self) -> str: + return str(self.name) + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> dict[str, Any]: + timestamps: dict[str, Any] = {} + if self._start: + timestamps["start"] = self._start + + if self._end: + timestamps["end"] = self._end + + return { + "type": ActivityType.playing.value, + "name": str(self.name), + "timestamps": timestamps, + } + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Game) and other.name == self.name + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.name) + + +class Streaming(BaseActivity): + """A slimmed down version of :class:`Activity` that represents a Discord streaming status. + + This is typically displayed via **Streaming** on the official Discord client. + + .. container:: operations + + .. describe:: x == y + + Checks if two streams are equal. + + .. describe:: x != y + + Checks if two streams are not equal. + + .. describe:: hash(x) + + Returns the stream's hash. + + .. describe:: str(x) + + Returns the stream's name. + + Attributes + ---------- + platform: Optional[:class:`str`] + Where the user is streaming from (ie. YouTube, Twitch). + + .. versionadded:: 1.3 + + name: Optional[:class:`str`] + The stream's name. + details: Optional[:class:`str`] + An alias for :attr:`name` + game: Optional[:class:`str`] + The game being streamed. + + .. versionadded:: 1.3 + + url: :class:`str` + The stream's URL. + assets: Dict[:class:`str`, :class:`str`] + A dictionary comprised of similar keys than those in :attr:`Activity.assets`. + """ + + __slots__ = ("platform", "name", "game", "url", "details", "assets") + + def __init__(self, *, name: str | None, url: str, **extra: Any): + super().__init__(**extra) + self.platform: str | None = name + self.name: str | None = extra.pop("details", name) + self.game: str | None = extra.pop("state", None) + self.url: str = url + self.details: str | None = extra.pop("details", self.name) # compatibility + self.assets: ActivityAssets = extra.pop("assets", {}) + + @property + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.streaming`. + """ + return ActivityType.streaming + + def __str__(self) -> str: + return str(self.name) + + def __repr__(self) -> str: + return f"" + + @property + def twitch_name(self): + """Optional[:class:`str`]: If provided, the twitch name of the user streaming. + + This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` + dictionary if it starts with ``twitch:``. Typically this is set by the Discord client. + """ + + try: + name = self.assets["large_image"] + except KeyError: + return None + else: + return name[7:] if name[:7] == "twitch:" else None + + def to_dict(self) -> dict[str, Any]: + ret: dict[str, Any] = { + "type": ActivityType.streaming.value, + "name": str(self.name), + "url": str(self.url), + "assets": self.assets, + } + if self.details: + ret["details"] = self.details + return ret + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Streaming) + and other.name == self.name + and other.url == self.url + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.name) + + +class Spotify: + """Represents a Spotify listening activity from Discord. This is a special case of + :class:`Activity` that makes it easier to work with the Spotify integration. + + .. container:: operations + + .. describe:: x == y + + Checks if two activities are equal. + + .. describe:: x != y + + Checks if two activities are not equal. + + .. describe:: hash(x) + + Returns the activity's hash. + + .. describe:: str(x) + + Returns the string 'Spotify'. + """ + + __slots__ = ( + "_state", + "_details", + "_timestamps", + "_assets", + "_party", + "_sync_id", + "_session_id", + "_created_at", + ) + + def __init__(self, **data): + self._state: str = data.pop("state", "") + self._details: str = data.pop("details", "") + self._timestamps: dict[str, int] = data.pop("timestamps", {}) + self._assets: ActivityAssets = data.pop("assets", {}) + self._party: ActivityParty = data.pop("party", {}) + self._sync_id: str = data.pop("sync_id") + self._session_id: str = data.pop("session_id") + self._created_at: float | None = data.pop("created_at", None) + + @property + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.listening`. + """ + return ActivityType.listening + + @property + def created_at(self) -> datetime.datetime | None: + """Optional[:class:`datetime.datetime`]: When the user started listening in UTC. + + .. versionadded:: 1.3 + """ + if self._created_at is not None: + return datetime.datetime.fromtimestamp( + self._created_at / 1000, tz=datetime.timezone.utc + ) + + @property + def colour(self) -> Colour: + """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. + + There is an alias for this named :attr:`color` + """ + return Colour(0x1DB954) + + @property + def color(self) -> Colour: + """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. + + There is an alias for this named :attr:`colour` + """ + return self.colour + + def to_dict(self) -> dict[str, Any]: + return { + "flags": 48, # SYNC | PLAY + "name": "Spotify", + "assets": self._assets, + "party": self._party, + "sync_id": self._sync_id, + "session_id": self._session_id, + "timestamps": self._timestamps, + "details": self._details, + "state": self._state, + } + + @property + def name(self) -> str: + """:class:`str`: The activity's name. This will always return "Spotify".""" + return "Spotify" + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Spotify) + and other._session_id == self._session_id + and other._sync_id == self._sync_id + and other.start == self.start + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self._session_id) + + def __str__(self) -> str: + return "Spotify" + + def __repr__(self) -> str: + return f"" + + @property + def title(self) -> str: + """:class:`str`: The title of the song being played.""" + return self._details + + @property + def artists(self) -> list[str]: + """List[:class:`str`]: The artists of the song being played.""" + return self._state.split("; ") + + @property + def artist(self) -> str: + """:class:`str`: The artist of the song being played. + + This does not attempt to split the artist information into + multiple artists. Useful if there's only a single artist. + """ + return self._state + + @property + def album(self) -> str: + """:class:`str`: The album that the song being played belongs to.""" + return self._assets.get("large_text", "") + + @property + def album_cover_url(self) -> str: + """:class:`str`: The album cover image URL from Spotify's CDN.""" + large_image = self._assets.get("large_image", "") + if large_image[:8] != "spotify:": + return "" + album_image_id = large_image[8:] + return f"https://i.scdn.co/image/{album_image_id}" + + @property + def track_id(self) -> str: + """:class:`str`: The track ID used by Spotify to identify this song.""" + return self._sync_id + + @property + def track_url(self) -> str: + """:class:`str`: The track URL to listen on Spotify. + + .. versionadded:: 2.0 + """ + return f"https://open.spotify.com/track/{self.track_id}" + + @property + def start(self) -> datetime.datetime: + """:class:`datetime.datetime`: When the user started playing this song in UTC.""" + return datetime.datetime.fromtimestamp( + self._timestamps["start"] / 1000, tz=datetime.timezone.utc + ) + + @property + def end(self) -> datetime.datetime: + """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" + return datetime.datetime.fromtimestamp( + self._timestamps["end"] / 1000, tz=datetime.timezone.utc + ) + + @property + def duration(self) -> datetime.timedelta: + """:class:`datetime.timedelta`: The duration of the song being played.""" + return self.end - self.start + + @property + def party_id(self) -> str: + """:class:`str`: The party ID of the listening party.""" + return self._party.get("id", "") + + +class CustomActivity(BaseActivity): + """Represents a Custom activity from Discord. + + .. container:: operations + + .. describe:: x == y + + Checks if two activities are equal. + + .. describe:: x != y + + Checks if two activities are not equal. + + .. describe:: hash(x) + + Returns the activity's hash. + + .. describe:: str(x) + + Returns the custom status text. + + .. versionadded:: 1.3 + + Attributes + ---------- + name: Optional[:class:`str`] + The custom activity's name. + emoji: Optional[:class:`PartialEmoji`] + The emoji to pass to the activity, if any. + """ + + __slots__ = ("name", "emoji", "state") + + def __init__( + self, name: str | None, *, emoji: PartialEmoji | None = None, **extra: Any + ): + super().__init__(**extra) + self.name: str | None = name + self.state: str | None = extra.pop("state", None) + if self.name == "Custom Status": + self.name = self.state + + self.emoji: PartialEmoji | None + if emoji is None: + self.emoji = emoji + elif isinstance(emoji, dict): + self.emoji = PartialEmoji.from_dict(emoji) + elif isinstance(emoji, str): + self.emoji = PartialEmoji(name=emoji) + elif isinstance(emoji, PartialEmoji): + self.emoji = emoji + else: + raise TypeError( + f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead." + ) + + @property + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.custom`. + """ + return ActivityType.custom + + def to_dict(self) -> dict[str, Any]: + if self.name == self.state: + o = { + "type": ActivityType.custom.value, + "state": self.name, + "name": "Custom Status", + } + else: + o = { + "type": ActivityType.custom.value, + "name": self.name, + } + + if self.emoji: + o["emoji"] = self.emoji.to_dict() + return o + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, CustomActivity) + and other.name == self.name + and other.emoji == self.emoji + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash((self.name, str(self.emoji))) + + def __str__(self) -> str: + if not self.emoji: + return str(self.name) + if self.name: + return f"{self.emoji} {self.name}" + return str(self.emoji) + + def __repr__(self) -> str: + return f"" + + +ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + + +@overload +def create_activity(data: ActivityPayload) -> ActivityTypes: + ... + + +@overload +def create_activity(data: None) -> None: + ... + + +def create_activity(data: ActivityPayload | None) -> ActivityTypes | None: + if not data: + return None + + game_type = try_enum(ActivityType, data.get("type", -1)) + if game_type is ActivityType.playing: + if "application_id" in data or "session_id" in data: + return Activity(**data) + return Game(**data) + elif game_type is ActivityType.custom: + try: + name = data.pop("name") + except KeyError: + return Activity(**data) + else: + # we removed the name key from data already + return CustomActivity(name=name, **data) # type: ignore + elif game_type is ActivityType.streaming: + if "url" in data: + # the url won't be None here + return Streaming(**data) # type: ignore + return Activity(**data) + elif ( + game_type is ActivityType.listening + and "sync_id" in data + and "session_id" in data + ): + return Spotify(**data) + return Activity(**data) diff --git a/discord/appinfo.py b/discord/appinfo.py new file mode 100644 index 0000000..9024b5e --- /dev/null +++ b/discord/appinfo.py @@ -0,0 +1,259 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from . import utils +from .asset import Asset + +if TYPE_CHECKING: + from .guild import Guild + from .state import ConnectionState + from .types.appinfo import AppInfo as AppInfoPayload + from .types.appinfo import PartialAppInfo as PartialAppInfoPayload + from .types.appinfo import Team as TeamPayload + from .user import User + +__all__ = ( + "AppInfo", + "PartialAppInfo", +) + + +class AppInfo: + """Represents the application info for the bot provided by Discord. + + Attributes + ---------- + id: :class:`int` + The application ID. + name: :class:`str` + The application name. + owner: :class:`User` + The application owner. + team: Optional[:class:`Team`] + The application's team. + + .. versionadded:: 1.3 + + description: :class:`str` + The application description. + bot_public: :class:`bool` + Whether the bot can be invited by anyone or if it is locked + to the application owner. + bot_require_code_grant: :class:`bool` + Whether the bot requires the completion of the full OAuth2 code + grant flow to join. + rpc_origins: Optional[List[:class:`str`]] + A list of RPC origin URLs, if RPC is enabled. + summary: :class:`str` + If this application is a game sold on Discord, + this field will be the summary field for the store page of its primary SKU. + + .. versionadded:: 1.3 + + verify_key: :class:`str` + The hex encoded key for verification in interactions and the + GameSDK's `GetTicket `_. + + .. versionadded:: 1.3 + + guild_id: Optional[:class:`int`] + If this application is a game sold on Discord, + this field will be the guild to which it has been linked to. + + .. versionadded:: 1.3 + + primary_sku_id: Optional[:class:`int`] + If this application is a game sold on Discord, + this field will be the id of the "Game SKU" that is created, + if it exists. + + .. versionadded:: 1.3 + + slug: Optional[:class:`str`] + If this application is a game sold on Discord, + this field will be the URL slug that links to the store page. + + .. versionadded:: 1.3 + + terms_of_service_url: Optional[:class:`str`] + The application's terms of service URL, if set. + + .. versionadded:: 2.0 + + privacy_policy_url: Optional[:class:`str`] + The application's privacy policy URL, if set. + + .. versionadded:: 2.0 + """ + + __slots__ = ( + "_state", + "description", + "id", + "name", + "rpc_origins", + "bot_public", + "bot_require_code_grant", + "owner", + "_icon", + "summary", + "verify_key", + "team", + "guild_id", + "primary_sku_id", + "slug", + "_cover_image", + "terms_of_service_url", + "privacy_policy_url", + ) + + def __init__(self, state: ConnectionState, data: AppInfoPayload): + from .team import Team + + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.description: str = data["description"] + self._icon: str | None = data["icon"] + self.rpc_origins: list[str] = data["rpc_origins"] + self.bot_public: bool = data["bot_public"] + self.bot_require_code_grant: bool = data["bot_require_code_grant"] + self.owner: User = state.create_user(data["owner"]) + + team: TeamPayload | None = data.get("team") + self.team: Team | None = Team(state, team) if team else None + + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] + + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + + self.primary_sku_id: int | None = utils._get_as_snowflake( + data, "primary_sku_id" + ) + self.slug: str | None = data.get("slug") + self._cover_image: str | None = data.get("cover_image") + self.terms_of_service_url: str | None = data.get("terms_of_service_url") + self.privacy_policy_url: str | None = data.get("privacy_policy_url") + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} id={self.id} name={self.name!r} " + f"description={self.description!r} public={self.bot_public} " + f"owner={self.owner!r}>" + ) + + @property + def icon(self) -> Asset | None: + """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path="app") + + @property + def cover_image(self) -> Asset | None: + """Optional[:class:`.Asset`]: Retrieves the cover image on a store embed, if any. + + This is only available if the application is a game sold on Discord. + """ + if self._cover_image is None: + return None + return Asset._from_cover_image(self._state, self.id, self._cover_image) + + @property + def guild(self) -> Guild | None: + """Optional[:class:`Guild`]: If this application is a game sold on Discord, + this field will be the guild to which it has been linked. + + .. versionadded:: 1.3 + """ + return self._state._get_guild(self.guild_id) + + +class PartialAppInfo: + """Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite` + + .. versionadded:: 2.0 + + Attributes + ---------- + id: :class:`int` + The application ID. + name: :class:`str` + The application name. + description: :class:`str` + The application description. + rpc_origins: Optional[List[:class:`str`]] + A list of RPC origin URLs, if RPC is enabled. + summary: :class:`str` + If this application is a game sold on Discord, + this field will be the summary field for the store page of its primary SKU. + verify_key: :class:`str` + The hex encoded key for verification in interactions and the + GameSDK's `GetTicket `_. + terms_of_service_url: Optional[:class:`str`] + The application's terms of service URL, if set. + privacy_policy_url: Optional[:class:`str`] + The application's privacy policy URL, if set. + """ + + __slots__ = ( + "_state", + "id", + "name", + "description", + "rpc_origins", + "summary", + "verify_key", + "terms_of_service_url", + "privacy_policy_url", + "_icon", + ) + + def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self.name: str = data["name"] + self._icon: str | None = data.get("icon") + self.description: str = data["description"] + self.rpc_origins: list[str] | None = data.get("rpc_origins") + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] + self.terms_of_service_url: str | None = data.get("terms_of_service_url") + self.privacy_policy_url: str | None = data.get("privacy_policy_url") + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>" + + @property + def icon(self) -> Asset | None: + """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path="app") diff --git a/discord/asset.py b/discord/asset.py new file mode 100644 index 0000000..9cfcb92 --- /dev/null +++ b/discord/asset.py @@ -0,0 +1,445 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import io +import os +from typing import TYPE_CHECKING, Any, Literal + +import yarl + +from . import utils +from .errors import DiscordException, InvalidArgument + +__all__ = ("Asset",) + +if TYPE_CHECKING: + ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] + ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] + +VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) +VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} + + +MISSING = utils.MISSING + + +class AssetMixin: + url: str + _state: Any | None + + async def read(self) -> bytes: + """|coro| + + Retrieves the content of this asset as a :class:`bytes` object. + + Returns + ------- + :class:`bytes` + The content of the asset. + + Raises + ------ + DiscordException + There was no internal connection state. + HTTPException + Downloading the asset failed. + NotFound + The asset was deleted. + """ + if self._state is None: + raise DiscordException("Invalid state (no ConnectionState provided)") + + return await self._state.http.get_from_cdn(self.url) + + async def save( + self, + fp: str | bytes | os.PathLike | io.BufferedIOBase, + *, + seek_begin: bool = True, + ) -> int: + """|coro| + + Saves this asset into a file-like object. + + Parameters + ---------- + fp: Union[:class:`io.BufferedIOBase`, :class:`os.PathLike`] + The file-like object to save this attachment to or the filename + to use. If a filename is passed then a file is created with that + filename and used instead. + seek_begin: :class:`bool` + Whether to seek to the beginning of the file after saving is + successfully done. + + Returns + ------- + :class:`int` + The number of bytes written. + + Raises + ------ + DiscordException + There was no internal connection state. + HTTPException + Downloading the asset failed. + NotFound + The asset was deleted. + """ + + data = await self.read() + if isinstance(fp, io.BufferedIOBase): + written = fp.write(data) + if seek_begin: + fp.seek(0) + return written + else: + with open(fp, "wb") as f: + return f.write(data) + + +class Asset(AssetMixin): + """Represents a CDN asset on Discord. + + .. container:: operations + + .. describe:: str(x) + + Returns the URL of the CDN asset. + + .. describe:: len(x) + + Returns the length of the CDN asset's URL. + + .. describe:: x == y + + Checks if the asset is equal to another asset. + + .. describe:: x != y + + Checks if the asset is not equal to another asset. + + .. describe:: hash(x) + + Returns the hash of the asset. + """ + + __slots__: tuple[str, ...] = ( + "_state", + "_url", + "_animated", + "_key", + ) + + BASE = "https://cdn.discordapp.com" + + def __init__(self, state, *, url: str, key: str, animated: bool = False): + self._state = state + self._url = url + self._animated = animated + self._key = key + + @classmethod + def _from_default_avatar(cls, state, index: int) -> Asset: + return cls( + state, + url=f"{cls.BASE}/embed/avatars/{index}.png", + key=str(index), + animated=False, + ) + + @classmethod + def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: + animated = avatar.startswith("a_") + format = "gif" if animated else "png" + return cls( + state, + url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024", + key=avatar, + animated=animated, + ) + + @classmethod + def _from_guild_avatar( + cls, state, guild_id: int, member_id: int, avatar: str + ) -> Asset: + animated = avatar.startswith("a_") + format = "gif" if animated else "png" + return cls( + state, + url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", + key=avatar, + animated=animated, + ) + + @classmethod + def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: + return cls( + state, + url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", + key=icon_hash, + animated=False, + ) + + @classmethod + def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: + return cls( + state, + url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", + key=cover_image_hash, + animated=False, + ) + + @classmethod + def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: + animated = False + format = "png" + if path == "banners": + animated = image.startswith("a_") + format = "gif" if animated else "png" + + return cls( + state, + url=f"{cls.BASE}/{path}/{guild_id}/{image}.{format}?size=1024", + key=image, + animated=animated, + ) + + @classmethod + def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: + animated = icon_hash.startswith("a_") + format = "gif" if animated else "png" + return cls( + state, + url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024", + key=icon_hash, + animated=animated, + ) + + @classmethod + def _from_sticker_banner(cls, state, banner: int) -> Asset: + return cls( + state, + url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", + key=str(banner), + animated=False, + ) + + @classmethod + def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: + animated = banner_hash.startswith("a_") + format = "gif" if animated else "png" + return cls( + state, + url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512", + key=banner_hash, + animated=animated, + ) + + @classmethod + def _from_scheduled_event_cover( + cls, state, event_id: int, cover_hash: str + ) -> Asset: + return cls( + state, + url=f"{cls.BASE}/guild-events/{event_id}/{cover_hash}.png", + key=cover_hash, + animated=False, + ) + + def __str__(self) -> str: + return self._url + + def __len__(self) -> int: + return len(self._url) + + def __repr__(self): + shorten = self._url.replace(self.BASE, "") + return f"" + + def __eq__(self, other): + return isinstance(other, Asset) and self._url == other._url + + def __hash__(self): + return hash(self._url) + + @property + def url(self) -> str: + """:class:`str`: Returns the underlying URL of the asset.""" + return self._url + + @property + def key(self) -> str: + """:class:`str`: Returns the identifying key of the asset.""" + return self._key + + def is_animated(self) -> bool: + """:class:`bool`: Returns whether the asset is animated.""" + return self._animated + + def replace( + self, + *, + size: int = MISSING, + format: ValidAssetFormatTypes = MISSING, + static_format: ValidStaticFormatTypes = MISSING, + ) -> Asset: + """Returns a new asset with the passed components replaced. + + Parameters + ---------- + size: :class:`int` + The new size of the asset. + format: :class:`str` + The new format to change it to. Must be either + 'webp', 'jpeg', 'jpg', 'png', or 'gif' if it's animated. + static_format: :class:`str` + The new format to change it to if the asset isn't animated. + Must be either 'webp', 'jpeg', 'jpg', or 'png'. + + Returns + ------- + :class:`Asset` + The newly updated asset. + + Raises + ------ + InvalidArgument + An invalid size or format was passed. + """ + url = yarl.URL(self._url) + path, _ = os.path.splitext(url.path) + + if format is not MISSING: + if self._animated: + if format not in VALID_ASSET_FORMATS: + raise InvalidArgument( + f"format must be one of {VALID_ASSET_FORMATS}" + ) + url = url.with_path(f"{path}.{format}") + elif static_format is MISSING: + if format not in VALID_STATIC_FORMATS: + raise InvalidArgument( + f"format must be one of {VALID_STATIC_FORMATS}" + ) + url = url.with_path(f"{path}.{format}") + + if static_format is not MISSING and not self._animated: + if static_format not in VALID_STATIC_FORMATS: + raise InvalidArgument( + f"static_format must be one of {VALID_STATIC_FORMATS}" + ) + url = url.with_path(f"{path}.{static_format}") + + if size is not MISSING: + if not utils.valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 4096") + url = url.with_query(size=size) + else: + url = url.with_query(url.raw_query_string) + + url = str(url) + return Asset(state=self._state, url=url, key=self._key, animated=self._animated) + + def with_size(self, size: int, /) -> Asset: + """Returns a new asset with the specified size. + + Parameters + ---------- + size: :class:`int` + The new size of the asset. + + Returns + ------- + :class:`Asset` + The new updated asset. + + Raises + ------ + InvalidArgument + The asset had an invalid size. + """ + if not utils.valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 4096") + + url = str(yarl.URL(self._url).with_query(size=size)) + return Asset(state=self._state, url=url, key=self._key, animated=self._animated) + + def with_format(self, format: ValidAssetFormatTypes, /) -> Asset: + """Returns a new asset with the specified format. + + Parameters + ---------- + format: :class:`str` + The new format of the asset. + + Returns + ------- + :class:`Asset` + The new updated asset. + + Raises + ------ + InvalidArgument + The asset has an invalid format. + """ + + if self._animated: + if format not in VALID_ASSET_FORMATS: + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") + elif format not in VALID_STATIC_FORMATS: + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") + + url = yarl.URL(self._url) + path, _ = os.path.splitext(url.path) + url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string)) + return Asset(state=self._state, url=url, key=self._key, animated=self._animated) + + def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: + """Returns a new asset with the specified static format. + + This only changes the format if the underlying asset is + not animated. Otherwise, the asset is not changed. + + Parameters + ---------- + format: :class:`str` + The new static format of the asset. + + Returns + ------- + :class:`Asset` + The new updated asset. + + Raises + ------ + InvalidArgument + The asset had an invalid format. + """ + + if self._animated: + return self + return self.with_format(format) diff --git a/discord/audit_logs.py b/discord/audit_logs.py new file mode 100644 index 0000000..892c7e8 --- /dev/null +++ b/discord/audit_logs.py @@ -0,0 +1,634 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generator, TypeVar + +from . import enums, utils +from .asset import Asset +from .colour import Colour +from .invite import Invite +from .mixins import Hashable +from .object import Object +from .permissions import PermissionOverwrite, Permissions + +__all__ = ( + "AuditLogDiff", + "AuditLogChanges", + "AuditLogEntry", +) + + +if TYPE_CHECKING: + import datetime + + from . import abc + from .emoji import Emoji + from .guild import Guild + from .member import Member + from .role import Role + from .scheduled_events import ScheduledEvent + from .stage_instance import StageInstance + from .state import ConnectionState + from .sticker import GuildSticker + from .threads import Thread + from .types.audit_log import AuditLogChange as AuditLogChangePayload + from .types.audit_log import AuditLogEntry as AuditLogEntryPayload + from .types.channel import PermissionOverwrite as PermissionOverwritePayload + from .types.role import Role as RolePayload + from .types.snowflake import Snowflake + from .user import User + + +def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions: + return Permissions(int(data)) + + +def _transform_color(entry: AuditLogEntry, data: int) -> Colour: + return Colour(data) + + +def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: + return int(data) + + +def _transform_channel( + entry: AuditLogEntry, data: Snowflake | None +) -> abc.GuildChannel | Object | None: + if data is None: + return None + return entry.guild.get_channel(int(data)) or Object(id=data) + + +def _transform_member_id( + entry: AuditLogEntry, data: Snowflake | None +) -> Member | User | None: + if data is None: + return None + return entry._get_member(int(data)) + + +def _transform_guild_id(entry: AuditLogEntry, data: Snowflake | None) -> Guild | None: + if data is None: + return None + return entry._state._get_guild(data) + + +def _transform_overwrites( + entry: AuditLogEntry, data: list[PermissionOverwritePayload] +) -> list[tuple[Object, PermissionOverwrite]]: + overwrites = [] + for elem in data: + allow = Permissions(int(elem["allow"])) + deny = Permissions(int(elem["deny"])) + ow = PermissionOverwrite.from_pair(allow, deny) + + ow_type = elem["type"] + ow_id = int(elem["id"]) + target = None + if ow_type == "0": + target = entry.guild.get_role(ow_id) + elif ow_type == "1": + target = entry._get_member(ow_id) + + if target is None: + target = Object(id=ow_id) + + overwrites.append((target, ow)) + + return overwrites + + +def _transform_icon(entry: AuditLogEntry, data: str | None) -> Asset | None: + if data is None: + return None + return Asset._from_guild_icon(entry._state, entry.guild.id, data) + + +def _transform_avatar(entry: AuditLogEntry, data: str | None) -> Asset | None: + if data is None: + return None + return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore + + +def _transform_scheduled_event_cover( + entry: AuditLogEntry, data: str | None +) -> Asset | None: + if data is None: + return None + return Asset._from_scheduled_event_cover(entry._state, entry._target_id, data) + + +def _guild_hash_transformer( + path: str, +) -> Callable[[AuditLogEntry, str | None], Asset | None]: + def _transform(entry: AuditLogEntry, data: str | None) -> Asset | None: + if data is None: + return None + return Asset._from_guild_image(entry._state, entry.guild.id, data, path=path) + + return _transform + + +T = TypeVar("T", bound=enums.Enum) + + +def _enum_transformer(enum: type[T]) -> Callable[[AuditLogEntry, int], T]: + def _transform(entry: AuditLogEntry, data: int) -> T: + return enums.try_enum(enum, data) + + return _transform + + +def _transform_type( + entry: AuditLogEntry, data: int +) -> enums.ChannelType | enums.StickerType: + if entry.action.name.startswith("sticker_"): + return enums.try_enum(enums.StickerType, data) + else: + return enums.try_enum(enums.ChannelType, data) + + +class AuditLogDiff: + def __len__(self) -> int: + return len(self.__dict__) + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + yield from self.__dict__.items() + + def __repr__(self) -> str: + values = " ".join("%s=%r" % item for item in self.__dict__.items()) + return f"" + + if TYPE_CHECKING: + + def __getattr__(self, item: str) -> Any: + ... + + def __setattr__(self, key: str, value: Any) -> Any: + ... + + +Transformer = Callable[["AuditLogEntry", Any], Any] + + +class AuditLogChanges: + TRANSFORMERS: ClassVar[dict[str, tuple[str | None, Transformer | None]]] = { + "verification_level": (None, _enum_transformer(enums.VerificationLevel)), + "explicit_content_filter": (None, _enum_transformer(enums.ContentFilter)), + "allow": (None, _transform_permissions), + "deny": (None, _transform_permissions), + "permissions": (None, _transform_permissions), + "id": (None, _transform_snowflake), + "color": ("colour", _transform_color), + "owner_id": ("owner", _transform_member_id), + "inviter_id": ("inviter", _transform_member_id), + "channel_id": ("channel", _transform_channel), + "afk_channel_id": ("afk_channel", _transform_channel), + "system_channel_id": ("system_channel", _transform_channel), + "widget_channel_id": ("widget_channel", _transform_channel), + "rules_channel_id": ("rules_channel", _transform_channel), + "public_updates_channel_id": ("public_updates_channel", _transform_channel), + "permission_overwrites": ("overwrites", _transform_overwrites), + "splash_hash": ("splash", _guild_hash_transformer("splashes")), + "banner_hash": ("banner", _guild_hash_transformer("banners")), + "discovery_splash_hash": ( + "discovery_splash", + _guild_hash_transformer("discovery-splashes"), + ), + "icon_hash": ("icon", _transform_icon), + "avatar_hash": ("avatar", _transform_avatar), + "rate_limit_per_user": ("slowmode_delay", None), + "guild_id": ("guild", _transform_guild_id), + "tags": ("emoji", None), + "default_message_notifications": ( + "default_notifications", + _enum_transformer(enums.NotificationLevel), + ), + "rtc_region": (None, _enum_transformer(enums.VoiceRegion)), + "video_quality_mode": (None, _enum_transformer(enums.VideoQualityMode)), + "privacy_level": (None, _enum_transformer(enums.StagePrivacyLevel)), + "format_type": (None, _enum_transformer(enums.StickerFormatType)), + "type": (None, _transform_type), + "status": (None, _enum_transformer(enums.ScheduledEventStatus)), + "entity_type": ( + "location_type", + _enum_transformer(enums.ScheduledEventLocationType), + ), + "command_id": ("command_id", _transform_snowflake), + "image_hash": ("cover", _transform_scheduled_event_cover), + } + + def __init__( + self, + entry: AuditLogEntry, + data: list[AuditLogChangePayload], + *, + state: ConnectionState, + ): + self.before = AuditLogDiff() + self.after = AuditLogDiff() + + for elem in sorted(data, key=lambda i: i["key"]): + attr = elem["key"] + + # special cases for role add/remove + if attr == "$add": + self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore + continue + elif attr == "$remove": + self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore + continue + + try: + key, transformer = self.TRANSFORMERS[attr] + except (ValueError, KeyError): + transformer = None + else: + if key: + attr = key + + transformer: Transformer | None + + try: + before = elem["old_value"] + except KeyError: + before = None + else: + if transformer: + before = transformer(entry, before) + + if attr == "location" and hasattr(self.before, "location_type"): + from .scheduled_events import ScheduledEventLocation + + if ( + self.before.location_type + is enums.ScheduledEventLocationType.external + ): + before = ScheduledEventLocation(state=state, value=before) + elif hasattr(self.before, "channel"): + before = ScheduledEventLocation( + state=state, value=self.before.channel + ) + + setattr(self.before, attr, before) + + try: + after = elem["new_value"] + except KeyError: + after = None + else: + if transformer: + after = transformer(entry, after) + + if attr == "location" and hasattr(self.after, "location_type"): + from .scheduled_events import ScheduledEventLocation + + if ( + self.after.location_type + is enums.ScheduledEventLocationType.external + ): + after = ScheduledEventLocation(state=state, value=after) + elif hasattr(self.after, "channel"): + after = ScheduledEventLocation( + state=state, value=self.after.channel + ) + + setattr(self.after, attr, after) + + # add an alias + if hasattr(self.after, "colour"): + self.after.color = self.after.colour + self.before.color = self.before.colour + if hasattr(self.after, "expire_behavior"): + self.after.expire_behaviour = self.after.expire_behavior + self.before.expire_behaviour = self.before.expire_behavior + + def __repr__(self) -> str: + return f"" + + def _handle_role( + self, + first: AuditLogDiff, + second: AuditLogDiff, + entry: AuditLogEntry, + elem: list[RolePayload], + ) -> None: + if not hasattr(first, "roles"): + setattr(first, "roles", []) + + data = [] + g: Guild = entry.guild # type: ignore + + for e in elem: + role_id = int(e["id"]) + role = g.get_role(role_id) + + if role is None: + role = Object(id=role_id) + role.name = e["name"] # type: ignore + + data.append(role) + + setattr(second, "roles", data) + + +class _AuditLogProxyMemberPrune: + delete_member_days: int + members_removed: int + + +class _AuditLogProxyMemberMoveOrMessageDelete: + channel: abc.GuildChannel + count: int + + +class _AuditLogProxyMemberDisconnect: + count: int + + +class _AuditLogProxyPinAction: + channel: abc.GuildChannel + message_id: int + + +class _AuditLogProxyStageInstanceAction: + channel: abc.GuildChannel + + +class AuditLogEntry(Hashable): + r"""Represents an Audit Log entry. + + You retrieve these via :meth:`Guild.audit_logs`. + + .. container:: operations + + .. describe:: x == y + + Checks if two entries are equal. + + .. describe:: x != y + + Checks if two entries are not equal. + + .. describe:: hash(x) + + Returns the entry's hash. + + .. versionchanged:: 1.7 + Audit log entries are now comparable and hashable. + + Attributes + ----------- + action: :class:`AuditLogAction` + The action that was done. + user: :class:`abc.User` + The user who initiated this action. Usually a :class:`Member`\, unless gone + then it's a :class:`User`. + id: :class:`int` + The entry ID. + target: Any + The target that got changed. The exact type of this depends on + the action being done. + reason: Optional[:class:`str`] + The reason this action was done. + extra: Any + Extra information that this entry has that might be useful. + For most actions, this is ``None``. However, in some cases it + contains extra information. See :class:`AuditLogAction` for + which actions have this field filled out. + """ + + def __init__( + self, *, users: dict[int, User], data: AuditLogEntryPayload, guild: Guild + ): + self._state = guild._state + self.guild = guild + self._users = users + self._from_data(data) + + def _from_data(self, data: AuditLogEntryPayload) -> None: + self.action = enums.try_enum(enums.AuditLogAction, data["action_type"]) + self.id = int(data["id"]) + + # this key is technically not usually present + self.reason = data.get("reason") + self.extra = data.get("options") + + if isinstance(self.action, enums.AuditLogAction) and self.extra: + if self.action is enums.AuditLogAction.member_prune: + # member prune has two keys with useful information + self.extra: _AuditLogProxyMemberPrune = type( + "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} + )() + elif ( + self.action is enums.AuditLogAction.member_move + or self.action is enums.AuditLogAction.message_delete + ): + channel_id = int(self.extra["channel_id"]) + elems = { + "count": int(self.extra["count"]), + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id), + } + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type( + "_AuditLogProxy", (), elems + )() + elif self.action is enums.AuditLogAction.member_disconnect: + # The member disconnect action has a dict with some information + elems = { + "count": int(self.extra["count"]), + } + self.extra: _AuditLogProxyMemberDisconnect = type( + "_AuditLogProxy", (), elems + )() + elif self.action.name.endswith("pin"): + # the pin actions have a dict with some information + channel_id = int(self.extra["channel_id"]) + elems = { + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id), + "message_id": int(self.extra["message_id"]), + } + self.extra: _AuditLogProxyPinAction = type( + "_AuditLogProxy", (), elems + )() + elif self.action.name.startswith("overwrite_"): + # the overwrite_ actions have a dict with some information + instance_id = int(self.extra["id"]) + the_type = self.extra.get("type") + if the_type == "1": + self.extra = self._get_member(instance_id) + elif the_type == "0": + role = self.guild.get_role(instance_id) + if role is None: + role = Object(id=instance_id) + role.name = self.extra.get("role_name") # type: ignore + self.extra: Role = role + elif self.action.name.startswith("stage_instance"): + channel_id = int(self.extra["channel_id"]) + elems = { + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id) + } + self.extra: _AuditLogProxyStageInstanceAction = type( + "_AuditLogProxy", (), elems + )() + + self.extra: ( + _AuditLogProxyMemberPrune + | _AuditLogProxyMemberMoveOrMessageDelete + | _AuditLogProxyMemberDisconnect + | _AuditLogProxyPinAction + | _AuditLogProxyStageInstanceAction + | Member + | User + | None + | Role + ) + + # this key is not present when the above is present, typically. + # It's a list of { new_value: a, old_value: b, key: c } + # where new_value and old_value are not guaranteed to be there depending + # on the action type, so let's just fetch it for now and only turn it + # into meaningful data when requested + self._changes = data.get("changes", []) + + self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore + self._target_id = utils._get_as_snowflake(data, "target_id") + + def _get_member(self, user_id: int) -> Member | User | None: + return self.guild.get_member(user_id) or self._users.get(user_id) + + def __repr__(self) -> str: + return f"" + + @utils.cached_property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the entry's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @utils.cached_property + def target( + self, + ) -> ( + Guild + | abc.GuildChannel + | Member + | User + | Role + | Invite + | Emoji + | StageInstance + | GuildSticker + | Thread + | Object + | None + ): + try: + converter = getattr(self, f"_convert_target_{self.action.target_type}") + except AttributeError: + return Object(id=self._target_id) + else: + return converter(self._target_id) + + @utils.cached_property + def category(self) -> enums.AuditLogActionCategory: + """Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable.""" + return self.action.category + + @utils.cached_property + def changes(self) -> AuditLogChanges: + """:class:`AuditLogChanges`: The list of changes this entry has.""" + obj = AuditLogChanges(self, self._changes, state=self._state) + del self._changes + return obj + + @utils.cached_property + def before(self) -> AuditLogDiff: + """:class:`AuditLogDiff`: The target's prior state.""" + return self.changes.before + + @utils.cached_property + def after(self) -> AuditLogDiff: + """:class:`AuditLogDiff`: The target's subsequent state.""" + return self.changes.after + + def _convert_target_guild(self, target_id: int) -> Guild: + return self.guild + + def _convert_target_channel(self, target_id: int) -> abc.GuildChannel | Object: + return self.guild.get_channel(target_id) or Object(id=target_id) + + def _convert_target_user(self, target_id: int) -> Member | User | None: + return self._get_member(target_id) + + def _convert_target_role(self, target_id: int) -> Role | Object: + return self.guild.get_role(target_id) or Object(id=target_id) + + def _convert_target_invite(self, target_id: int) -> Invite: + # invites have target_id set to null + # so figure out which change has the full invite data + changeset = ( + self.before + if self.action is enums.AuditLogAction.invite_delete + else self.after + ) + + fake_payload = { + "max_age": changeset.max_age, + "max_uses": changeset.max_uses, + "code": changeset.code, + "temporary": changeset.temporary, + "uses": changeset.uses, + } + + obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore + try: + obj.inviter = changeset.inviter + except AttributeError: + pass + return obj + + def _convert_target_emoji(self, target_id: int) -> Emoji | Object: + return self._state.get_emoji(target_id) or Object(id=target_id) + + def _convert_target_message(self, target_id: int) -> Member | User | None: + return self._get_member(target_id) + + def _convert_target_stage_instance(self, target_id: int) -> StageInstance | Object: + return self.guild.get_stage_instance(target_id) or Object(id=target_id) + + def _convert_target_sticker(self, target_id: int) -> GuildSticker | Object: + return self._state.get_sticker(target_id) or Object(id=target_id) + + def _convert_target_thread(self, target_id: int) -> Thread | Object: + return self.guild.get_thread(target_id) or Object(id=target_id) + + def _convert_target_scheduled_event( + self, target_id: int + ) -> ScheduledEvent | Object: + return self.guild.get_scheduled_event(target_id) or Object(id=target_id) diff --git a/discord/automod.py b/discord/automod.py new file mode 100644 index 0000000..02fbd4a --- /dev/null +++ b/discord/automod.py @@ -0,0 +1,477 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from datetime import timedelta +from functools import cached_property +from typing import TYPE_CHECKING + +from . import utils +from .enums import ( + AutoModActionType, + AutoModEventType, + AutoModKeywordPresetType, + AutoModTriggerType, + try_enum, +) +from .mixins import Hashable +from .object import Object + +__all__ = ("AutoModRule",) + +if TYPE_CHECKING: + from .abc import Snowflake + from .channel import ForumChannel, TextChannel, VoiceChannel + from .guild import Guild + from .member import Member + from .role import Role + from .state import ConnectionState + from .types.automod import AutoModAction as AutoModActionPayload + from .types.automod import AutoModActionMetadata as AutoModActionMetadataPayload + from .types.automod import AutoModRule as AutoModRulePayload + from .types.automod import AutoModTriggerMetadata as AutoModTriggerMetadataPayload + +MISSING = utils.MISSING + + +class AutoModActionMetadata: + """Represents an action's metadata. + + Depending on the action's type, different attributes will be used. + + .. versionadded:: 2.0 + + Attributes + ---------- + channel_id: :class:`int` + The ID of the channel to send the message to. + Only for actions of type :attr:`AutoModActionType.send_alert_message`. + timeout_duration: :class:`datetime.timedelta` + How long the member that triggered the action should be timed out for. + Only for actions of type :attr:`AutoModActionType.timeout`. + """ + + # maybe add a table of action types and attributes? + + __slots__ = ( + "channel_id", + "timeout_duration", + ) + + def __init__( + self, channel_id: int = MISSING, timeout_duration: timedelta = MISSING + ): + self.channel_id: int = channel_id + self.timeout_duration: timedelta = timeout_duration + + def to_dict(self) -> dict: + data = {} + + if self.channel_id is not MISSING: + data["channel_id"] = self.channel_id + + if self.timeout_duration is not MISSING: + data["duration_seconds"] = self.timeout_duration.total_seconds() + + return data + + @classmethod + def from_dict(cls, data: AutoModActionMetadataPayload): + kwargs = {} + + if (channel_id := data.get("channel_id")) is not None: + kwargs["channel_id"] = int(channel_id) + + if (duration_seconds := data.get("duration_seconds")) is not None: + # might need an explicit int cast + kwargs["timeout_duration"] = timedelta(seconds=duration_seconds) + + return cls(**kwargs) + + def __repr__(self) -> str: + repr_attrs = ( + "channel_id", + "timeout_duration", + ) + inner = [] + + for attr in repr_attrs: + if (value := getattr(self, attr)) is not MISSING: + inner.append(f"{attr}={value}") + inner = " ".join(inner) + + return f"" + + +class AutoModAction: + """Represents an action for a guild's auto moderation rule. + + .. versionadded:: 2.0 + + Attributes + ---------- + type: :class:`AutoModActionType` + The action's type. + metadata: :class:`AutoModActionMetadata` + The action's metadata. + """ + + # note that AutoModActionType.timeout is only valid for trigger type 1? + + __slots__ = ( + "type", + "metadata", + ) + + def __init__(self, action_type: AutoModActionType, metadata: AutoModActionMetadata): + self.type: AutoModActionType = action_type + self.metadata: AutoModActionMetadata = metadata + + def to_dict(self) -> dict: + return { + "type": self.type.value, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: AutoModActionPayload): + return cls( + try_enum(AutoModActionType, data["type"]), + AutoModActionMetadata.from_dict(data["metadata"]), + ) + + def __repr__(self) -> str: + return f"" + + +class AutoModTriggerMetadata: + """Represents a rule's trigger metadata. + + Depending on the trigger type, different attributes will be used. + + .. versionadded:: 2.0 + + Attributes + ---------- + keyword_filter: List[:class:`str`] + A list of substrings to filter. Only for triggers of type :attr:`AutoModTriggerType.keyword`. + presets: List[:class:`AutoModKeywordPresetType`] + A list of keyword presets to filter. Only for triggers of type :attr:`AutoModTriggerType.keyword_preset`. + """ + + # maybe add a table of action types and attributes? + # wording for presets could change + + __slots__ = ( + "keyword_filter", + "presets", + ) + + def __init__( + self, + keyword_filter: list[str] = MISSING, + presets: list[AutoModKeywordPresetType] = MISSING, + ): + self.keyword_filter = keyword_filter + self.presets = presets + + def to_dict(self) -> dict: + data = {} + + if self.keyword_filter is not MISSING: + data["keyword_filter"] = self.keyword_filter + + if self.presets is not MISSING: + data["presets"] = [wordset.value for wordset in self.presets] + + return data + + @classmethod + def from_dict(cls, data: AutoModTriggerMetadataPayload): + kwargs = {} + + if (keyword_filter := data.get("keyword_filter")) is not None: + kwargs["keyword_filter"] = keyword_filter + + if (presets := data.get("presets")) is not None: + kwargs["presets"] = [ + try_enum(AutoModKeywordPresetType, wordset) for wordset in presets + ] + + return cls(**kwargs) + + def __repr__(self) -> str: + repr_attrs = ( + "keyword_filter", + "presets", + ) + inner = [] + + for attr in repr_attrs: + if (value := getattr(self, attr)) is not MISSING: + inner.append(f"{attr}={value}") + inner = " ".join(inner) + + return f"" + + +class AutoModRule(Hashable): + """Represents a guild's auto moderation rule. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two rules are equal. + + .. describe:: x != y + + Checks if two rules are not equal. + + .. describe:: hash(x) + + Returns the rule's hash. + + .. describe:: str(x) + + Returns the rule's name. + + Attributes + ---------- + id: :class:`int` + The rule's ID. + name: :class:`str` + The rule's name. + creator_id: :class:`int` + The ID of the user who created this rule. + event_type: :class:`AutoModEventType` + Indicates in what context the rule is checked. + trigger_type: :class:`AutoModTriggerType` + Indicates what type of information is checked to determine whether the rule is triggered. + trigger_metadata: :class:`AutoModTriggerMetadata` + The rule's trigger metadata. + actions: List[:class:`AutoModAction`] + The actions to perform when the rule is triggered. + enabled: :class:`bool` + Whether this rule is enabled. + exempt_role_ids: List[:class:`int`] + The IDs of the roles that are exempt from this rule. + exempt_channel_ids: List[:class:`int`] + The IDs of the channels that are exempt from this rule. + """ + + __slots__ = ( + "_state", + "id", + "guild_id", + "name", + "creator_id", + "event_type", + "trigger_type", + "trigger_metadata", + "actions", + "enabled", + "exempt_role_ids", + "exempt_channel_ids", + ) + + def __init__( + self, + *, + state: ConnectionState, + data: AutoModRulePayload, + ): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self.guild_id: int = int(data["guild_id"]) + self.name: str = data["name"] + self.creator_id: int = int(data["creator_id"]) + self.event_type: AutoModEventType = try_enum( + AutoModEventType, data["event_type"] + ) + self.trigger_type: AutoModTriggerType = try_enum( + AutoModTriggerType, data["trigger_type"] + ) + self.trigger_metadata: AutoModTriggerMetadata = ( + AutoModTriggerMetadata.from_dict(data["trigger_metadata"]) + ) + self.actions: list[AutoModAction] = [ + AutoModAction.from_dict(d) for d in data["actions"] + ] + self.enabled: bool = data["enabled"] + self.exempt_role_ids: list[int] = [int(r) for r in data["exempt_roles"]] + self.exempt_channel_ids: list[int] = [int(c) for c in data["exempt_channels"]] + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name + + @cached_property + def guild(self) -> Guild | None: + """Optional[:class:`Guild`]: The guild this rule belongs to.""" + return self._state._get_guild(self.guild_id) + + @cached_property + def creator(self) -> Member | None: + """Optional[:class:`Member`]: The member who created this rule.""" + if self.guild is None: + return None + return self.guild.get_member(self.creator_id) + + @cached_property + def exempt_roles(self) -> list[Role | Object]: + """List[Union[:class:`Role`, :class:`Object`]]: The roles that are exempt + from this rule. + + If a role is not found in the guild's cache, + then it will be returned as an :class:`Object`. + """ + if self.guild is None: + return [Object(role_id) for role_id in self.exempt_role_ids] + return [ + self.guild.get_role(role_id) or Object(role_id) + for role_id in self.exempt_role_ids + ] + + @cached_property + def exempt_channels( + self, + ) -> list[TextChannel | ForumChannel | VoiceChannel | Object]: + """List[Union[Union[:class:`TextChannel`, :class:`ForumChannel`, :class:`VoiceChannel`], :class:`Object`]]: The + channels that are exempt from this rule. + + If a channel is not found in the guild's cache, + then it will be returned as an :class:`Object`. + """ + if self.guild is None: + return [Object(channel_id) for channel_id in self.exempt_channel_ids] + return [ + self.guild.get_channel(channel_id) or Object(channel_id) + for channel_id in self.exempt_channel_ids + ] + + async def delete(self, reason: str | None = None) -> None: + """|coro| + + Deletes this rule. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this rule. Shows up in the audit log. + + Raises + ------ + Forbidden + You do not have the Manage Guild permission. + HTTPException + The operation failed. + """ + await self._state.http.delete_auto_moderation_rule( + self.guild_id, self.id, reason=reason + ) + + async def edit( + self, + *, + name: str = MISSING, + event_type: AutoModEventType = MISSING, + trigger_metadata: AutoModTriggerMetadata = MISSING, + actions: list[AutoModAction] = MISSING, + enabled: bool = MISSING, + exempt_roles: list[Snowflake] = MISSING, + exempt_channels: list[Snowflake] = MISSING, + reason: str | None = None, + ) -> AutoModRule | None: + """|coro| + + Edits this rule. + + Parameters + ---------- + name: :class:`str` + The rule's new name. + event_type: :class:`AutoModEventType` + The new context in which the rule is checked. + trigger_metadata: :class:`AutoModTriggerMetadata` + The new trigger metadata. + actions: List[:class:`AutoModAction`] + The new actions to perform when the rule is triggered. + enabled: :class:`bool` + Whether this rule is enabled. + exempt_roles: List[:class:`Snowflake`] + The roles that will be exempt from this rule. + exempt_channels: List[:class:`Snowflake`] + The channels that will be exempt from this rule. + reason: Optional[:class:`str`] + The reason for editing this rule. Shows up in the audit log. + + Returns + ------- + Optional[:class:`.AutoModRule`] + The newly updated rule, if applicable. This is only returned + when fields are updated. + + Raises + ------ + Forbidden + You do not have the Manage Guild permission. + HTTPException + The operation failed. + """ + http = self._state.http + payload = {} + + if name is not MISSING: + payload["name"] = name + + if event_type is not MISSING: + payload["event_type"] = event_type.value + + if trigger_metadata is not MISSING: + payload["trigger_metadata"] = trigger_metadata.to_dict() + + if actions is not MISSING: + payload["actions"] = [a.to_dict() for a in actions] + + if enabled is not MISSING: + payload["enabled"] = enabled + + # Maybe consider enforcing limits on the number of exempt roles/channels? + if exempt_roles is not MISSING: + payload["exempt_roles"] = [r.id for r in exempt_roles] + + if exempt_channels is not MISSING: + payload["exempt_channels"] = [c.id for c in exempt_channels] + + if payload: + data = await http.edit_auto_moderation_rule( + self.guild_id, self.id, payload, reason=reason + ) + return AutoModRule(state=self._state, data=data) diff --git a/discord/backoff.py b/discord/backoff.py new file mode 100644 index 0000000..009df69 --- /dev/null +++ b/discord/backoff.py @@ -0,0 +1,104 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import random +import time +from typing import Callable, Generic, Literal, TypeVar, overload + +T = TypeVar("T", bool, Literal[True], Literal[False]) + +__all__ = ("ExponentialBackoff",) + + +class ExponentialBackoff(Generic[T]): + """An implementation of the exponential backoff algorithm + + Provides a convenient interface to implement an exponential backoff + for reconnecting or retrying transmissions in a distributed network. + + Once instantiated, the delay method will return the next interval to + wait for when retrying a connection or transmission. The maximum + delay increases exponentially with each retry up to a maximum of + 2^10 * base, and is reset if no more attempts are needed in a period + of 2^11 * base seconds. + + Parameters + ---------- + base: :class:`int` + The base delay in seconds. The first retry-delay will be up to + this many seconds. + integral: :class:`bool` + Set to ``True`` if whole periods of base is desirable, otherwise any + number in between may be returned. + """ + + def __init__(self, base: int = 1, *, integral: T = False): + self._base: int = base + + self._exp: int = 0 + self._max: int = 10 + self._reset_time: int = base * 2**11 + self._last_invocation: float = time.monotonic() + + # Use our own random instance to avoid messing with global one + rand = random.Random() + rand.seed() + + self._randfunc: Callable[..., int | float] = rand.randrange if integral else rand.uniform # type: ignore + + @overload + def delay(self: ExponentialBackoff[Literal[False]]) -> float: + ... + + @overload + def delay(self: ExponentialBackoff[Literal[True]]) -> int: + ... + + @overload + def delay(self: ExponentialBackoff[bool]) -> int | float: + ... + + def delay(self) -> int | float: + """Compute the next delay + + Returns the next delay to wait according to the exponential + backoff algorithm. This is a value between 0 and base * 2^exp + where exponent starts off at 1 and is incremented at every + invocation of this method up to a maximum of 10. + + If a period of more than base * 2^11 has passed since the last + retry, the exponent is reset to 1. + """ + invocation = time.monotonic() + interval = invocation - self._last_invocation + self._last_invocation = invocation + + if interval > self._reset_time: + self._exp = 0 + + self._exp = min(self._exp + 1, self._max) + return self._randfunc(0, self._base * 2**self._exp) diff --git a/discord/bin/COPYING b/discord/bin/COPYING new file mode 100644 index 0000000..7b53d66 --- /dev/null +++ b/discord/bin/COPYING @@ -0,0 +1,28 @@ +Copyright (c) 1994-2013 Xiph.Org Foundation and contributors + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +- Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +- Neither the name of the Xiph.Org Foundation nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION +OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/discord/bin/libopus-0.x64.dll b/discord/bin/libopus-0.x64.dll new file mode 100644 index 0000000..74a8e35 Binary files /dev/null and b/discord/bin/libopus-0.x64.dll differ diff --git a/discord/bin/libopus-0.x86.dll b/discord/bin/libopus-0.x86.dll new file mode 100644 index 0000000..ee71317 Binary files /dev/null and b/discord/bin/libopus-0.x86.dll differ diff --git a/discord/bot.py b/discord/bot.py new file mode 100644 index 0000000..697a9be --- /dev/null +++ b/discord/bot.py @@ -0,0 +1,1523 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import collections +import collections.abc +import copy +import inspect +import logging +import sys +import traceback +from abc import ABC, abstractmethod +from typing import Any, Callable, Coroutine, Generator, Literal, Mapping, TypeVar + +from .client import Client +from .cog import CogMixin +from .commands import ( + ApplicationCommand, + ApplicationContext, + AutocompleteContext, + MessageCommand, + SlashCommand, + SlashCommandGroup, + UserCommand, + command, +) +from .enums import InteractionType +from .errors import CheckFailure, DiscordException +from .interactions import Interaction +from .shard import AutoShardedClient +from .types import interactions +from .user import User +from .utils import MISSING, async_all, find, get + +CoroFunc = Callable[..., Coroutine[Any, Any, Any]] +CFT = TypeVar("CFT", bound=CoroFunc) + +__all__ = ( + "ApplicationCommandMixin", + "Bot", + "AutoShardedBot", +) + +_log = logging.getLogger(__name__) + + +class ApplicationCommandMixin(ABC): + """A mixin that implements common functionality for classes that need + application command compatibility. + + Attributes + ---------- + application_commands: :class:`dict` + A mapping of command id string to :class:`.ApplicationCommand` objects. + pending_application_commands: :class:`list` + A list of commands that have been added but not yet registered. This is read-only and is modified via other + methods. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._pending_application_commands = [] + self._application_commands = {} + + @property + def all_commands(self): + return self._application_commands + + @property + def pending_application_commands(self): + return self._pending_application_commands + + @property + def commands(self) -> list[ApplicationCommand | Any]: + commands = self.application_commands + if self._bot._supports_prefixed_commands and hasattr( + self._bot, "prefixed_commands" + ): + commands += getattr(self._bot, "prefixed_commands") + return commands + + @property + def application_commands(self) -> list[ApplicationCommand]: + return list(self._application_commands.values()) + + def add_application_command(self, command: ApplicationCommand) -> None: + """Adds a :class:`.ApplicationCommand` into the internal list of commands. + + This is usually not called, instead the :meth:`command` or + other shortcut decorators are used instead. + + .. versionadded:: 2.0 + + Parameters + ---------- + command: :class:`.ApplicationCommand` + The command to add. + """ + if isinstance(command, SlashCommand) and command.is_subcommand: + raise TypeError("The provided command is a sub-command of group") + + if command.cog is MISSING: + command._set_cog(None) + + if self._bot.debug_guilds and command.guild_ids is None: + command.guild_ids = self._bot.debug_guilds + + for cmd in self.pending_application_commands: + if cmd == command: + command.id = cmd.id + self._application_commands[command.id] = command + break + self._pending_application_commands.append(command) + + def remove_application_command( + self, command: ApplicationCommand + ) -> ApplicationCommand | None: + """Remove a :class:`.ApplicationCommand` from the internal list + of commands. + + .. versionadded:: 2.0 + + Parameters + ---------- + command: :class:`.ApplicationCommand` + The command to remove. + + Returns + ------- + Optional[:class:`.ApplicationCommand`] + The command that was removed. If the name is not valid then + ``None`` is returned instead. + """ + if command.id is None: + try: + index = self._pending_application_commands.index(command) + except ValueError: + return None + return self._pending_application_commands.pop(index) + return self._application_commands.pop(command.id, None) + + @property + def get_command(self): + """Shortcut for :meth:`.get_application_command`. + + .. note:: + Overridden in :class:`ext.commands.Bot`. + + .. versionadded:: 2.0 + """ + # TODO: Do something like we did in self.commands for this + return self.get_application_command + + def get_application_command( + self, + name: str, + guild_ids: list[int] | None = None, + type: type[ApplicationCommand] = SlashCommand, + ) -> ApplicationCommand | None: + """Get a :class:`.ApplicationCommand` from the internal list + of commands. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the command to get. + guild_ids: List[:class:`int`] + The guild ids associated to the command to get. + type: Type[:class:`.ApplicationCommand`] + The type of the command to get. Defaults to :class:`.SlashCommand`. + + Returns + ------- + Optional[:class:`.ApplicationCommand`] + The command that was requested. If not found, returns ``None``. + """ + + for command in self._application_commands.values(): + if command.name == name and isinstance(command, type): + if guild_ids is not None and command.guild_ids != guild_ids: + return + return command + + async def get_desynced_commands( + self, + guild_id: int | None = None, + prefetched: list[interactions.ApplicationCommand] | None = None, + ) -> list[dict[str, Any]]: + """|coro| + + Gets the list of commands that are desynced from discord. If ``guild_id`` is specified, it will only return + guild commands that are desynced from said guild, else it will return global commands. + + .. note:: + This function is meant to be used internally, and should only be used if you want to override the default + command registration behavior. + + .. versionadded:: 2.0 + + Parameters + ---------- + guild_id: Optional[:class:`int`] + The guild id to get the desynced commands for, else global commands if unspecified. + prefetched: Optional[List[:class:`.ApplicationCommand`]] + If you already fetched the commands, you can pass them here to be used. Not recommended for typical usage. + + Returns + ------- + List[Dict[:class:`str`, Any]] + A list of the desynced commands. Each will come with at least the ``cmd`` and ``action`` keys, which + respectively contain the command and the action to perform. Other keys may also be present depending on + the action, including ``id``. + """ + + # We can suggest the user to upsert, edit, delete, or bulk upsert the commands + + def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: + if isinstance(cmd, SlashCommandGroup): + if len(cmd.subcommands) != len(match.get("options", [])): + return True + for i, subcommand in enumerate(cmd.subcommands): + match_ = next( + ( + data + for data in match["options"] + if data["name"] == subcommand.name + ), + MISSING, + ) + if match_ is not MISSING and _check_command(subcommand, match_): + return True + else: + as_dict = cmd.to_dict() + to_check = { + "dm_permission": None, + "default_member_permissions": None, + "name": None, + "description": None, + "name_localizations": None, + "description_localizations": None, + "options": [ + "type", + "name", + "description", + "autocomplete", + "choices", + "name_localizations", + "description_localizations", + ], + } + for check, value in to_check.items(): + if type(to_check[check]) == list: + # We need to do some falsy conversion here + # The API considers False (autocomplete) and [] (choices) to be falsy values + falsy_vals = (False, []) + for opt in value: + cmd_vals = ( + [val.get(opt, MISSING) for val in as_dict[check]] + if check in as_dict + else [] + ) + for i, val in enumerate(cmd_vals): + if val in falsy_vals: + cmd_vals[i] = MISSING + if match.get( + check, MISSING + ) is not MISSING and cmd_vals != [ + val.get(opt, MISSING) for val in match[check] + ]: + # We have a difference + return True + elif getattr(cmd, check, None) != match.get(check): + # We have a difference + if ( + check == "default_permission" + and getattr(cmd, check) is True + and match.get(check) is None + ): + # This is a special case + # TODO: Remove for perms v2 + continue + return True + return False + + return_value = [] + cmds = self.pending_application_commands.copy() + + if guild_id is None: + pending = [cmd for cmd in cmds if cmd.guild_ids is None] + else: + pending = [ + cmd + for cmd in cmds + if cmd.guild_ids is not None and guild_id in cmd.guild_ids + ] + + registered_commands: list[interactions.ApplicationCommand] = [] + if prefetched is not None: + registered_commands = prefetched + elif self._bot.user: + if guild_id is None: + registered_commands = await self._bot.http.get_global_commands( + self._bot.user.id + ) + else: + registered_commands = await self._bot.http.get_guild_commands( + self._bot.user.id, guild_id + ) + + registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands} + # First let's check if the commands we have locally are the same as the ones on discord + for cmd in pending: + match = registered_commands_dict.get(cmd.name) + if match is None: + # We don't have this command registered + return_value.append({"command": cmd, "action": "upsert"}) + elif _check_command(cmd, match): + return_value.append( + { + "command": cmd, + "action": "edit", + "id": int(registered_commands_dict[cmd.name]["id"]), + } + ) + else: + # We have this command registered but it's the same + return_value.append( + {"command": cmd, "action": None, "id": int(match["id"])} + ) + + # Now let's see if there are any commands on discord that we need to delete + for cmd, value_ in registered_commands_dict.items(): + match = get(pending, name=registered_commands_dict[cmd]["name"]) + if match is None: + # We have this command registered but not in our list + return_value.append( + { + "command": registered_commands_dict[cmd]["name"], + "id": int(value_["id"]), + "action": "delete", + } + ) + + continue + + return return_value + + async def register_command( + self, + command: ApplicationCommand, + force: bool = True, + guild_ids: list[int] | None = None, + ) -> None: + """|coro| + + Registers a command. If the command has ``guild_ids`` set, or if the ``guild_ids`` parameter is passed, + the command will be registered as a guild command for those guilds. + + Parameters + ---------- + command: :class:`~.ApplicationCommand` + The command to register. + force: :class:`bool` + Whether to force the command to be registered. If this is set to False, the command will only be registered + if it seems to already be registered and up to date with our internal cache. Defaults to True. + guild_ids: :class:`list` + A list of guild ids to register the command for. If this is not set, the command's + :attr:`ApplicationCommand.guild_ids` attribute will be used. + + Returns + ------- + :class:`~.ApplicationCommand` + The command that was registered + """ + # TODO: Write this + raise NotImplementedError + + async def register_commands( + self, + commands: list[ApplicationCommand] | None = None, + guild_id: int | None = None, + method: Literal["individual", "bulk", "auto"] = "bulk", + force: bool = False, + delete_existing: bool = True, + ) -> list[interactions.ApplicationCommand]: + """|coro| + + Register a list of commands. + + .. versionadded:: 2.0 + + Parameters + ---------- + commands: Optional[List[:class:`~.ApplicationCommand`]] + A list of commands to register. If this is not set (``None``), then all commands will be registered. + guild_id: Optional[int] + If this is set, the commands will be registered as a guild command for the respective guild. If it is not + set, the commands will be registered according to their :attr:`ApplicationCommand.guild_ids` attribute. + method: Literal['individual', 'bulk', 'auto'] + The method to use when registering the commands. If this is set to "individual", then each command will be + registered individually. If this is set to "bulk", then all commands will be registered in bulk. If this is + set to "auto", then the method will be determined automatically. Defaults to "bulk". + force: :class:`bool` + Registers the commands regardless of the state of the command on Discord. This uses one less API call, but + can result in hitting rate limits more often. Defaults to False. + delete_existing: :class:`bool` + Whether to delete existing commands that are not in the list of commands to register. Defaults to True. + """ + if commands is None: + commands = self.pending_application_commands + + commands = [copy.copy(cmd) for cmd in commands] + + if guild_id is not None: + for cmd in commands: + to_rep_with = [guild_id] + cmd.guild_ids = to_rep_with + + is_global = guild_id is None + + registered = [] + + if is_global: + pending = list(filter(lambda c: c.guild_ids is None, commands)) + registration_methods = { + "bulk": self._bot.http.bulk_upsert_global_commands, + "upsert": self._bot.http.upsert_global_command, + "delete": self._bot.http.delete_global_command, + "edit": self._bot.http.edit_global_command, + } + + def _register( + method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs + ): + return registration_methods[method]( + self._bot.user and self._bot.user.id, *args, **kwargs + ) + + else: + pending = list( + filter( + lambda c: c.guild_ids is not None and guild_id in c.guild_ids, + commands, + ) + ) + registration_methods = { + "bulk": self._bot.http.bulk_upsert_guild_commands, + "upsert": self._bot.http.upsert_guild_command, + "delete": self._bot.http.delete_guild_command, + "edit": self._bot.http.edit_guild_command, + } + + def _register( + method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs + ): + return registration_methods[method]( + self._bot.user and self._bot.user.id, guild_id, *args, **kwargs + ) + + def register( + method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs + ): + if kwargs.pop("_log", True): + if method == "bulk": + _log.debug( + f"Bulk updating commands {[c['name'] for c in args[0]]} for guild {guild_id}" + ) + # TODO: Find where "cmd" is defined + elif method == "upsert": + _log.debug(f"Creating command {cmd['name']} for guild {guild_id}") # type: ignore + elif method == "edit": + _log.debug(f"Editing command {cmd['name']} for guild {guild_id}") # type: ignore + elif method == "delete": + _log.debug(f"Deleting command {cmd['name']} for guild {guild_id}") # type: ignore + return _register(method, *args, **kwargs) + + pending_actions = [] + + if not force: + prefetched_commands: list[interactions.ApplicationCommand] = [] + if self._bot.user: + if guild_id is None: + prefetched_commands = await self._bot.http.get_global_commands( + self._bot.user.id + ) + else: + prefetched_commands = await self._bot.http.get_guild_commands( + self._bot.user.id, guild_id + ) + desynced = await self.get_desynced_commands( + guild_id=guild_id, prefetched=prefetched_commands + ) + + for cmd in desynced: + if cmd["action"] == "delete": + pending_actions.append( + { + "action": "delete" if delete_existing else None, + "command": collections.namedtuple("Command", ["name"])( + name=cmd["command"] + ), + "id": cmd["id"], + } + ) + continue + # We can assume the command item is a command, since it's only a string if action is delete + match = get(pending, name=cmd["command"].name, type=cmd["command"].type) + if match is None: + continue + if cmd["action"] == "edit": + pending_actions.append( + { + "action": "edit", + "command": match, + "id": cmd["id"], + } + ) + elif cmd["action"] == "upsert": + pending_actions.append( + { + "action": "upsert", + "command": match, + } + ) + elif cmd["action"] is None: + pending_actions.append( + { + "action": None, + "command": match, + } + ) + else: + raise ValueError(f"Unknown action: {cmd['action']}") + filtered_no_action = list( + filter(lambda c: c["action"] is not None, pending_actions) + ) + filtered_deleted = list( + filter(lambda a: a["action"] != "delete", pending_actions) + ) + if method == "bulk" or ( + method == "auto" and len(filtered_deleted) == len(pending) + ): + # Either the method is bulk or all the commands need to be modified, so we can just do a bulk upsert + data = [cmd["command"].to_dict() for cmd in filtered_deleted] + # If there's nothing to update, don't bother + if len(filtered_no_action) == 0: + _log.debug("Skipping bulk command update: Commands are up to date") + registered = prefetched_commands + else: + _log.debug( + "Bulk updating commands %s for guild %s", + {c["command"].name: c["action"] for c in pending_actions}, + guild_id, + ) + registered = await register("bulk", data, _log=False) + else: + if not filtered_no_action: + registered = [] + for cmd in filtered_no_action: + if cmd["action"] == "delete": + await register("delete", cmd["command"]) + continue + if cmd["action"] == "edit": + registered.append( + await register("edit", cmd["id"], cmd["command"].to_dict()) + ) + elif cmd["action"] == "upsert": + registered.append( + await register("upsert", cmd["command"].to_dict()) + ) + else: + raise ValueError(f"Unknown action: {cmd['action']}") + + # TODO: Our lists dont work sometimes, see if that can be fixed so we can avoid this second API call + if method != "bulk": + if self._bot.user: + if guild_id is None: + registered = await self._bot.http.get_global_commands( + self._bot.user.id + ) + else: + registered = await self._bot.http.get_guild_commands( + self._bot.user.id, guild_id + ) + else: + data = [cmd.to_dict() for cmd in pending] + registered = await register("bulk", data) + + for i in registered: + cmd = get( + self.pending_application_commands, + name=i["name"], + type=i.get("type"), + ) + if not cmd: + raise ValueError( + f"Registered command {i['name']}, type {i.get('type')} not found in pending commands" + ) + cmd.id = i["id"] + self._application_commands[cmd.id] = cmd + + return registered + + async def sync_commands( + self, + commands: list[ApplicationCommand] | None = None, + method: Literal["individual", "bulk", "auto"] = "bulk", + force: bool = False, + guild_ids: list[int] | None = None, + register_guild_commands: bool = True, + check_guilds: list[int] | None = [], + delete_existing: bool = True, + ) -> None: + """|coro| + + Registers all commands that have been added through :meth:`.add_application_command`. This method cleans up all + commands over the API and should sync them with the internal cache of commands. It attempts to register the + commands in the most efficient way possible, unless ``force`` is set to ``True``, in which case it will always + register all commands. + + By default, this coroutine is called inside the :func:`.on_connect` event. If you choose to override the + :func:`.on_connect` event, then you should invoke this coroutine as well. + + .. note:: + If you remove all guild commands from a particular guild, the library may not be able to detect and update + the commands accordingly, as it would have to individually check for each guild. To force the library to + unregister a guild's commands, call this function with ``commands=[]`` and ``guild_ids=[guild_id]``. + + .. versionadded:: 2.0 + + Parameters + ---------- + commands: Optional[List[:class:`~.ApplicationCommand`]] + A list of commands to register. If this is not set (None), then all commands will be registered. + method: Literal['individual', 'bulk', 'auto'] + The method to use when registering the commands. If this is set to "individual", then each command will be + registered individually. If this is set to "bulk", then all commands will be registered in bulk. If this is + set to "auto", then the method will be determined automatically. Defaults to "bulk". + force: :class:`bool` + Registers the commands regardless of the state of the command on Discord. This uses one less API call, but + can result in hitting rate limits more often. Defaults to False. + guild_ids: Optional[List[:class:`int`]] + A list of guild ids to register the commands for. If this is not set, the commands' + :attr:`~.ApplicationCommand.guild_ids` attribute will be used. + register_guild_commands: :class:`bool` + Whether to register guild commands. Defaults to True. + check_guilds: Optional[List[:class:`int`]] + A list of guilds ids to check for commands to unregister, since the bot would otherwise have to check all + guilds. Unlike ``guild_ids``, this does not alter the commands' :attr:`~.ApplicationCommand.guild_ids` + attribute, instead it adds the guild ids to a list of guilds to sync commands for. If + ``register_guild_commands`` is set to False, then this parameter is ignored. + delete_existing: :class:`bool` + Whether to delete existing commands that are not in the list of commands to register. Defaults to True. + """ + + check_guilds = list(set((check_guilds or []) + (self._bot.debug_guilds or []))) + + if commands is None: + commands = self.pending_application_commands + + if guild_ids is not None: + for cmd in commands: + cmd.guild_ids = guild_ids + + global_commands = [cmd for cmd in commands if cmd.guild_ids is None] + registered_commands = await self.register_commands( + global_commands, method=method, force=force, delete_existing=delete_existing + ) + + registered_guild_commands: dict[int, list[interactions.ApplicationCommand]] = {} + + if register_guild_commands: + cmd_guild_ids: list[int] = [] + for cmd in commands: + if cmd.guild_ids is not None: + cmd_guild_ids.extend(cmd.guild_ids) + if check_guilds is not None: + cmd_guild_ids.extend(check_guilds) + for guild_id in set(cmd_guild_ids): + guild_commands = [ + cmd + for cmd in commands + if cmd.guild_ids is not None and guild_id in cmd.guild_ids + ] + app_cmds = await self.register_commands( + guild_commands, + guild_id=guild_id, + method=method, + force=force, + delete_existing=delete_existing, + ) + registered_guild_commands[guild_id] = app_cmds + + for i in registered_commands: + cmd = get( + self.pending_application_commands, + name=i["name"], + guild_ids=None, + type=i.get("type"), + ) + if cmd: + cmd.id = i["id"] + self._application_commands[cmd.id] = cmd + + if register_guild_commands and registered_guild_commands: + for guild_id, guild_cmds in registered_guild_commands.items(): + for i in guild_cmds: + cmd = find( + lambda cmd: cmd.name == i["name"] + and cmd.type == i.get("type") + and cmd.guild_ids is not None + # TODO: fix this type error (guild_id is not defined in ApplicationCommand Typed Dict) + and int(i["guild_id"]) in cmd.guild_ids, # type: ignore + self.pending_application_commands, + ) + if not cmd: + # command has not been added yet + continue + cmd.id = i["id"] + self._application_commands[cmd.id] = cmd + + async def process_application_commands( + self, interaction: Interaction, auto_sync: bool | None = None + ) -> None: + """|coro| + + This function processes the commands that have been registered + to the bot and other groups. Without this coroutine, none of the + commands will be triggered. + + By default, this coroutine is called inside the :func:`.on_interaction` + event. If you choose to override the :func:`.on_interaction` event, then + you should invoke this coroutine as well. + + This function finds a registered command matching the interaction id from + application commands and invokes it. If no matching command was + found, it replies to the interaction with a default message. + + .. versionadded:: 2.0 + + Parameters + ---------- + interaction: :class:`discord.Interaction` + The interaction to process + auto_sync: Optional[:class:`bool`] + Whether to automatically sync and unregister the command if it is not found in the internal cache. This will + invoke the :meth:`~.Bot.sync_commands` method on the context of the command, either globally or per-guild, + based on the type of the command, respectively. Defaults to :attr:`.Bot.auto_sync_commands`. + """ + if auto_sync is None: + auto_sync = self._bot.auto_sync_commands + # TODO: find out why the isinstance check below doesn't stop the type errors below + if interaction.type not in ( + InteractionType.application_command, + InteractionType.auto_complete, + ): + return + + command: ApplicationCommand | None = None + try: + if interaction.data: + command = self._application_commands[interaction.data["id"]] # type: ignore + except KeyError: + for cmd in self.application_commands + self.pending_application_commands: + if interaction.data: + guild_id = interaction.data.get("guild_id") + if guild_id: + guild_id = int(guild_id) + if cmd.name == interaction.data["name"] and ( # type: ignore + guild_id == cmd.guild_ids + or ( + isinstance(cmd.guild_ids, list) + and guild_id in cmd.guild_ids + ) + ): + command = cmd + break + else: + if auto_sync and interaction.data: + guild_id = interaction.data.get("guild_id") + if guild_id is None: + await self.sync_commands() + else: + + await self.sync_commands(check_guilds=[guild_id]) + return self._bot.dispatch("unknown_application_command", interaction) + + if interaction.type is InteractionType.auto_complete: + return self._bot.dispatch( + "application_command_auto_complete", interaction, command + ) + + ctx = await self.get_application_context(interaction) + if command: + ctx.command = command + await self.invoke_application_command(ctx) + + async def on_application_command_auto_complete( + self, interaction: Interaction, command: ApplicationCommand + ) -> None: + async def callback() -> None: + ctx = await self.get_autocomplete_context(interaction) + ctx.command = command + return await command.invoke_autocomplete_callback(ctx) + + autocomplete_task = self._bot.loop.create_task(callback()) + try: + await self._bot.wait_for( + "application_command_auto_complete", + check=lambda i, c: c == command, + timeout=3, + ) + except asyncio.TimeoutError: + return + else: + if not autocomplete_task.done(): + autocomplete_task.cancel() + + def slash_command(self, **kwargs): + """A shortcut decorator that invokes :func:`command` and adds it to + the internal command list via :meth:`add_application_command`. + This shortcut is made specifically for :class:`.SlashCommand`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`SlashCommand`] + A decorator that converts the provided method into a :class:`.SlashCommand`, adds it to the bot, + then returns it. + """ + return self.application_command(cls=SlashCommand, **kwargs) + + def user_command(self, **kwargs): + """A shortcut decorator that invokes :func:`command` and adds it to + the internal command list via :meth:`add_application_command`. + This shortcut is made specifically for :class:`.UserCommand`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`UserCommand`] + A decorator that converts the provided method into a :class:`.UserCommand`, adds it to the bot, + then returns it. + """ + return self.application_command(cls=UserCommand, **kwargs) + + def message_command(self, **kwargs): + """A shortcut decorator that invokes :func:`command` and adds it to + the internal command list via :meth:`add_application_command`. + This shortcut is made specifically for :class:`.MessageCommand`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`MessageCommand`] + A decorator that converts the provided method into a :class:`.MessageCommand`, adds it to the bot, + then returns it. + """ + return self.application_command(cls=MessageCommand, **kwargs) + + def application_command(self, **kwargs): + """A shortcut decorator that invokes :func:`command` and adds it to + the internal command list via :meth:`~.Bot.add_application_command`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`ApplicationCommand`] + A decorator that converts the provided method into an :class:`.ApplicationCommand`, adds it to the bot, + then returns it. + """ + + def decorator(func) -> ApplicationCommand: + result = command(**kwargs)(func) + self.add_application_command(result) + return result + + return decorator + + def command(self, **kwargs): + """An alias for :meth:`application_command`. + + .. note:: + + This decorator is overridden by :class:`discord.ext.commands.Bot`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`ApplicationCommand`] + A decorator that converts the provided method into an :class:`.ApplicationCommand`, adds it to the bot, + then returns it. + """ + return self.application_command(**kwargs) + + def create_group( + self, + name: str, + description: str | None = None, + guild_ids: list[int] | None = None, + **kwargs, + ) -> SlashCommandGroup: + """A shortcut method that creates a slash command group with no subcommands and adds it to the internal + command list via :meth:`add_application_command`. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the group to create. + description: Optional[:class:`str`] + The description of the group to create. + guild_ids: Optional[List[:class:`int`]] + A list of the IDs of each guild this group should be added to, making it a guild command. + This will be a global command if ``None`` is passed. + kwargs: + Any additional keyword arguments to pass to :class:`.SlashCommandGroup`. + + Returns + ------- + SlashCommandGroup + The slash command group that was created. + """ + description = description or "No description provided." + group = SlashCommandGroup(name, description, guild_ids, **kwargs) + self.add_application_command(group) + return group + + def group( + self, + name: str | None = None, + description: str | None = None, + guild_ids: list[int] | None = None, + ) -> Callable[[type[SlashCommandGroup]], SlashCommandGroup]: + """A shortcut decorator that initializes the provided subclass of :class:`.SlashCommandGroup` + and adds it to the internal command list via :meth:`add_application_command`. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: Optional[:class:`str`] + The name of the group to create. This will resolve to the name of the decorated class if ``None`` is passed. + description: Optional[:class:`str`] + The description of the group to create. + guild_ids: Optional[List[:class:`int`]] + A list of the IDs of each guild this group should be added to, making it a guild command. + This will be a global command if ``None`` is passed. + + Returns + ------- + Callable[[Type[SlashCommandGroup]], SlashCommandGroup] + The slash command group that was created. + """ + + def inner(cls: type[SlashCommandGroup]) -> SlashCommandGroup: + group = cls( + name or cls.__name__, + ( + description or inspect.cleandoc(cls.__doc__).splitlines()[0] + if cls.__doc__ is not None + else "No description provided" + ), + guild_ids=guild_ids, + ) + self.add_application_command(group) + return group + + return inner + + slash_group = group + + def walk_application_commands(self) -> Generator[ApplicationCommand, None, None]: + """An iterator that recursively walks through all application commands and subcommands. + + Yields + ------ + :class:`.ApplicationCommand` + An application command from the internal list of application commands. + """ + for command in self.application_commands: + if isinstance(command, SlashCommandGroup): + yield from command.walk_commands() + yield command + + async def get_application_context( + self, interaction: Interaction, cls: Any = ApplicationContext + ) -> ApplicationContext: + r"""|coro| + + Returns the invocation context from the interaction. + + This is a more low-level counter-part for :meth:`.process_application_commands` + to allow users more fine-grained control over the processing. + + Parameters + ----------- + interaction: :class:`discord.Interaction` + The interaction to get the invocation context from. + cls + The factory class that will be used to create the context. + By default, this is :class:`.ApplicationContext`. Should a custom + class be provided, it must be similar enough to + :class:`.ApplicationContext`\'s interface. + + Returns + -------- + :class:`.ApplicationContext` + The invocation context. The type of this can change via the + ``cls`` parameter. + """ + return cls(self, interaction) + + async def get_autocomplete_context( + self, interaction: Interaction, cls: Any = AutocompleteContext + ) -> AutocompleteContext: + r"""|coro| + + Returns the autocomplete context from the interaction. + + This is a more low-level counter-part for :meth:`.process_application_commands` + to allow users more fine-grained control over the processing. + + Parameters + ----------- + interaction: :class:`discord.Interaction` + The interaction to get the invocation context from. + cls + The factory class that will be used to create the context. + By default, this is :class:`.AutocompleteContext`. Should a custom + class be provided, it must be similar enough to + :class:`.AutocompleteContext`\'s interface. + + Returns + -------- + :class:`.AutocompleteContext` + The autocomplete context. The type of this can change via the + ``cls`` parameter. + """ + return cls(self, interaction) + + async def invoke_application_command(self, ctx: ApplicationContext) -> None: + """|coro| + + Invokes the application command given under the invocation + context and handles all the internal event dispatch mechanisms. + + Parameters + ---------- + ctx: :class:`.ApplicationCommand` + The invocation context to invoke. + """ + self._bot.dispatch("application_command", ctx) + try: + if await self._bot.can_run(ctx, call_once=True): + await ctx.command.invoke(ctx) + else: + raise CheckFailure("The global check once functions failed.") + except DiscordException as exc: + await ctx.command.dispatch_error(ctx, exc) + else: + self._bot.dispatch("application_command_completion", ctx) + + @property + @abstractmethod + def _bot(self) -> Bot | AutoShardedBot: + ... + + +class BotBase(ApplicationCommandMixin, CogMixin, ABC): + _supports_prefixed_commands = False + + def __init__(self, description=None, *args, **options): + super().__init__(*args, **options) + self.extra_events = {} # TYPE: Dict[str, List[CoroFunc]] + self.__cogs = {} # TYPE: Dict[str, Cog] + self.__extensions = {} # TYPE: Dict[str, types.ModuleType] + self._checks = [] # TYPE: List[Check] + self._check_once = [] + self._before_invoke = None + self._after_invoke = None + self.description = inspect.cleandoc(description) if description else "" + self.owner_id = options.get("owner_id") + self.owner_ids = options.get("owner_ids", set()) + self.auto_sync_commands = options.get("auto_sync_commands", True) + + self.debug_guilds = options.pop("debug_guilds", None) + + if self.owner_id and self.owner_ids: + raise TypeError("Both owner_id and owner_ids are set.") + + if self.owner_ids and not isinstance( + self.owner_ids, collections.abc.Collection + ): + raise TypeError( + f"owner_ids must be a collection not {self.owner_ids.__class__!r}" + ) + + self._checks = [] + self._check_once = [] + self._before_invoke = None + self._after_invoke = None + + async def on_connect(self): + if self.auto_sync_commands: + await self.sync_commands() + + async def on_interaction(self, interaction): + await self.process_application_commands(interaction) + + async def on_application_command_error( + self, context: ApplicationContext, exception: DiscordException + ) -> None: + """|coro| + + The default command error handler provided by the bot. + + By default, this prints to :data:`sys.stderr` however it could be + overridden to have a different implementation. + + This only fires if you do not specify any listeners for command error. + """ + if self.extra_events.get("on_application_command_error", None): + return + + command = context.command + if command and command.has_error_handler(): + return + + cog = context.cog + if cog and cog.has_error_handler(): + return + + print(f"Ignoring exception in command {context.command}:", file=sys.stderr) + traceback.print_exception( + type(exception), exception, exception.__traceback__, file=sys.stderr + ) + + # global check registration + # TODO: Remove these from commands.Bot + + def check(self, func): + """A decorator that adds a global check to the bot. A global check is similar to a :func:`.check` that is + applied on a per-command basis except it is run before any command checks have been verified and applies to + every command the bot has. + + .. note:: + + This function can either be a regular function or a coroutine. Similar to a command :func:`.check`, this + takes a single parameter of type :class:`.Context` and can only raise exceptions inherited from + :exc:`.ApplicationCommandError`. + + Example + ------- + .. code-block:: python3 + + @bot.check + def check_commands(ctx): + return ctx.command.qualified_name in allowed_commands + """ + # T was used instead of Check to ensure the type matches on return + self.add_check(func) # type: ignore + return func + + def add_check(self, func, *, call_once: bool = False) -> None: + """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` and + :meth:`.check_once`. + + Parameters + ---------- + func + The function that was used as a global check. + call_once: :class:`bool` + If the function should only be called once per :meth:`.Bot.invoke` call. + """ + + if call_once: + self._check_once.append(func) + else: + self._checks.append(func) + + def remove_check(self, func, *, call_once: bool = False) -> None: + """Removes a global check from the bot. + This function is idempotent and will not raise an exception + if the function is not in the global checks. + + Parameters + ---------- + func + The function to remove from the global checks. + call_once: :class:`bool` + If the function was added with ``call_once=True`` in + the :meth:`.Bot.add_check` call or using :meth:`.check_once`. + """ + checks = self._check_once if call_once else self._checks + + try: + checks.remove(func) + except ValueError: + pass + + def check_once(self, func): + """A decorator that adds a "call once" global check to the bot. Unlike regular global checks, this one is called + only once per :meth:`.Bot.invoke` call. Regular global checks are called whenever a command is called or + :meth:`.Command.can_run` is called. This type of check bypasses that and ensures that it's called only once, + even inside the default help command. + + .. note:: + + When using this function the :class:`.Context` sent to a group subcommand may only parse the parent command + and not the subcommands due to it being invoked once per :meth:`.Bot.invoke` call. + + .. note:: + + This function can either be a regular function or a coroutine. Similar to a command :func:`.check`, + this takes a single parameter of type :class:`.Context` and can only raise exceptions inherited from + :exc:`.ApplicationCommandError`. + + Example + ------- + .. code-block:: python3 + + @bot.check_once + def whitelist(ctx): + return ctx.message.author.id in my_whitelist + """ + self.add_check(func, call_once=True) + return func + + async def can_run( + self, ctx: ApplicationContext, *, call_once: bool = False + ) -> bool: + data = self._check_once if call_once else self._checks + + if not data: + return True + + # type-checker doesn't distinguish between functions and methods + return await async_all(f(ctx) for f in data) # type: ignore + + # listener registration + + def add_listener(self, func: CoroFunc, name: str = MISSING) -> None: + """The non decorator alternative to :meth:`.listen`. + + Parameters + ---------- + func: :ref:`coroutine ` + The function to call. + name: :class:`str` + The name of the event to listen for. Defaults to ``func.__name__``. + + Example + ------- + + .. code-block:: python3 + + async def on_ready(): pass + async def my_message(message): pass + + bot.add_listener(on_ready) + bot.add_listener(my_message, 'on_message') + """ + name = func.__name__ if name is MISSING else name + + if not asyncio.iscoroutinefunction(func): + raise TypeError("Listeners must be coroutines") + + if name in self.extra_events: + self.extra_events[name].append(func) + else: + self.extra_events[name] = [func] + + def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None: + """Removes a listener from the pool of listeners. + + Parameters + ---------- + func + The function that was used as a listener to remove. + name: :class:`str` + The name of the event we want to remove. Defaults to + ``func.__name__``. + """ + + name = func.__name__ if name is MISSING else name + + if name in self.extra_events: + try: + self.extra_events[name].remove(func) + except ValueError: + pass + + def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]: + """A decorator that registers another function as an external + event listener. Basically this allows you to listen to multiple + events from different places e.g. such as :func:`.on_ready` + + The functions being listened to must be a :ref:`coroutine `. + + Raises + ------ + TypeError + The function being listened to is not a coroutine. + + Example + ------- + + .. code-block:: python3 + + @bot.listen() + async def on_message(message): + print('one') + + # in some other file... + + @bot.listen('on_message') + async def my_message(message): + print('two') + + Would print one and two in an unspecified order. + """ + + def decorator(func: CFT) -> CFT: + self.add_listener(func, name) + return func + + return decorator + + def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: + # super() will resolve to Client + super().dispatch(event_name, *args, **kwargs) # type: ignore + ev = f"on_{event_name}" + for event in self.extra_events.get(ev, []): + self._schedule_event(event, ev, *args, **kwargs) # type: ignore + + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + .. note:: + + The :meth:`~.Bot.before_invoke` and :meth:`~.Bot.after_invoke` hooks are + only called if all checks and argument parsing procedures pass + without error. If any check or argument parsing procedures fail + then the hooks are not called. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + r"""A decorator that registers a coroutine as a post-invoke hook. + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + .. note:: + + Similar to :meth:`~.Bot.before_invoke`\, this is not called unless + checks and argument parsing procedures succeed. This hook is, + however, **always** called regardless of the internal command + callback raising an error (i.e. :exc:`.CommandInvokeError`\). + This makes it ideal for clean-up scenarios. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + async def is_owner(self, user: User) -> bool: + """|coro| + + Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of + this bot. + + If an :attr:`owner_id` is not set, it is fetched automatically + through the use of :meth:`~.Bot.application_info`. + + .. versionchanged:: 1.3 + The function also checks if the application is team-owned if + :attr:`owner_ids` is not set. + + Parameters + ---------- + user: :class:`.abc.User` + The user to check for. + + Returns + ------- + :class:`bool` + Whether the user is the owner. + """ + + if self.owner_id: + return user.id == self.owner_id + elif self.owner_ids: + return user.id in self.owner_ids + else: + app = await self.application_info() # type: ignore + if app.team: + self.owner_ids = ids = {m.id for m in app.team.members} + return user.id in ids + else: + self.owner_id = owner_id = app.owner.id + return user.id == owner_id + + +class Bot(BotBase, Client): + """Represents a discord bot. + + This class is a subclass of :class:`discord.Client` and as a result + anything that you can do with a :class:`discord.Client` you can do with + this bot. + + This class also subclasses ``ApplicationCommandMixin`` to provide the functionality + to manage commands. + + .. versionadded:: 2.0 + + Attributes + ---------- + description: :class:`str` + The content prefixed into the default help message. + owner_id: Optional[:class:`int`] + The user ID that owns the bot. If this is not set and is then queried via + :meth:`.is_owner` then it is fetched automatically using + :meth:`~.Bot.application_info`. + owner_ids: Optional[Collection[:class:`int`]] + The user IDs that owns the bot. This is similar to :attr:`owner_id`. + If this is not set and the application is team based, then it is + fetched automatically using :meth:`~.Bot.application_info`. + For performance reasons it is recommended to use a :class:`set` + for the collection. You cannot set both ``owner_id`` and ``owner_ids``. + + .. versionadded:: 1.3 + debug_guilds: Optional[List[:class:`int`]] + Guild IDs of guilds to use for testing commands. + The bot will not create any global commands if debug guild IDs are passed. + + .. versionadded:: 2.0 + auto_sync_commands: :class:`bool` + Whether to automatically sync slash commands. This will call sync_commands in on_connect, and in + :attr:`.process_application_commands` if the command is not found. Defaults to ``True``. + + .. versionadded:: 2.0 + """ + + @property + def _bot(self) -> Bot: + return self + + +class AutoShardedBot(BotBase, AutoShardedClient): + """This is similar to :class:`.Bot` except that it is inherited from + :class:`discord.AutoShardedClient` instead. + + .. versionadded:: 2.0 + """ + + @property + def _bot(self) -> AutoShardedBot: + return self diff --git a/discord/channel.py b/discord/channel.py new file mode 100644 index 0000000..73ca1fe --- /dev/null +++ b/discord/channel.py @@ -0,0 +1,2699 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload + +import discord.abc + +from . import utils +from .asset import Asset +from .enums import ( + ChannelType, + EmbeddedActivity, + InviteTarget, + StagePrivacyLevel, + VideoQualityMode, + VoiceRegion, + try_enum, +) +from .errors import ClientException, InvalidArgument +from .file import File +from .flags import ChannelFlags +from .invite import Invite +from .iterators import ArchivedThreadIterator +from .mixins import Hashable +from .object import Object +from .permissions import PermissionOverwrite, Permissions +from .stage_instance import StageInstance +from .threads import Thread +from .utils import MISSING + +__all__ = ( + "TextChannel", + "VoiceChannel", + "StageChannel", + "DMChannel", + "CategoryChannel", + "GroupChannel", + "PartialMessageable", + "ForumChannel", +) + +if TYPE_CHECKING: + from .abc import Snowflake, SnowflakeTime + from .guild import Guild + from .guild import GuildChannel as GuildChannelType + from .member import Member, VoiceState + from .message import Message, PartialMessage + from .role import Role + from .state import ConnectionState + from .types.channel import CategoryChannel as CategoryChannelPayload + from .types.channel import DMChannel as DMChannelPayload + from .types.channel import ForumChannel as ForumChannelPayload + from .types.channel import GroupDMChannel as GroupChannelPayload + from .types.channel import StageChannel as StageChannelPayload + from .types.channel import TextChannel as TextChannelPayload + from .types.channel import VoiceChannel as VoiceChannelPayload + from .types.snowflake import SnowflakeList + from .types.threads import ThreadArchiveDuration + from .user import BaseUser, ClientUser, User + from .webhook import Webhook + + +class _TextChannel(discord.abc.GuildChannel, Hashable): + __slots__ = ( + "name", + "id", + "guild", + "topic", + "_state", + "nsfw", + "category_id", + "position", + "slowmode_delay", + "_overwrites", + "_type", + "last_message_id", + "default_auto_archive_duration", + "flags", + ) + + def __init__( + self, + *, + state: ConnectionState, + guild: Guild, + data: TextChannelPayload | ForumChannelPayload, + ): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self._update(guild, data) + + @property + def _repr_attrs(self) -> tuple[str, ...]: + return "id", "name", "position", "nsfw", "category_id" + + def __repr__(self) -> str: + attrs = [(val, getattr(self, val)) for val in self._repr_attrs] + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" + + def _update( + self, guild: Guild, data: TextChannelPayload | ForumChannelPayload + ) -> None: + # This data will always exist + self.guild: Guild = guild + self.name: str = data["name"] + self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self._type: int = data["type"] + + # This data may be missing depending on how this object is being created/updated + if not data.pop("_invoke_flag", False): + self.topic: str | None = data.get("topic") + self.position: int = data.get("position") + self.nsfw: bool = data.get("nsfw", False) + # Does this need coercion into `int`? No idea yet. + self.slowmode_delay: int = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get( + "default_auto_archive_duration", 1440 + ) + self.last_message_id: int | None = utils._get_as_snowflake( + data, "last_message_id" + ) + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) + self._fill_overwrites(data) + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return try_enum(ChannelType, self._type) + + @property + def _sorting_bucket(self) -> int: + return ChannelType.text.value + + @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + def permissions_for(self, obj: Member | Role, /) -> Permissions: + base = super().permissions_for(obj) + + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + @property + def members(self) -> list[Member]: + """List[:class:`Member`]: Returns all members that can see this channel.""" + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + + @property + def threads(self) -> list[Thread]: + """List[:class:`Thread`]: Returns all the threads that you can see. + + .. versionadded:: 2.0 + """ + return [ + thread + for thread in self.guild._threads.values() + if thread.parent_id == self.id + ] + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the channel is NSFW.""" + return self.nsfw + + @property + def last_message(self) -> Message | None: + """Fetches the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + ------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return ( + self._state._get_message(self.last_message_id) + if self.last_message_id + else None + ) + + @overload + async def edit( + self, + *, + reason: str | None = ..., + name: str = ..., + topic: str = ..., + position: int = ..., + nsfw: bool = ..., + sync_permissions: bool = ..., + category: CategoryChannel | None = ..., + slowmode_delay: int = ..., + default_auto_archive_duration: ThreadArchiveDuration = ..., + type: ChannelType = ..., + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] = ..., + ) -> TextChannel | None: + ... + + @overload + async def edit(self) -> TextChannel | None: + ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 1.4 + The ``type`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + To mark the channel as NSFW or not. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + A value of `0` disables slowmode. The maximum value possible is `21600`. + type: :class:`ChannelType` + Change the type of this text channel. Currently, only conversion between + :attr:`ChannelType.text` and :attr:`ChannelType.news` is supported. This + is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + + Returns + ------- + Optional[:class:`.TextChannel`] + The newly edited text channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of channels, or if + the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: str | None = None, reason: str | None = None + ) -> TextChannel: + return await self._clone_impl( + { + "topic": self.topic, + "nsfw": self.nsfw, + "rate_limit_per_user": self.slowmode_delay, + }, + name=name, + reason=reason, + ) + + async def delete_messages( + self, messages: Iterable[Snowflake], *, reason: str | None = None + ) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this. + + Parameters + ---------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id: int = messages[0].id + await self._state.http.delete_message(self.id, message_id, reason=reason) + return + + if len(messages) > 100: + raise ClientException("Can only bulk delete messages up to 100 messages") + + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: int | None = 100, + check: Callable[[Message], bool] = MISSING, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = False, + bulk: bool = True, + reason: str | None = None, + ) -> list[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have the :attr:`~Permissions.manage_messages` permission to + delete messages even if they are your own. + The :attr:`~Permissions.read_message_history` permission is + also needed to retrieve message history. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + bulk: :class:`bool` + If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting + a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Returns + ------- + List[:class:`.Message`] + The list of messages that were deleted. + + Raises + ------ + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Examples + -------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f'Deleted {len(deleted)} message(s)') + """ + return await discord.abc._purge_messages_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> list[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Returns + ------- + List[:class:`Webhook`] + The webhooks for this channel. + + Raises + ------ + Forbidden + You don't have permissions to get the webhooks. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook( + self, *, name: str, avatar: bytes | None = None, reason: str | None = None + ) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionchanged:: 1.1 + Added the ``reason`` keyword-only parameter. + + Parameters + ---------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) # type: ignore + + data = await self._state.http.create_webhook( + self.id, name=str(name), avatar=avatar, reason=reason + ) + return Webhook.from_state(data, state=self._state) + + async def follow( + self, *, destination: TextChannel, reason: str | None = None + ) -> Webhook: + """ + Follows a channel using a webhook. + + Only news channels can be followed. + + .. note:: + + The webhook returned will not provide a token to do webhook + actions, as Discord does not provide it. + + .. versionadded:: 1.3 + + Parameters + ---------- + destination: :class:`TextChannel` + The channel you would like to follow from. + reason: Optional[:class:`str`] + The reason for following the channel. Shows up on the destination guild's audit log. + + .. versionadded:: 1.4 + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Following the channel failed. + Forbidden + You do not have the permissions to create a webhook. + """ + + if not self.is_news(): + raise ClientException("The channel must be a news channel.") + + if not isinstance(destination, TextChannel): + raise InvalidArgument( + f"Expected TextChannel received {destination.__class__.__name__}" + ) + + from .webhook import Webhook + + data = await self._state.http.follow_webhook( + self.id, webhook_channel_id=destination.id, reason=reason + ) + return Webhook._as_follower(data, channel=destination, user=self._state.user) + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + def get_thread(self, thread_id: int, /) -> Thread | None: + """Returns a thread with the given ID. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`Thread`] + The returned thread or ``None`` if not found. + """ + return self.guild.get_thread(thread_id) + + def archived_threads( + self, + *, + private: bool = False, + joined: bool = False, + limit: int | None = 50, + before: Snowflake | datetime.datetime | None = None, + ) -> ArchivedThreadIterator: + """Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild. + + You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads + then :attr:`~Permissions.manage_threads` is also required. + + .. versionadded:: 2.0 + + Parameters + ---------- + limit: Optional[:class:`bool`] + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve archived channels before the given date or ID. + private: :class:`bool` + Whether to retrieve private archived threads. + joined: :class:`bool` + Whether to retrieve private archived threads that you've joined. + You cannot set ``joined`` to ``True`` and ``private`` to ``False``. + + Yields + ------ + :class:`Thread` + The archived threads. + + Raises + ------ + Forbidden + You do not have permissions to get archived threads. + HTTPException + The request to get the archived threads failed. + """ + return ArchivedThreadIterator( + self.id, + self.guild, + limit=limit, + joined=joined, + private=private, + before=before, + ) + + +class TextChannel(discord.abc.Messageable, _TextChannel): + """Represents a Discord text channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + topic: Optional[:class:`str`] + The channel's topic. ``None`` if it doesn't exist. + position: Optional[:class:`int`] + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. Can be ``None`` if the channel was received in an interaction. + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + nsfw: :class:`bool` + If the channel is marked as "not safe for work". + + .. note:: + + To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + def __init__( + self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload + ): + super().__init__(state=state, guild=guild, data=data) + + @property + def _repr_attrs(self) -> tuple[str, ...]: + return super()._repr_attrs + ("news",) + + def _update(self, guild: Guild, data: TextChannelPayload) -> None: + super()._update(guild, data) + + async def _get_channel(self) -> "TextChannel": + return self + + def is_news(self) -> bool: + """:class:`bool`: Checks if the channel is a news/announcements channel.""" + return self._type == ChannelType.news.value + + @property + def news(self) -> bool: + """Equivalent to :meth:`is_news`.""" + return self.is_news() + + async def create_thread( + self, + *, + name: str, + message: Snowflake | None = None, + auto_archive_duration: ThreadArchiveDuration = MISSING, + type: ChannelType | None = None, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this text channel. + + To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`. + For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the thread. + message: Optional[:class:`abc.Snowflake`] + A snowflake representing the message to create the thread with. + If ``None`` is passed then a private thread is created. + Defaults to ``None``. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + type: Optional[:class:`ChannelType`] + The type of thread to create. If a ``message`` is passed then this parameter + is ignored, as a thread created with a message is always a public thread. + By default, this creates a private thread if this is ``None``. + reason: :class:`str` + The reason for creating a new thread. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created thread + + Raises + ------ + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. + """ + + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration + or self.default_auto_archive_duration, + type=type.value, + reason=reason, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration + or self.default_auto_archive_duration, + reason=reason, + ) + + return Thread(guild=self.guild, state=self._state, data=data) + + +class ForumChannel(_TextChannel): + """Represents a Discord forum channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + topic: Optional[:class:`str`] + The channel's topic. ``None`` if it doesn't exist. + + .. note:: + + :attr:`guidelines` exists as an alternative to this attribute. + position: Optional[:class:`int`] + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. Can be ``None`` if the channel was received in an interaction. + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + nsfw: :class:`bool` + If the channel is marked as "not safe for work". + + .. note:: + + To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + def __init__( + self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload + ): + super().__init__(state=state, guild=guild, data=data) + + def _update(self, guild: Guild, data: ForumChannelPayload) -> None: + super()._update(guild, data) + + @property + def guidelines(self) -> str | None: + """Optional[:class:`str`]: The channel's guidelines. An alias of :attr:`topic`.""" + return self.topic + + async def create_thread( + self, + name: str, + content=None, + *, + embed=None, + embeds=None, + file=None, + files=None, + stickers=None, + delete_message_after=None, + nonce=None, + allowed_mentions=None, + view=None, + auto_archive_duration: ThreadArchiveDuration = MISSING, + slowmode_delay: int = MISSING, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this forum channel. + + To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`. + For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the thread. + content: :class:`str` + The content of the message to send. + embed: :class:`~discord.Embed` + The rich embed for the content. + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. + delete_message_after: :class:`int` + The time to wait before deleting the thread. + nonce: :class:`int` + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + view: :class:`discord.ui.View` + A Discord UI View to add to the message. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in the new thread. A value of `0` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + If not provided, the forum channel's default slowmode is used. + reason: :class:`str` + The reason for creating a new thread. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created thread + + Raises + ------ + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. + """ + state = self._state + message_content = str(content) if content is not None else None + + if embed is not None and embeds is not None: + raise InvalidArgument( + "cannot pass both embed and embeds parameter to create_thread()" + ) + + if embed is not None: + embed = embed.to_dict() + + elif embeds is not None: + if len(embeds) > 10: + raise InvalidArgument( + "embeds parameter must be a list of up to 10 elements" + ) + embeds = [embed.to_dict() for embed in embeds] + + if stickers is not None: + stickers = [sticker.id for sticker in stickers] + + if allowed_mentions is None: + allowed_mentions = ( + state.allowed_mentions and state.allowed_mentions.to_dict() + ) + elif state.allowed_mentions is not None: + allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + else: + allowed_mentions = allowed_mentions.to_dict() + + if view: + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument( + f"view parameter must be View not {view.__class__!r}" + ) + + components = view.to_components() + else: + components = None + + if file is not None and files is not None: + raise InvalidArgument("cannot pass both file and files parameter to send()") + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + + try: + data = await state.http.send_files( + self.id, + files=[file], + allowed_mentions=allowed_mentions, + content=message_content, + embed=embed, + embeds=embeds, + nonce=nonce, + stickers=stickers, + components=components, + ) + finally: + file.close() + + elif files is not None: + if len(files) > 10: + raise InvalidArgument( + "files parameter must be a list of up to 10 elements" + ) + elif not all(isinstance(file, File) for file in files): + raise InvalidArgument("files parameter must be a list of File") + + try: + data = await state.http.send_files( + self.id, + files=files, + content=message_content, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + stickers=stickers, + components=components, + ) + finally: + for f in files: + f.close() + else: + data = await state.http.start_forum_thread( + self.id, + content=message_content, + name=name, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + stickers=stickers, + components=components, + auto_archive_duration=auto_archive_duration + or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or self.slowmode_delay, + reason=reason, + ) + ret = Thread(guild=self.guild, state=self._state, data=data) + msg = ret.get_partial_message(data["last_message_id"]) + if view: + state.store_view(view, msg.id) + + if delete_message_after is not None: + await msg.delete(delay=delete_message_after) + return ret + + +class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): + __slots__ = ( + "name", + "id", + "guild", + "bitrate", + "user_limit", + "_state", + "position", + "_overwrites", + "category_id", + "rtc_region", + "video_quality_mode", + "last_message_id", + "flags", + ) + + def __init__( + self, + *, + state: ConnectionState, + guild: Guild, + data: VoiceChannelPayload | StageChannelPayload, + ): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self._update(guild, data) + + def _get_voice_client_key(self) -> tuple[int, str]: + return self.guild.id, "guild_id" + + def _get_voice_state_pair(self) -> tuple[int, int]: + return self.guild.id, self.id + + def _update( + self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload + ) -> None: + # This data will always exist + self.guild = guild + self.name: str = data["name"] + self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + + # This data may be missing depending on how this object is being created/updated + if not data.pop("_invoke_flag", False): + rtc = data.get("rtc_region") + self.rtc_region: VoiceRegion | None = ( + try_enum(VoiceRegion, rtc) if rtc is not None else None + ) + self.video_quality_mode: VideoQualityMode = try_enum( + VideoQualityMode, data.get("video_quality_mode", 1) + ) + self.last_message_id: int | None = utils._get_as_snowflake( + data, "last_message_id" + ) + self.position: int = data.get("position") + self.bitrate: int = data.get("bitrate") + self.user_limit: int = data.get("user_limit") + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) + self._fill_overwrites(data) + + @property + def _sorting_bucket(self) -> int: + return ChannelType.voice.value + + @property + def members(self) -> list[Member]: + """List[:class:`Member`]: Returns all members that are currently inside this voice channel.""" + ret = [] + for user_id, state in self.guild._voice_states.items(): + if state.channel and state.channel.id == self.id: + member = self.guild.get_member(user_id) + if member is not None: + ret.append(member) + return ret + + @property + def voice_states(self) -> dict[int, VoiceState]: + """Returns a mapping of member IDs who have voice states in this channel. + + .. versionadded:: 1.3 + + .. note:: + + This function is intentionally low level to replace :attr:`members` + when the member cache is unavailable. + + Returns + ------- + Mapping[:class:`int`, :class:`VoiceState`] + The mapping of member ID to a voice state. + """ + return { + key: value + for key, value in self.guild._voice_states.items() + if value.channel and value.channel.id == self.id + } + + @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + def permissions_for(self, obj: Member | Role, /) -> Permissions: + base = super().permissions_for(obj) + + # Voice channels cannot be edited by people who can't connect to them. + # It also implicitly denies all other voice perms + if not base.connect: + denied = Permissions.voice() + denied.update(manage_channels=True, manage_roles=True) + base.value &= ~denied.value + return base + + +class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): + """Represents a Discord guild voice channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + position: Optional[:class:`int`] + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. Can be ``None`` if the channel was received in an interaction. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a voice channel. + rtc_region: Optional[:class:`VoiceRegion`] + The region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + + .. versionadded:: 1.7 + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + + .. versionadded:: 2.0 + last_message_id: Optional[:class:`int`] + The ID of the last message sent to this channel. It may not always point to an existing or valid message. + + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + __slots__ = "nsfw" + + def _update(self, guild: Guild, data: VoiceChannelPayload): + super()._update(guild, data) + self.nsfw: bool = data.get("nsfw", False) + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" + + async def _get_channel(self): + return self + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the channel is NSFW.""" + return self.nsfw + + @property + def last_message(self) -> Message | None: + """Fetches the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + ------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return ( + self._state._get_message(self.last_message_id) + if self.last_message_id + else None + ) + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + async def delete_messages( + self, messages: Iterable[Snowflake], *, reason: str | None = None + ) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this. + + Parameters + ---------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id: int = messages[0].id + await self._state.http.delete_message(self.id, message_id, reason=reason) + return + + if len(messages) > 100: + raise ClientException("Can only bulk delete messages up to 100 messages") + + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: int | None = 100, + check: Callable[[Message], bool] = MISSING, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = False, + bulk: bool = True, + reason: str | None = None, + ) -> list[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have the :attr:`~Permissions.manage_messages` permission to + delete messages even if they are your own. + The :attr:`~Permissions.read_message_history` permission is + also needed to retrieve message history. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + bulk: :class:`bool` + If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting + a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Returns + ------- + List[:class:`.Message`] + The list of messages that were deleted. + + Raises + ------ + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Examples + -------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f'Deleted {len(deleted)} message(s)') + """ + return await discord.abc._purge_messages_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> list[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Returns + ------- + List[:class:`Webhook`] + The webhooks for this channel. + + Raises + ------ + Forbidden + You don't have permissions to get the webhooks. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook( + self, *, name: str, avatar: bytes | None = None, reason: str | None = None + ) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionchanged:: 1.1 + Added the ``reason`` keyword-only parameter. + + Parameters + ---------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) # type: ignore + + data = await self._state.http.create_webhook( + self.id, name=str(name), avatar=avatar, reason=reason + ) + return Webhook.from_state(data, state=self._state) + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.voice + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: str | None = None, reason: str | None = None + ) -> VoiceChannel: + return await self._clone_impl( + {"bitrate": self.bitrate, "user_limit": self.user_limit}, + name=name, + reason=reason, + ) + + @overload + async def edit( + self, + *, + name: str = ..., + bitrate: int = ..., + user_limit: int = ..., + position: int = ..., + sync_permissions: int = ..., + category: CategoryChannel | None = ..., + overwrites: Mapping[Role | Member, PermissionOverwrite] = ..., + rtc_region: VoiceRegion | None = ..., + video_quality_mode: VideoQualityMode = ..., + reason: str | None = ..., + ) -> VoiceChannel | None: + ... + + @overload + async def edit(self) -> VoiceChannel | None: + ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + bitrate: :class:`int` + The new channel's bitrate. + user_limit: :class:`int` + The new channel's user limit. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + rtc_region: Optional[:class:`VoiceRegion`] + The new region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + + .. versionadded:: 1.7 + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + + .. versionadded:: 2.0 + + Returns + ------- + Optional[:class:`.VoiceChannel`] + The newly edited voice channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + async def create_activity_invite( + self, activity: EmbeddedActivity | int, **kwargs + ) -> Invite: + """|coro| + + A shortcut method that creates an instant activity invite. + + You must have the :attr:`~discord.Permissions.start_embedded_activities` permission to + do this. + + Parameters + ---------- + activity: Union[:class:`discord.EmbeddedActivity`, :class:`int`] + The activity to create an invite for which can be an application id as well. + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite + doesn't expire. Defaults to ``0``. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there + are unlimited uses. Defaults to ``0``. + temporary: :class:`bool` + Denotes that the invite grants temporary membership + (i.e. they get kicked after they disconnect). Defaults to ``False``. + unique: :class:`bool` + Indicates if a unique invite URL should be created. Defaults to True. + If this is set to ``False`` then it will return a previously created + invite. + reason: Optional[:class:`str`] + The reason for creating this invite. Shows up on the audit log. + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + TypeError + If the activity is not a valid activity or application id. + ~discord.HTTPException + Invite creation failed. + """ + + if isinstance(activity, EmbeddedActivity): + activity = activity.value + + elif not isinstance(activity, int): + raise TypeError("Invalid type provided for the activity.") + + return await self.create_invite( + target_type=InviteTarget.embedded_application, + target_application_id=activity, + **kwargs, + ) + + +class StageChannel(VocalGuildChannel): + """Represents a Discord guild stage channel. + + .. versionadded:: 1.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + topic: Optional[:class:`str`] + The channel's topic. ``None`` if it isn't set. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + position: Optional[:class:`int`] + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. Can be ``None`` if the channel was received in an interaction. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a stage channel. + rtc_region: Optional[:class:`VoiceRegion`] + The region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + + .. versionadded:: 2.0 + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + __slots__ = ("topic",) + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("topic", self.topic), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" + + def _update(self, guild: Guild, data: StageChannelPayload) -> None: + super()._update(guild, data) + self.topic = data.get("topic") + + @property + def requesting_to_speak(self) -> list[Member]: + """List[:class:`Member`]: A list of members who are requesting to speak in the stage channel.""" + return [ + member + for member in self.members + if member.voice and member.voice.requested_to_speak_at is not None + ] + + @property + def speakers(self) -> list[Member]: + """List[:class:`Member`]: A list of members who have been permitted to speak in the stage channel. + + .. versionadded:: 2.0 + """ + return [ + member + for member in self.members + if member.voice + and not member.voice.suppress + and member.voice.requested_to_speak_at is None + ] + + @property + def listeners(self) -> list[Member]: + """List[:class:`Member`]: A list of members who are listening in the stage channel. + + .. versionadded:: 2.0 + """ + return [ + member for member in self.members if member.voice and member.voice.suppress + ] + + @property + def moderators(self) -> list[Member]: + """List[:class:`Member`]: A list of members who are moderating the stage channel. + + .. versionadded:: 2.0 + """ + required_permissions = Permissions.stage_moderator() + return [ + member + for member in self.members + if self.permissions_for(member) >= required_permissions + ] + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.stage_voice + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: str | None = None, reason: str | None = None + ) -> StageChannel: + return await self._clone_impl({}, name=name, reason=reason) + + @property + def instance(self) -> StageInstance | None: + """Optional[:class:`StageInstance`]: The running stage instance of the stage channel. + + .. versionadded:: 2.0 + """ + return utils.get(self.guild.stage_instances, channel_id=self.id) + + async def create_instance( + self, + *, + topic: str, + privacy_level: StagePrivacyLevel = MISSING, + reason: str | None = None, + send_notification: bool | None = False, + ) -> StageInstance: + """|coro| + + Create a stage instance. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionadded:: 2.0 + + Parameters + ---------- + topic: :class:`str` + The stage instance's topic. + privacy_level: :class:`StagePrivacyLevel` + The stage instance's privacy level. Defaults to :attr:`StagePrivacyLevel.guild_only`. + reason: :class:`str` + The reason the stage instance was created. Shows up on the audit log. + send_notification: :class:`bool` + Send a notification to everyone in the server that the stage instance has started. + Defaults to ``False``. Requires the ``mention_everyone`` permission. + + Returns + ------- + :class:`StageInstance` + The newly created stage instance. + + Raises + ------ + InvalidArgument + If the ``privacy_level`` parameter is not the proper type. + Forbidden + You do not have permissions to create a stage instance. + HTTPException + Creating a stage instance failed. + """ + + payload: dict[str, Any] = { + "channel_id": self.id, + "topic": topic, + "send_start_notification": send_notification, + } + + if privacy_level is not MISSING: + if not isinstance(privacy_level, StagePrivacyLevel): + raise InvalidArgument( + "privacy_level field must be of type PrivacyLevel" + ) + + payload["privacy_level"] = privacy_level.value + + data = await self._state.http.create_stage_instance(**payload, reason=reason) + return StageInstance(guild=self.guild, state=self._state, data=data) + + async def fetch_instance(self) -> StageInstance: + """|coro| + + Gets the running :class:`StageInstance`. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`StageInstance` + The stage instance. + + Raises + ------ + NotFound + The stage instance or channel could not be found. + HTTPException + Getting the stage instance failed. + """ + data = await self._state.http.get_stage_instance(self.id) + return StageInstance(guild=self.guild, state=self._state, data=data) + + @overload + async def edit( + self, + *, + name: str = ..., + topic: str | None = ..., + position: int = ..., + sync_permissions: int = ..., + category: CategoryChannel | None = ..., + overwrites: Mapping[Role | Member, PermissionOverwrite] = ..., + rtc_region: VoiceRegion | None = ..., + video_quality_mode: VideoQualityMode = ..., + reason: str | None = ..., + ) -> StageChannel | None: + ... + + @overload + async def edit(self) -> StageChannel | None: + ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 2.0 + The ``topic`` parameter must now be set via :attr:`create_instance`. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + rtc_region: Optional[:class:`VoiceRegion`] + The new region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + + .. versionadded:: 2.0 + + Returns + ------- + Optional[:class:`.StageChannel`] + The newly edited stage channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + +class CategoryChannel(discord.abc.GuildChannel, Hashable): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ---------- + name: :class:`str` + The category name. + guild: :class:`Guild` + The guild the category belongs to. + id: :class:`int` + The category channel ID. + position: Optional[:class:`int`] + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. Can be ``None`` if the channel was received in an interaction. + nsfw: :class:`bool` + If the channel is marked as "not safe for work". + + .. note:: + + To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. + flags: :class:`ChannelFlags` + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + __slots__ = ( + "name", + "id", + "guild", + "nsfw", + "_state", + "position", + "_overwrites", + "category_id", + "flags", + ) + + def __init__( + self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload + ): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self._update(guild, data) + + def __repr__(self) -> str: + return f"" + + def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: + # This data will always exist + self.guild: Guild = guild + self.name: str = data["name"] + self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + + # This data may be missing depending on how this object is being created/updated + if not data.pop("_invoke_flag", False): + self.nsfw: bool = data.get("nsfw", False) + self.position: int = data.get("position") + self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) + self._fill_overwrites(data) + + @property + def _sorting_bucket(self) -> int: + return ChannelType.category.value + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.category + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the category is NSFW.""" + return self.nsfw + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: str | None = None, reason: str | None = None + ) -> CategoryChannel: + return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) + + @overload + async def edit( + self, + *, + name: str = ..., + position: int = ..., + nsfw: bool = ..., + overwrites: Mapping[Role | Member, PermissionOverwrite] = ..., + reason: str | None = ..., + ) -> CategoryChannel | None: + ... + + @overload + async def edit(self) -> CategoryChannel | None: + ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new category's name. + position: :class:`int` + The new category's position. + nsfw: :class:`bool` + To mark the category as NSFW or not. + reason: Optional[:class:`str`] + The reason for editing this category. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + + Returns + ------- + Optional[:class:`.CategoryChannel`] + The newly edited category channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of categories. + Forbidden + You do not have permissions to edit the category. + HTTPException + Editing the category failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + @utils.copy_doc(discord.abc.GuildChannel.move) + async def move(self, **kwargs): + kwargs.pop("category", None) + await super().move(**kwargs) + + @property + def channels(self) -> list[GuildChannelType]: + """List[:class:`abc.GuildChannel`]: Returns the channels that are under this category. + + These are sorted by the official Discord UI, which places voice channels below the text channels. + """ + + def comparator(channel): + return not isinstance(channel, _TextChannel), (channel.position or -1) + + ret = [c for c in self.guild.channels if c.category_id == self.id] + ret.sort(key=comparator) + return ret + + @property + def text_channels(self) -> list[TextChannel]: + """List[:class:`TextChannel`]: Returns the text channels that are under this category.""" + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, TextChannel) + ] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def voice_channels(self) -> list[VoiceChannel]: + """List[:class:`VoiceChannel`]: Returns the voice channels that are under this category.""" + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, VoiceChannel) + ] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def stage_channels(self) -> list[StageChannel]: + """List[:class:`StageChannel`]: Returns the stage channels that are under this category. + + .. versionadded:: 1.7 + """ + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, StageChannel) + ] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def forum_channels(self) -> list[ForumChannel]: + """List[:class:`ForumChannel`]: Returns the forum channels that are under this category. + + .. versionadded:: 2.0 + """ + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, ForumChannel) + ] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + async def create_text_channel(self, name: str, **options: Any) -> TextChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category. + + Returns + ------- + :class:`TextChannel` + The channel that was just created. + """ + return await self.guild.create_text_channel(name, category=self, **options) + + async def create_voice_channel(self, name: str, **options: Any) -> VoiceChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category. + + Returns + ------- + :class:`VoiceChannel` + The channel that was just created. + """ + return await self.guild.create_voice_channel(name, category=self, **options) + + async def create_stage_channel(self, name: str, **options: Any) -> StageChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category. + + .. versionadded:: 1.7 + + Returns + ------- + :class:`StageChannel` + The channel that was just created. + """ + return await self.guild.create_stage_channel(name, category=self, **options) + + async def create_forum_channel(self, name: str, **options: Any) -> ForumChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_forum_channel` to create a :class:`ForumChannel` in the category. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`ForumChannel` + The channel that was just created. + """ + return await self.guild.create_forum_channel(name, category=self, **options) + + +DMC = TypeVar("DMC", bound="DMChannel") + + +class DMChannel(discord.abc.Messageable, Hashable): + """Represents a Discord direct message channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns a string representation of the channel + + Attributes + ---------- + recipient: Optional[:class:`User`] + The user you are participating with in the direct message channel. + If this channel is received through the gateway, the recipient information + may not be always available. + me: :class:`ClientUser` + The user presenting yourself. + id: :class:`int` + The direct message channel ID. + """ + + __slots__ = ("id", "recipient", "me", "_state") + + def __init__( + self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload + ): + self._state: ConnectionState = state + self.recipient: User | None = state.store_user(data["recipients"][0]) + self.me: ClientUser = me + self.id: int = int(data["id"]) + + async def _get_channel(self): + return self + + def __str__(self) -> str: + if self.recipient: + return f"Direct Message with {self.recipient}" + return "Direct Message with Unknown User" + + def __repr__(self) -> str: + return f"" + + @classmethod + def _from_message(cls: type[DMC], state: ConnectionState, channel_id: int) -> DMC: + self: DMC = cls.__new__(cls) + self._state = state + self.id = channel_id + self.recipient = None + # state.user won't be None here + self.me = state.user # type: ignore + return self + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.private + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f"https://discord.com/channels/@me/{self.id}" + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the direct message channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def permissions_for(self, obj: Any = None, /) -> Permissions: + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Actual direct messages do not really have the concept of permissions. + + This returns all the Text related permissions set to ``True`` except: + + - :attr:`~Permissions.send_tts_messages`: You cannot send TTS messages in a DM. + - :attr:`~Permissions.manage_messages`: You cannot delete others messages in a DM. + + Parameters + ---------- + obj: :class:`User` + The user to check permissions for. This parameter is ignored + but kept for compatibility with other ``permissions_for`` methods. + + Returns + ------- + :class:`Permissions` + The resolved permissions. + """ + + base = Permissions.text() + base.read_messages = True + base.send_tts_messages = False + base.manage_messages = False + return base + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + +class GroupChannel(discord.abc.Messageable, Hashable): + """Represents a Discord group channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns a string representation of the channel + + Attributes + ---------- + recipients: List[:class:`User`] + The users you are participating with in the group channel. + me: :class:`ClientUser` + The user presenting yourself. + id: :class:`int` + The group channel ID. + owner: Optional[:class:`User`] + The user that owns the group channel. + owner_id: :class:`int` + The owner ID that owns the group channel. + + .. versionadded:: 2.0 + name: Optional[:class:`str`] + The group channel's name if provided. + """ + + __slots__ = ( + "id", + "recipients", + "owner_id", + "owner", + "_icon", + "name", + "me", + "_state", + ) + + def __init__( + self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload + ): + self._state: ConnectionState = state + self.id: int = int(data["id"]) + self.me: ClientUser = me + self._update_group(data) + + def _update_group(self, data: GroupChannelPayload) -> None: + self.owner_id: int | None = utils._get_as_snowflake(data, "owner_id") + self._icon: str | None = data.get("icon") + self.name: str | None = data.get("name") + self.recipients: list[User] = [ + self._state.store_user(u) for u in data.get("recipients", []) + ] + + self.owner: BaseUser | None + if self.owner_id == self.me.id: + self.owner = self.me + else: + self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients) + + async def _get_channel(self): + return self + + def __str__(self) -> str: + if self.name: + return self.name + + if len(self.recipients) == 0: + return "Unnamed" + + return ", ".join(map(lambda x: x.name, self.recipients)) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.group + + @property + def icon(self) -> Asset | None: + """Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path="channel") + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f"https://discord.com/channels/@me/{self.id}" + + def permissions_for(self, obj: Snowflake, /) -> Permissions: + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Actual direct messages do not really have the concept of permissions. + + This returns all the Text related permissions set to ``True`` except: + + - :attr:`~Permissions.send_tts_messages`: You cannot send TTS messages in a DM. + - :attr:`~Permissions.manage_messages`: You cannot delete others messages in a DM. + + This also checks the kick_members permission if the user is the owner. + + Parameters + ---------- + obj: :class:`~discord.abc.Snowflake` + The user to check permissions for. + + Returns + ------- + :class:`Permissions` + The resolved permissions for the user. + """ + + base = Permissions.text() + base.read_messages = True + base.send_tts_messages = False + base.manage_messages = False + base.mention_everyone = True + + if obj.id == self.owner_id: + base.kick_members = True + + return base + + async def leave(self) -> None: + """|coro| + + Leave the group. + + If you are the only one in the group, this deletes it as well. + + Raises + ------ + HTTPException + Leaving the group failed. + """ + + await self._state.http.leave_group(self.id) + + +class PartialMessageable(discord.abc.Messageable, Hashable): + """Represents a partial messageable to aid with working messageable channels when + only a channel ID are present. + + The only way to construct this class is through :meth:`Client.get_partial_messageable`. + + Note that this class is trimmed down and has no rich attributes. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two partial messageables are equal. + + .. describe:: x != y + + Checks if two partial messageables are not equal. + + .. describe:: hash(x) + + Returns the partial messageable's hash. + + Attributes + ---------- + id: :class:`int` + The channel ID associated with this partial messageable. + type: Optional[:class:`ChannelType`] + The channel type associated with this partial messageable, if given. + """ + + def __init__( + self, state: ConnectionState, id: int, type: ChannelType | None = None + ): + self._state: ConnectionState = state + self._channel: Object = Object(id=id) + self.id: int = id + self.type: ChannelType | None = type + + async def _get_channel(self) -> Object: + return self._channel + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + +def _guild_channel_factory(channel_type: int): + value = try_enum(ChannelType, channel_type) + if value is ChannelType.text: + return TextChannel, value + elif value is ChannelType.voice: + return VoiceChannel, value + elif value is ChannelType.category: + return CategoryChannel, value + elif value is ChannelType.news: + return TextChannel, value + elif value is ChannelType.stage_voice: + return StageChannel, value + elif value is ChannelType.forum: + return ForumChannel, value + else: + return None, value + + +def _channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return cls, value + + +def _threaded_channel_factory(channel_type: int): + cls, value = _channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value + + +def _threaded_guild_channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value diff --git a/discord/client.py b/discord/client.py new file mode 100644 index 0000000..e6fa165 --- /dev/null +++ b/discord/client.py @@ -0,0 +1,1782 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +import sys +import traceback +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generator, Sequence, TypeVar + +import aiohttp + +from . import utils +from .activity import ActivityTypes, BaseActivity, create_activity +from .appinfo import AppInfo, PartialAppInfo +from .backoff import ExponentialBackoff +from .channel import PartialMessageable, _threaded_channel_factory +from .emoji import Emoji +from .enums import ChannelType, Status +from .errors import * +from .flags import ApplicationFlags, Intents +from .gateway import * +from .guild import Guild +from .http import HTTPClient +from .invite import Invite +from .iterators import GuildIterator +from .mentions import AllowedMentions +from .object import Object +from .stage_instance import StageInstance +from .state import ConnectionState +from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory +from .template import Template +from .threads import Thread +from .ui.view import View +from .user import ClientUser, User +from .utils import MISSING +from .voice_client import VoiceClient +from .webhook import Webhook +from .widget import Widget + +if TYPE_CHECKING: + from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime + from .channel import DMChannel + from .member import Member + from .message import Message + from .voice_client import VoiceProtocol + +__all__ = ("Client",) + +Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) + + +_log = logging.getLogger(__name__) + + +def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: + tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} + + if not tasks: + return + + _log.info("Cleaning up after %d tasks.", len(tasks)) + for task in tasks: + task.cancel() + + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + _log.info("All tasks finished cancelling.") + + for task in tasks: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "Unhandled exception during Client.run shutdown.", + "exception": task.exception(), + "task": task, + } + ) + + +def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: + try: + _cancel_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + _log.info("Closing the event loop.") + loop.close() + + +class Client: + r"""Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + A number of options can be passed to the :class:`Client`. + + Parameters + ----------- + max_messages: Optional[:class:`int`] + The maximum number of messages to store in the internal message cache. + This defaults to ``1000``. Passing in ``None`` disables the message cache. + + .. versionchanged:: 1.3 + Allow disabling the message cache and change the default size to ``1000``. + loop: Optional[:class:`asyncio.AbstractEventLoop`] + The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations. + Defaults to ``None``, in which case the default event loop is used via + :func:`asyncio.get_event_loop()`. + connector: Optional[:class:`aiohttp.BaseConnector`] + The connector to use for connection pooling. + proxy: Optional[:class:`str`] + Proxy URL. + proxy_auth: Optional[:class:`aiohttp.BasicAuth`] + An object that represents proxy HTTP Basic Authorization. + shard_id: Optional[:class:`int`] + Integer starting at ``0`` and less than :attr:`.shard_count`. + shard_count: Optional[:class:`int`] + The total number of shards. + application_id: :class:`int` + The client's application ID. + intents: :class:`Intents` + The intents that you want to enable for the session. This is a way of + disabling and enabling certain gateway events from triggering and being sent. + If not given, defaults to a regularly constructed :class:`Intents` class. + + .. versionadded:: 1.5 + member_cache_flags: :class:`MemberCacheFlags` + Allows for finer control over how the library caches members. + If not given, defaults to cache as much as possible with the + currently selected intents. + + .. versionadded:: 1.5 + chunk_guilds_at_startup: :class:`bool` + Indicates if :func:`.on_ready` should be delayed to chunk all guilds + at start-up if necessary. This operation is incredibly slow for large + amounts of guilds. The default is ``True`` if :attr:`Intents.members` + is ``True``. + + .. versionadded:: 1.5 + status: Optional[:class:`.Status`] + A status to start your presence with upon logging on to Discord. + activity: Optional[:class:`.BaseActivity`] + An activity to start your presence with upon logging on to Discord. + allowed_mentions: Optional[:class:`AllowedMentions`] + Control how the client handles mentions by default on every message sent. + + .. versionadded:: 1.4 + heartbeat_timeout: :class:`float` + The maximum numbers of seconds before timing out and restarting the + WebSocket in the case of not receiving a HEARTBEAT_ACK. Useful if + processing the initial packets take too long to the point of disconnecting + you. The default timeout is 60 seconds. + guild_ready_timeout: :class:`float` + The maximum number of seconds to wait for the GUILD_CREATE stream to end before + preparing the member cache and firing READY. The default timeout is 2 seconds. + + .. versionadded:: 1.4 + assume_unsync_clock: :class:`bool` + Whether to assume the system clock is unsynced. This applies to the ratelimit handling + code. If this is set to ``True``, the default, then the library uses the time to reset + a rate limit bucket given by Discord. If this is ``False`` then your system clock is + used to calculate how long to sleep for. If this is set to ``False`` it is recommended to + sync your system clock to Google's NTP server. + + .. versionadded:: 1.3 + enable_debug_events: :class:`bool` + Whether to enable events that are useful only for debugging gateway related information. + + Right now this involves :func:`on_socket_raw_receive` and :func:`on_socket_raw_send`. If + this is ``False`` then those events will not be dispatched (due to performance considerations). + To enable these events, this must be set to ``True``. Defaults to ``False``. + + .. versionadded:: 2.0 + + Attributes + ----------- + ws + The WebSocket gateway the client is currently connected to. Could be ``None``. + loop: :class:`asyncio.AbstractEventLoop` + The event loop that the client uses for asynchronous operations. + """ + + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + **options: Any, + ): + # self.ws is set in the connect method + self.ws: DiscordWebSocket = None # type: ignore + self.loop: asyncio.AbstractEventLoop = ( + asyncio.get_event_loop() if loop is None else loop + ) + self._listeners: dict[ + str, list[tuple[asyncio.Future, Callable[..., bool]]] + ] = {} + self.shard_id: int | None = options.get("shard_id") + self.shard_count: int | None = options.get("shard_count") + + connector: aiohttp.BaseConnector | None = options.pop("connector", None) + proxy: str | None = options.pop("proxy", None) + proxy_auth: aiohttp.BasicAuth | None = options.pop("proxy_auth", None) + unsync_clock: bool = options.pop("assume_unsync_clock", True) + self.http: HTTPClient = HTTPClient( + connector, + proxy=proxy, + proxy_auth=proxy_auth, + unsync_clock=unsync_clock, + loop=self.loop, + ) + + self._handlers: dict[str, Callable] = {"ready": self._handle_ready} + + self._hooks: dict[str, Callable] = { + "before_identify": self._call_before_identify_hook + } + + self._enable_debug_events: bool = options.pop("enable_debug_events", False) + self._connection: ConnectionState = self._get_state(**options) + self._connection.shard_count = self.shard_count + self._closed: bool = False + self._ready: asyncio.Event = asyncio.Event() + self._connection._get_websocket = self._get_websocket + self._connection._get_client = lambda: self + + if VoiceClient.warn_nacl: + VoiceClient.warn_nacl = False + _log.warning("PyNaCl is not installed, voice will NOT be supported") + + # internals + + def _get_websocket( + self, guild_id: int | None = None, *, shard_id: int | None = None + ) -> DiscordWebSocket: + return self.ws + + def _get_state(self, **options: Any) -> ConnectionState: + return ConnectionState( + dispatch=self.dispatch, + handlers=self._handlers, + hooks=self._hooks, + http=self.http, + loop=self.loop, + **options, + ) + + def _handle_ready(self) -> None: + self._ready.set() + + @property + def latency(self) -> float: + """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. + + This could be referred to as the Discord WebSocket protocol latency. + """ + ws = self.ws + return float("nan") if not ws else ws.latency + + def is_ws_ratelimited(self) -> bool: + """:class:`bool`: Whether the WebSocket is currently rate limited. + + This can be useful to know when deciding whether you should query members + using HTTP or via the gateway. + + .. versionadded:: 1.6 + """ + if self.ws: + return self.ws.is_ratelimited() + return False + + @property + def user(self) -> ClientUser | None: + """Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in.""" + return self._connection.user + + @property + def guilds(self) -> list[Guild]: + """List[:class:`.Guild`]: The guilds that the connected client is a member of.""" + return self._connection.guilds + + @property + def emojis(self) -> list[Emoji]: + """List[:class:`.Emoji`]: The emojis that the connected client has.""" + return self._connection.emojis + + @property + def stickers(self) -> list[GuildSticker]: + """List[:class:`.GuildSticker`]: The stickers that the connected client has. + + .. versionadded:: 2.0 + """ + return self._connection.stickers + + @property + def cached_messages(self) -> Sequence[Message]: + """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. + + .. versionadded:: 1.1 + """ + return utils.SequenceProxy(self._connection._messages or []) + + @property + def private_channels(self) -> list[PrivateChannel]: + """List[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on. + + .. note:: + + This returns only up to 128 most recent private channels due to an internal working + on how Discord deals with private channels. + """ + return self._connection.private_channels + + @property + def voice_clients(self) -> list[VoiceProtocol]: + """List[:class:`.VoiceProtocol`]: Represents a list of voice connections. + + These are usually :class:`.VoiceClient` instances. + """ + return self._connection.voice_clients + + @property + def application_id(self) -> int | None: + """Optional[:class:`int`]: The client's application ID. + + If this is not passed via ``__init__`` then this is retrieved + through the gateway when an event contains the data. Usually + after :func:`~discord.on_connect` is called. + + .. versionadded:: 2.0 + """ + return self._connection.application_id + + @property + def application_flags(self) -> ApplicationFlags: + """:class:`~discord.ApplicationFlags`: The client's application flags. + + .. versionadded:: 2.0 + """ + return self._connection.application_flags # type: ignore + + def is_ready(self) -> bool: + """:class:`bool`: Specifies if the client's internal cache is ready for use.""" + return self._ready.is_set() + + async def _run_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> None: + try: + await coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception: + try: + await self.on_error(event_name, *args, **kwargs) + except asyncio.CancelledError: + pass + + def _schedule_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task: + wrapped = self._run_event(coro, event_name, *args, **kwargs) + # Schedules the task + return asyncio.create_task(wrapped, name=f"pycord: {event_name}") + + def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: + _log.debug("Dispatching event %s", event) + method = f"on_{event}" + + listeners = self._listeners.get(event) + if listeners: + removed = [] + for i, (future, condition) in enumerate(listeners): + if future.cancelled(): + removed.append(i) + continue + + try: + result = condition(*args) + except Exception as exc: + future.set_exception(exc) + removed.append(i) + else: + if result: + if len(args) == 0: + future.set_result(None) + elif len(args) == 1: + future.set_result(args[0]) + else: + future.set_result(args) + removed.append(i) + + if len(removed) == len(listeners): + self._listeners.pop(event) + else: + for idx in reversed(removed): + del listeners[idx] + + try: + coro = getattr(self, method) + except AttributeError: + pass + else: + self._schedule_event(coro, method, *args, **kwargs) + + async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: + """|coro| + + The default error handler provided by the client. + + By default, this prints to :data:`sys.stderr` however it could be + overridden to have a different implementation. + Check :func:`~discord.on_error` for more details. + """ + print(f"Ignoring exception in {event_method}", file=sys.stderr) + traceback.print_exc() + + # hooks + + async def _call_before_identify_hook( + self, shard_id: int | None, *, initial: bool = False + ) -> None: + # This hook is an internal hook that actually calls the public one. + # It allows the library to have its own hook without stepping on the + # toes of those who need to override their own hook. + await self.before_identify_hook(shard_id, initial=initial) + + async def before_identify_hook( + self, shard_id: int | None, *, initial: bool = False + ) -> None: + """|coro| + + A hook that is called before IDENTIFYing a session. This is useful + if you wish to have more control over the synchronization of multiple + IDENTIFYing clients. + + The default implementation sleeps for 5 seconds. + + .. versionadded:: 1.4 + + Parameters + ---------- + shard_id: :class:`int` + The shard ID that requested being IDENTIFY'd + initial: :class:`bool` + Whether this IDENTIFY is the first initial IDENTIFY. + """ + + if not initial: + await asyncio.sleep(5.0) + + # login state management + + async def login(self, token: str) -> None: + """|coro| + + Logs in the client with the specified credentials. + + Parameters + ---------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + + Raises + ------ + TypeError + The token was in invalid type. + :exc:`LoginFailure` + The wrong credentials are passed. + :exc:`HTTPException` + An unknown HTTP related error occurred, + usually when it isn't 200 or the known incorrect credentials + passing status code. + """ + if not isinstance(token, str): + raise TypeError( + f"token must be of type str, not {token.__class__.__name__}" + ) + + _log.info("logging in using static token") + + data = await self.http.static_login(token.strip()) + self._connection.user = ClientUser(state=self._connection, data=data) + + async def connect(self, *, reconnect: bool = True) -> None: + """|coro| + + Creates a WebSocket connection and lets the WebSocket listen + to messages from Discord. This is a loop that runs the entire + event system and miscellaneous aspects of the library. Control + is not resumed until the WebSocket connection is terminated. + + Parameters + ---------- + reconnect: :class:`bool` + If we should attempt reconnecting, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). + + Raises + ------ + :exc:`GatewayNotFound` + The gateway to connect to Discord is not found. Usually if this + is thrown then there is a Discord API outage. + :exc:`ConnectionClosed` + The WebSocket connection has been terminated. + """ + + backoff = ExponentialBackoff() + ws_params = { + "initial": True, + "shard_id": self.shard_id, + } + while not self.is_closed(): + try: + coro = DiscordWebSocket.from_client(self, **ws_params) + self.ws = await asyncio.wait_for(coro, timeout=60.0) + ws_params["initial"] = False + while True: + await self.ws.poll_event() + except ReconnectWebSocket as e: + _log.info("Got a request to %s the websocket.", e.op) + self.dispatch("disconnect") + ws_params.update( + sequence=self.ws.sequence, + resume=e.resume, + session=self.ws.session_id, + ) + continue + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: + + self.dispatch("disconnect") + if not reconnect: + await self.close() + if isinstance(exc, ConnectionClosed) and exc.code == 1000: + # clean close, don't re-raise this + return + raise + + if self.is_closed(): + return + + # If we get connection reset by peer then try to RESUME + if isinstance(exc, OSError) and exc.errno in (54, 10054): + ws_params.update( + sequence=self.ws.sequence, + initial=False, + resume=True, + session=self.ws.session_id, + ) + continue + + # We should only get this when an unhandled close code happens, + # such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc) + # sometimes, discord sends us 1000 for unknown reasons, so we should reconnect + # regardless and rely on is_closed instead + if isinstance(exc, ConnectionClosed): + if exc.code == 4014: + raise PrivilegedIntentsRequired(exc.shard_id) from None + if exc.code != 1000: + await self.close() + raise + + retry = backoff.delay() + _log.exception("Attempting a reconnect in %.2fs", retry) + await asyncio.sleep(retry) + # Always try to RESUME the connection + # If the connection is not RESUME-able then the gateway will invalidate the session. + # This is apparently what the official Discord client does. + ws_params.update( + sequence=self.ws.sequence, resume=True, session=self.ws.session_id + ) + + async def close(self) -> None: + """|coro| + + Closes the connection to Discord. + """ + if self._closed: + return + + self._closed = True + + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # if an error happens during disconnects, disregard it. + pass + + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) + + await self.http.close() + self._ready.clear() + + def clear(self) -> None: + """Clears the internal state of the bot. + + After this, the bot can be considered "re-opened", i.e. :meth:`is_closed` + and :meth:`is_ready` both return ``False`` along with the bot's internal + cache cleared. + """ + self._closed = False + self._ready.clear() + self._connection.clear() + self.http.recreate() + + async def start(self, token: str, *, reconnect: bool = True) -> None: + """|coro| + + A shorthand coroutine for :meth:`login` + :meth:`connect`. + + Raises + ------ + TypeError + An unexpected keyword argument was received. + """ + await self.login(token) + await self.connect(reconnect=reconnect) + + def run(self, *args: Any, **kwargs: Any) -> None: + """A blocking call that abstracts away the event loop + initialisation from you. + + If you want more control over the event loop then this + function should not be used. Use :meth:`start` coroutine + or :meth:`connect` + :meth:`login`. + + Roughly Equivalent to: :: + + try: + loop.run_until_complete(start(*args, **kwargs)) + except KeyboardInterrupt: + loop.run_until_complete(close()) + # cancel all tasks lingering + finally: + loop.close() + + .. warning:: + + This function must be the last function to call due to the fact that it + is blocking. That means that registration of events or anything being + called after this function call will not execute until it returns. + """ + loop = self.loop + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + loop.add_signal_handler(signal.SIGTERM, loop.stop) + except (NotImplementedError, RuntimeError): + pass + + async def runner(): + try: + await self.start(*args, **kwargs) + finally: + if not self.is_closed(): + await self.close() + + def stop_loop_on_completion(f): + loop.stop() + + future = asyncio.ensure_future(runner(), loop=loop) + future.add_done_callback(stop_loop_on_completion) + try: + loop.run_forever() + except KeyboardInterrupt: + _log.info("Received signal to terminate bot and event loop.") + finally: + future.remove_done_callback(stop_loop_on_completion) + _log.info("Cleaning up tasks.") + _cleanup_loop(loop) + + if not future.cancelled(): + try: + return future.result() + except KeyboardInterrupt: + # I am unsure why this gets raised here but suppress it anyway + return None + + # properties + + def is_closed(self) -> bool: + """:class:`bool`: Indicates if the WebSocket connection is closed.""" + return self._closed + + @property + def activity(self) -> ActivityTypes | None: + """Optional[:class:`.BaseActivity`]: The activity being used upon + logging in. + """ + return create_activity(self._connection._activity) + + @activity.setter + def activity(self, value: ActivityTypes | None) -> None: + if value is None: + self._connection._activity = None + elif isinstance(value, BaseActivity): + # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] + self._connection._activity = value.to_dict() # type: ignore + else: + raise TypeError("activity must derive from BaseActivity.") + + @property + def status(self): + """:class:`.Status`: + The status being used upon logging on to Discord. + + .. versionadded: 2.0 + """ + if self._connection._status in {state.value for state in Status}: + return Status(self._connection._status) + return Status.online + + @status.setter + def status(self, value): + if value is Status.offline: + self._connection._status = "invisible" + elif isinstance(value, Status): + self._connection._status = str(value) + else: + raise TypeError("status must derive from Status.") + + @property + def allowed_mentions(self) -> AllowedMentions | None: + """Optional[:class:`~discord.AllowedMentions`]: The allowed mention configuration. + + .. versionadded:: 1.4 + """ + return self._connection.allowed_mentions + + @allowed_mentions.setter + def allowed_mentions(self, value: AllowedMentions | None) -> None: + if value is None or isinstance(value, AllowedMentions): + self._connection.allowed_mentions = value + else: + raise TypeError( + f"allowed_mentions must be AllowedMentions not {value.__class__!r}" + ) + + @property + def intents(self) -> Intents: + """:class:`~discord.Intents`: The intents configured for this connection. + + .. versionadded:: 1.5 + """ + return self._connection.intents + + # helpers/getters + + @property + def users(self) -> list[User]: + """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" + return list(self._connection._users.values()) + + async def fetch_application(self, application_id: int, /) -> PartialAppInfo: + """|coro| + Retrieves a :class:`.PartialAppInfo` from an application ID. + + Parameters + ---------- + application_id: :class:`int` + The application ID to retrieve information from. + + Returns + ------- + :class:`.PartialAppInfo` + The application information. + + Raises + ------ + NotFound + An application with this ID does not exist. + HTTPException + Retrieving the application failed. + """ + data = await self.http.get_application(application_id) + return PartialAppInfo(state=self._connection, data=data) + + def get_channel(self, id: int, /) -> GuildChannel | Thread | PrivateChannel | None: + """Returns a channel or thread with the given ID. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] + The returned channel or ``None`` if not found. + """ + return self._connection.get_channel(id) + + def get_message(self, id: int, /) -> Message | None: + """Returns a message the given ID. + + This is useful if you have a message_id but don't want to do an API call + to access the message. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`.Message`] + The returned message or ``None`` if not found. + """ + return self._connection._get_message(id) + + def get_partial_messageable( + self, id: int, *, type: ChannelType | None = None + ) -> PartialMessageable: + """Returns a partial messageable with the given channel ID. + + This is useful if you have a channel_id but don't want to do an API call + to send messages to it. + + .. versionadded:: 2.0 + + Parameters + ---------- + id: :class:`int` + The channel ID to create a partial messageable for. + type: Optional[:class:`.ChannelType`] + The underlying channel type for the partial messageable. + + Returns + ------- + :class:`.PartialMessageable` + The partial messageable + """ + return PartialMessageable(state=self._connection, id=id, type=type) + + def get_stage_instance(self, id: int, /) -> StageInstance | None: + """Returns a stage instance with the given stage channel ID. + + .. versionadded:: 2.0 + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`.StageInstance`] + The stage instance or ``None`` if not found. + """ + from .channel import StageChannel + + channel = self._connection.get_channel(id) + + if isinstance(channel, StageChannel): + return channel.instance + + def get_guild(self, id: int, /) -> Guild | None: + """Returns a guild with the given ID. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`.Guild`] + The guild or ``None`` if not found. + """ + return self._connection._get_guild(id) + + def get_user(self, id: int, /) -> User | None: + """Returns a user with the given ID. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`~discord.User`] + The user or ``None`` if not found. + """ + return self._connection.get_user(id) + + def get_emoji(self, id: int, /) -> Emoji | None: + """Returns an emoji with the given ID. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`.Emoji`] + The custom emoji or ``None`` if not found. + """ + return self._connection.get_emoji(id) + + def get_sticker(self, id: int, /) -> GuildSticker | None: + """Returns a guild sticker with the given ID. + + .. versionadded:: 2.0 + + .. note:: + + To retrieve standard stickers, use :meth:`.fetch_sticker`. + or :meth:`.fetch_premium_sticker_packs`. + + Returns + ------- + Optional[:class:`.GuildSticker`] + The sticker or ``None`` if not found. + """ + return self._connection.get_sticker(id) + + def get_all_channels(self) -> Generator[GuildChannel, None, None]: + """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. + + This is equivalent to: :: + + for guild in client.guilds: + for channel in guild.channels: + yield channel + + .. note:: + + Just because you receive a :class:`.abc.GuildChannel` does not mean that + you can communicate in said channel. :meth:`.abc.GuildChannel.permissions_for` should + be used for that. + + Yields + ------ + :class:`.abc.GuildChannel` + A channel the client can 'access'. + """ + + for guild in self.guilds: + yield from guild.channels + + def get_all_members(self) -> Generator[Member, None, None]: + """Returns a generator with every :class:`.Member` the client can see. + + This is equivalent to: :: + + for guild in client.guilds: + for member in guild.members: + yield member + + Yields + ------ + :class:`.Member` + A member the client can see. + """ + for guild in self.guilds: + yield from guild.members + + async def get_or_fetch_user(self, id: int, /) -> User | None: + """Looks up a user in the user cache or fetches if not found. + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`~discord.User`] + The user or ``None`` if not found. + """ + + return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None) + + # listeners/waiters + + async def wait_until_ready(self) -> None: + """|coro| + + Waits until the client's internal cache is all ready. + """ + await self._ready.wait() + + def wait_for( + self, + event: str, + *, + check: Callable[..., bool] | None = None, + timeout: float | None = None, + ) -> Any: + """|coro| + + Waits for a WebSocket event to be dispatched. + + This could be used to wait for a user to reply to a message, + or to react to a message, or to edit a message in a self-contained + way. + + The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default, + it does not timeout. Note that this does propagate the + :exc:`asyncio.TimeoutError` for you in case of timeout and is provided for + ease of use. + + In case the event returns multiple arguments, a :class:`tuple` containing those + arguments is returned instead. Please check the + :ref:`documentation ` for a list of events and their + parameters. + + This function returns the **first event that meets the requirements**. + + Parameters + ---------- + event: :class:`str` + The event name, similar to the :ref:`event reference `, + but without the ``on_`` prefix, to wait for. + check: Optional[Callable[..., :class:`bool`]] + A predicate to check what to wait for. The arguments must meet the + parameters of the event being waited for. + timeout: Optional[:class:`float`] + The number of seconds to wait before timing out and raising + :exc:`asyncio.TimeoutError`. + + Returns + ------- + Any + Returns no arguments, a single argument, or a :class:`tuple` of multiple + arguments that mirrors the parameters passed in the + :ref:`event reference `. + + Raises + ------ + asyncio.TimeoutError + Raised if a timeout is provided and reached. + + Examples + -------- + + Waiting for a user reply: :: + + @client.event + async def on_message(message): + if message.content.startswith('$greet'): + channel = message.channel + await channel.send('Say hello!') + + def check(m): + return m.content == 'hello' and m.channel == channel + + msg = await client.wait_for('message', check=check) + await channel.send(f'Hello {msg.author}!') + + Waiting for a thumbs up reaction from the message author: :: + + @client.event + async def on_message(message): + if message.content.startswith('$thumb'): + channel = message.channel + await channel.send('Send me that \N{THUMBS UP SIGN} reaction, mate') + + def check(reaction, user): + return user == message.author and str(reaction.emoji) == '\N{THUMBS UP SIGN}' + + try: + reaction, user = await client.wait_for('reaction_add', timeout=60.0, check=check) + except asyncio.TimeoutError: + await channel.send('\N{THUMBS DOWN SIGN}') + else: + await channel.send('\N{THUMBS UP SIGN}') + """ + + future = self.loop.create_future() + if check is None: + + def _check(*args): + return True + + check = _check + + ev = event.lower() + try: + listeners = self._listeners[ev] + except KeyError: + listeners = [] + self._listeners[ev] = listeners + + listeners.append((future, check)) + return asyncio.wait_for(future, timeout) + + # event registration + + def event(self, coro: Coro) -> Coro: + """A decorator that registers an event to listen to. + + You can find more info about the events on the :ref:`documentation below `. + + The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + + Example + ------- + + .. code-block:: python3 + + @client.event + async def on_ready(): + print('Ready!') + """ + + if not asyncio.iscoroutinefunction(coro): + raise TypeError("event registered must be a coroutine function") + + setattr(self, coro.__name__, coro) + _log.debug("%s has successfully been registered as an event", coro.__name__) + return coro + + async def change_presence( + self, + *, + activity: BaseActivity | None = None, + status: Status | None = None, + ): + """|coro| + + Changes the client's presence. + + Parameters + ---------- + activity: Optional[:class:`.BaseActivity`] + The activity being done. ``None`` if no currently active activity is done. + status: Optional[:class:`.Status`] + Indicates what status to change to. If ``None``, then + :attr:`.Status.online` is used. + + Raises + ------ + :exc:`InvalidArgument` + If the ``activity`` parameter is not the proper type. + + Example + ------- + + .. code-block:: python3 + + game = discord.Game("with the API") + await client.change_presence(status=discord.Status.idle, activity=game) + + .. versionchanged:: 2.0 + Removed the ``afk`` keyword-only parameter. + """ + + if status is None: + status_str = "online" + status = Status.online + elif status is Status.offline: + status_str = "invisible" + status = Status.offline + else: + status_str = str(status) + + await self.ws.change_presence(activity=activity, status=status_str) + + for guild in self._connection.guilds: + me = guild.me + if me is None: + continue + + me.activities = (activity,) if activity is not None else () + me.status = status + + # Guild stuff + + def fetch_guilds( + self, + *, + limit: int | None = 100, + before: SnowflakeTime = None, + after: SnowflakeTime = None, + ) -> GuildIterator: + """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. + + .. note:: + + Using this, you will only receive :attr:`.Guild.owner`, :attr:`.Guild.icon`, + :attr:`.Guild.id`, and :attr:`.Guild.name` per :class:`.Guild`. + + .. note:: + + This method is an API call. For general usage, consider :attr:`guilds` instead. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of guilds to retrieve. + If ``None``, it retrieves every guild you have access to. Note, however, + that this would make it a slow operation. + Defaults to ``100``. + before: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`] + Retrieves guilds before this date or object. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + after: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`] + Retrieve guilds after this date or object. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + + Yields + ------ + :class:`.Guild` + The guild with the guild data parsed. + + Raises + ------ + :exc:`HTTPException` + Getting the guilds failed. + + Examples + -------- + + Usage :: + + async for guild in client.fetch_guilds(limit=150): + print(guild.name) + + Flattening into a list :: + + guilds = await client.fetch_guilds(limit=150).flatten() + # guilds is now a list of Guild... + + All parameters are optional. + """ + return GuildIterator(self, limit=limit, before=before, after=after) + + async def fetch_template(self, code: Template | str) -> Template: + """|coro| + + Gets a :class:`.Template` from a discord.new URL or code. + + Parameters + ---------- + code: Union[:class:`.Template`, :class:`str`] + The Discord Template Code or URL (must be a discord.new URL). + + Returns + ------- + :class:`.Template` + The template from the URL/code. + + Raises + ------ + :exc:`NotFound` + The template is invalid. + :exc:`HTTPException` + Getting the template failed. + """ + code = utils.resolve_template(code) + data = await self.http.get_template(code) + return Template(data=data, state=self._connection) # type: ignore + + async def fetch_guild(self, guild_id: int, /, *, with_counts=True) -> Guild: + """|coro| + + Retrieves a :class:`.Guild` from an ID. + + .. note:: + + Using this, you will **not** receive :attr:`.Guild.channels`, :attr:`.Guild.members`, + :attr:`.Member.activity` and :attr:`.Member.voice` per :class:`.Member`. + + .. note:: + + This method is an API call. For general usage, consider :meth:`get_guild` instead. + + Parameters + ---------- + guild_id: :class:`int` + The guild's ID to fetch from. + + with_counts: :class:`bool` + Whether to include count information in the guild. This fills the + :attr:`.Guild.approximate_member_count` and :attr:`.Guild.approximate_presence_count` + fields. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`.Guild` + The guild from the ID. + + Raises + ------ + :exc:`Forbidden` + You do not have access to the guild. + :exc:`HTTPException` + Getting the guild failed. + """ + data = await self.http.get_guild(guild_id, with_counts=with_counts) + return Guild(data=data, state=self._connection) + + async def create_guild( + self, + *, + name: str, + icon: bytes = MISSING, + code: str = MISSING, + ) -> Guild: + """|coro| + + Creates a :class:`.Guild`. + + Bot accounts in more than 10 guilds are not allowed to create guilds. + + Parameters + ---------- + name: :class:`str` + The name of the guild. + icon: Optional[:class:`bytes`] + The :term:`py:bytes-like object` representing the icon. See :meth:`.ClientUser.edit` + for more details on what is expected. + code: :class:`str` + The code for a template to create the guild with. + + .. versionadded:: 1.4 + + Returns + ------- + :class:`.Guild` + The guild created. This is not the same guild that is + added to cache. + + Raises + ------ + :exc:`HTTPException` + Guild creation failed. + :exc:`InvalidArgument` + Invalid icon image format given. Must be PNG or JPG. + """ + if icon is not MISSING: + icon_base64 = utils._bytes_to_base64_data(icon) + else: + icon_base64 = None + + if code: + data = await self.http.create_from_template(code, name, icon_base64) + else: + data = await self.http.create_guild(name, icon_base64) + return Guild(data=data, state=self._connection) + + async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance: + """|coro| + + Gets a :class:`.StageInstance` for a stage channel id. + + .. versionadded:: 2.0 + + Parameters + ---------- + channel_id: :class:`int` + The stage channel ID. + + Returns + ------- + :class:`.StageInstance` + The stage instance from the stage channel ID. + + Raises + ------ + :exc:`NotFound` + The stage instance or channel could not be found. + :exc:`HTTPException` + Getting the stage instance failed. + """ + data = await self.http.get_stage_instance(channel_id) + guild = self.get_guild(int(data["guild_id"])) + return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore + + # Invite management + + async def fetch_invite( + self, + url: Invite | str, + *, + with_counts: bool = True, + with_expiration: bool = True, + event_id: int | None = None, + ) -> Invite: + """|coro| + + Gets an :class:`.Invite` from a discord.gg URL or ID. + + .. note:: + + If the invite is for a guild you have not joined, the guild and channel + attributes of the returned :class:`.Invite` will be :class:`.PartialInviteGuild` and + :class:`.PartialInviteChannel` respectively. + + Parameters + ---------- + url: Union[:class:`.Invite`, :class:`str`] + The Discord invite ID or URL (must be a discord.gg URL). + with_counts: :class:`bool` + Whether to include count information in the invite. This fills the + :attr:`.Invite.approximate_member_count` and :attr:`.Invite.approximate_presence_count` + fields. + with_expiration: :class:`bool` + Whether to include the expiration date of the invite. This fills the + :attr:`.Invite.expires_at` field. + + .. versionadded:: 2.0 + event_id: Optional[:class:`int`] + The ID of the scheduled event to be associated with the event. + + See :meth:`Invite.set_scheduled_event` for more + info on event invite linking. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`.Invite` + The invite from the URL/ID. + + Raises + ------ + :exc:`NotFound` + The invite has expired or is invalid. + :exc:`HTTPException` + Getting the invite failed. + """ + + invite_id = utils.resolve_invite(url) + data = await self.http.get_invite( + invite_id, + with_counts=with_counts, + with_expiration=with_expiration, + guild_scheduled_event_id=event_id, + ) + return Invite.from_incomplete(state=self._connection, data=data) + + async def delete_invite(self, invite: Invite | str) -> None: + """|coro| + + Revokes an :class:`.Invite`, URL, or ID to an invite. + + You must have the :attr:`~.Permissions.manage_channels` permission in + the associated guild to do this. + + Parameters + ---------- + invite: Union[:class:`.Invite`, :class:`str`] + The invite to revoke. + + Raises + ------ + :exc:`Forbidden` + You do not have permissions to revoke invites. + :exc:`NotFound` + The invite is invalid or expired. + :exc:`HTTPException` + Revoking the invite failed. + """ + + invite_id = utils.resolve_invite(invite) + await self.http.delete_invite(invite_id) + + # Miscellaneous stuff + + async def fetch_widget(self, guild_id: int, /) -> Widget: + """|coro| + + Gets a :class:`.Widget` from a guild ID. + + .. note:: + + The guild must have the widget enabled to get this information. + + Parameters + ---------- + guild_id: :class:`int` + The ID of the guild. + + Returns + ------- + :class:`.Widget` + The guild's widget. + + Raises + ------ + :exc:`Forbidden` + The widget for this guild is disabled. + :exc:`HTTPException` + Retrieving the widget failed. + """ + data = await self.http.get_widget(guild_id) + + return Widget(state=self._connection, data=data) + + async def application_info(self) -> AppInfo: + """|coro| + + Retrieves the bot's application information. + + Returns + ------- + :class:`.AppInfo` + The bot's application information. + + Raises + ------ + :exc:`HTTPException` + Retrieving the information failed somehow. + """ + data = await self.http.application_info() + if "rpc_origins" not in data: + data["rpc_origins"] = None + return AppInfo(self._connection, data) + + async def fetch_user(self, user_id: int, /) -> User: + """|coro| + + Retrieves a :class:`~discord.User` based on their ID. + You do not have to share any guilds with the user to get this information, + however many operations do require that you do. + + .. note:: + + This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, + consider :meth:`get_user` instead. + + Parameters + ---------- + user_id: :class:`int` + The user's ID to fetch from. + + Returns + ------- + :class:`~discord.User` + The user you requested. + + Raises + ------ + :exc:`NotFound` + A user with this ID does not exist. + :exc:`HTTPException` + Fetching the user failed. + """ + data = await self.http.get_user(user_id) + return User(state=self._connection, data=data) + + async def fetch_channel( + self, channel_id: int, / + ) -> GuildChannel | PrivateChannel | Thread: + """|coro| + + Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. + + .. note:: + + This method is an API call. For general usage, consider :meth:`get_channel` instead. + + .. versionadded:: 1.2 + + Returns + ------- + Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`] + The channel from the ID. + + Raises + ------ + :exc:`InvalidData` + An unknown channel type was received from Discord. + :exc:`HTTPException` + Retrieving the channel failed. + :exc:`NotFound` + Invalid Channel ID. + :exc:`Forbidden` + You do not have permission to fetch this channel. + """ + data = await self.http.get_channel(channel_id) + + factory, ch_type = _threaded_channel_factory(data["type"]) + if factory is None: + raise InvalidData( + "Unknown channel type {type} for channel ID {id}.".format_map(data) + ) + + if ch_type in (ChannelType.group, ChannelType.private): + # the factory will be a DMChannel or GroupChannel here + return factory(me=self.user, data=data, state=self._connection) + # the factory can't be a DMChannel or GroupChannel here + guild_id = int(data["guild_id"]) # type: ignore + guild = self.get_guild(guild_id) or Object(id=guild_id) + # GuildChannels expect a Guild, we may be passing an Object + return factory(guild=guild, state=self._connection, data=data) + + async def fetch_webhook(self, webhook_id: int, /) -> Webhook: + """|coro| + + Retrieves a :class:`.Webhook` with the specified ID. + + Returns + ------- + :class:`.Webhook` + The webhook you requested. + + Raises + ------ + :exc:`HTTPException` + Retrieving the webhook failed. + :exc:`NotFound` + Invalid webhook ID. + :exc:`Forbidden` + You do not have permission to fetch this webhook. + """ + data = await self.http.get_webhook(webhook_id) + return Webhook.from_state(data, state=self._connection) + + async def fetch_sticker(self, sticker_id: int, /) -> StandardSticker | GuildSticker: + """|coro| + + Retrieves a :class:`.Sticker` with the specified ID. + + .. versionadded:: 2.0 + + Returns + ------- + Union[:class:`.StandardSticker`, :class:`.GuildSticker`] + The sticker you requested. + + Raises + ------ + :exc:`HTTPException` + Retrieving the sticker failed. + :exc:`NotFound` + Invalid sticker ID. + """ + data = await self.http.get_sticker(sticker_id) + cls, _ = _sticker_factory(data["type"]) # type: ignore + return cls(state=self._connection, data=data) # type: ignore + + async def fetch_premium_sticker_packs(self) -> list[StickerPack]: + """|coro| + + Retrieves all available premium sticker packs. + + .. versionadded:: 2.0 + + Returns + ------- + List[:class:`.StickerPack`] + All available premium sticker packs. + + Raises + ------ + :exc:`HTTPException` + Retrieving the sticker packs failed. + """ + data = await self.http.list_premium_sticker_packs() + return [ + StickerPack(state=self._connection, data=pack) + for pack in data["sticker_packs"] + ] + + async def create_dm(self, user: Snowflake) -> DMChannel: + """|coro| + + Creates a :class:`.DMChannel` with this user. + + This should be rarely called, as this is done transparently for most + people. + + .. versionadded:: 2.0 + + Parameters + ---------- + user: :class:`~discord.abc.Snowflake` + The user to create a DM with. + + Returns + ------- + :class:`.DMChannel` + The channel that was created. + """ + state = self._connection + found = state._get_private_channel_by_user(user.id) + if found: + return found + + data = await state.http.start_private_message(user.id) + return state.add_dm_channel(data) + + def add_view(self, view: View, *, message_id: int | None = None) -> None: + """Registers a :class:`~discord.ui.View` for persistent listening. + + This method should be used for when a view is comprised of components + that last longer than the lifecycle of the program. + + .. versionadded:: 2.0 + + Parameters + ---------- + view: :class:`discord.ui.View` + The view to register for dispatching. + message_id: Optional[:class:`int`] + The message ID that the view is attached to. This is currently used to + refresh the view's state during message update events. If not given + then message update events are not propagated for the view. + + Raises + ------ + TypeError + A view was not passed. + ValueError + The view is not persistent. A persistent view has no timeout + and all their components have an explicitly provided ``custom_id``. + """ + + if not isinstance(view, View): + raise TypeError(f"expected an instance of View not {view.__class__!r}") + + if not view.is_persistent(): + raise ValueError( + "View is not persistent. Items need to have a custom_id set and View must have no timeout" + ) + + self._connection.store_view(view, message_id) + + @property + def persistent_views(self) -> Sequence[View]: + """Sequence[:class:`.View`]: A sequence of persistent views added to the client. + + .. versionadded:: 2.0 + """ + return self._connection.persistent_views diff --git a/discord/cog.py b/discord/cog.py new file mode 100644 index 0000000..a634e36 --- /dev/null +++ b/discord/cog.py @@ -0,0 +1,1118 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import importlib +import inspect +import os +import pathlib +import sys +import types +from typing import Any, Callable, ClassVar, Generator, Mapping, TypeVar, overload + +import discord.utils + +from . import errors +from .commands import ( + ApplicationCommand, + ApplicationContext, + SlashCommandGroup, + _BaseCommand, +) + +__all__ = ( + "CogMeta", + "Cog", + "CogMixin", +) + +CogT = TypeVar("CogT", bound="Cog") +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + +MISSING: Any = discord.utils.MISSING + + +def _is_submodule(parent: str, child: str) -> bool: + return parent == child or child.startswith(f"{parent}.") + + +class CogMeta(type): + """A metaclass for defining a cog. + + Note that you should probably not use this directly. It is exposed + purely for documentation purposes along with making custom metaclasses to intermix + with other metaclasses such as the :class:`abc.ABCMeta` metaclass. + + For example, to create an abstract cog mixin class, the following would be done. + + .. code-block:: python3 + + import abc + + class CogABCMeta(discord.CogMeta, abc.ABCMeta): + pass + + class SomeMixin(metaclass=abc.ABCMeta): + pass + + class SomeCogMixin(SomeMixin, discord.Cog, metaclass=CogABCMeta): + pass + + .. note:: + + When passing an attribute of a metaclass that is documented below, note + that you must pass it as a keyword-only argument to the class creation + like the following example: + + .. code-block:: python3 + + class MyCog(discord.Cog, name='My Cog'): + pass + + Attributes + ---------- + name: :class:`str` + The cog name. By default, it is the name of the class with no modification. + description: :class:`str` + The cog description. By default, it is the cleaned docstring of the class. + + .. versionadded:: 1.6 + + command_attrs: :class:`dict` + A list of attributes to apply to every command inside this cog. The dictionary + is passed into the :class:`Command` options at ``__init__``. + If you specify attributes inside the command attribute in the class, it will + override the one specified inside this attribute. For example: + + .. code-block:: python3 + + class MyCog(discord.Cog, command_attrs=dict(hidden=True)): + @discord.slash_command() + async def foo(self, ctx): + pass # hidden -> True + + @discord.slash_command(hidden=False) + async def bar(self, ctx): + pass # hidden -> False + + guild_ids: Optional[List[:class:`int`]] + A shortcut to :attr:`.command_attrs`, what ``guild_ids`` should all application commands have + in the cog. You can override this by setting ``guild_ids`` per command. + + .. versionadded:: 2.0 + """ + + __cog_name__: str + __cog_settings__: dict[str, Any] + __cog_commands__: list[ApplicationCommand] + __cog_listeners__: list[tuple[str, str]] + __cog_guild_ids__: list[int] + + def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: + name, bases, attrs = args + attrs["__cog_name__"] = kwargs.pop("name", name) + attrs["__cog_settings__"] = kwargs.pop("command_attrs", {}) + attrs["__cog_guild_ids__"] = kwargs.pop("guild_ids", []) + + description = kwargs.pop("description", None) + if description is None: + description = inspect.cleandoc(attrs.get("__doc__", "")) + attrs["__cog_description__"] = description + + commands = {} + listeners = {} + no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})" + + new_cls = super().__new__(cls, name, bases, attrs, **kwargs) + + valid_commands = [ + (c for i, c in j.__dict__.items() if isinstance(c, _BaseCommand)) + for j in reversed(new_cls.__mro__) + ] + if any(isinstance(i, ApplicationCommand) for i in valid_commands) and any( + not isinstance(i, _BaseCommand) for i in valid_commands + ): + _filter = ApplicationCommand + else: + _filter = _BaseCommand + + for base in reversed(new_cls.__mro__): + for elem, value in base.__dict__.items(): + if elem in commands: + del commands[elem] + if elem in listeners: + del listeners[elem] + + if getattr(value, "parent", None) and isinstance( + value, ApplicationCommand + ): + # Skip commands if they are a part of a group + continue + + is_static_method = isinstance(value, staticmethod) + if is_static_method: + value = value.__func__ + if isinstance(value, _filter): + if is_static_method: + raise TypeError( + f"Command in method {base}.{elem!r} must not be staticmethod." + ) + if elem.startswith(("cog_", "bot_")): + raise TypeError(no_bot_cog.format(base, elem)) + commands[elem] = value + + # a test to see if this value is a BridgeCommand + if hasattr(value, "add_to") and not getattr(value, "parent", None): + if is_static_method: + raise TypeError( + f"Command in method {base}.{elem!r} must not be staticmethod." + ) + if elem.startswith(("cog_", "bot_")): + raise TypeError(no_bot_cog.format(base, elem)) + + commands[f"ext_{elem}"] = value.ext_variant + commands[f"app_{elem}"] = value.slash_variant + for cmd in getattr(value, "subcommands", []): + commands[ + f"ext_{cmd.ext_variant.qualified_name}" + ] = cmd.ext_variant + + if inspect.iscoroutinefunction(value): + try: + getattr(value, "__cog_listener__") + except AttributeError: + continue + else: + if elem.startswith(("cog_", "bot_")): + raise TypeError(no_bot_cog.format(base, elem)) + listeners[elem] = value + + new_cls.__cog_commands__ = list(commands.values()) + + listeners_as_list = [] + for listener in listeners.values(): + for listener_name in listener.__cog_listener_names__: + # I use __name__ instead of just storing the value, so I can inject + # the self attribute when the time comes to add them to the bot + listeners_as_list.append((listener_name, listener.__name__)) + + new_cls.__cog_listeners__ = listeners_as_list + + cmd_attrs = new_cls.__cog_settings__ + + # Either update the command with the cog provided defaults or copy it. + # r.e type ignore, type-checker complains about overriding a ClassVar + new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in new_cls.__cog_commands__) # type: ignore + + name_filter = lambda c: "app" if isinstance(c, ApplicationCommand) else "ext" + + lookup = { + f"{name_filter(cmd)}_{cmd.qualified_name}": cmd + for cmd in new_cls.__cog_commands__ + } + + # Update the Command instances dynamically as well + for command in new_cls.__cog_commands__: + if ( + isinstance(command, ApplicationCommand) + and not command.guild_ids + and new_cls.__cog_guild_ids__ + ): + command.guild_ids = new_cls.__cog_guild_ids__ + + if not isinstance(command, SlashCommandGroup): + # ignore bridge commands + cmd = getattr(new_cls, command.callback.__name__, None) + if hasattr(cmd, "add_to"): + setattr( + cmd, + f"{name_filter(command).replace('app', 'slash')}_variant", + command, + ) + else: + setattr(new_cls, command.callback.__name__, command) + + parent = command.parent + if parent is not None: + # Get the latest parent reference + parent = lookup[f"{name_filter(command)}_{parent.qualified_name}"] # type: ignore + + # Update our parent's reference to our self + parent.remove_command(command.name) # type: ignore + parent.add_command(command) # type: ignore + + return new_cls + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args) + + @classmethod + def qualified_name(cls) -> str: + return cls.__cog_name__ + + +def _cog_special_method(func: FuncT) -> FuncT: + func.__cog_special_method__ = None + return func + + +class Cog(metaclass=CogMeta): + """The base class that all cogs must inherit from. + + A cog is a collection of commands, listeners, and optional state to + help group commands together. More information on them can be found on + the :ref:`ext_commands_cogs` page. + + When inheriting from this class, the options shown in :class:`CogMeta` + are equally valid here. + """ + + __cog_name__: ClassVar[str] + __cog_settings__: ClassVar[dict[str, Any]] + __cog_commands__: ClassVar[list[ApplicationCommand]] + __cog_listeners__: ClassVar[list[tuple[str, str]]] + __cog_guild_ids__: ClassVar[list[int]] + + def __new__(cls: type[CogT], *args: Any, **kwargs: Any) -> CogT: + # For issue 426, we need to store a copy of the command objects + # since we modify them to inject `self` to them. + # To do this, we need to interfere with the Cog creation process. + return super().__new__(cls) + + def get_commands(self) -> list[ApplicationCommand]: + r""" + Returns + -------- + List[:class:`.ApplicationCommand`] + A :class:`list` of :class:`.ApplicationCommand`\s that are + defined inside this cog. + + .. note:: + + This does not include subcommands. + """ + return [ + c + for c in self.__cog_commands__ + if isinstance(c, ApplicationCommand) and c.parent is None + ] + + @property + def qualified_name(self) -> str: + """:class:`str`: Returns the cog's specified name, not the class name.""" + return self.__cog_name__ + + @property + def description(self) -> str: + """:class:`str`: Returns the cog's description, typically the cleaned docstring.""" + return self.__cog_description__ + + @description.setter + def description(self, description: str) -> None: + self.__cog_description__ = description + + def walk_commands(self) -> Generator[ApplicationCommand, None, None]: + """An iterator that recursively walks through this cog's commands and subcommands. + + Yields + ------ + Union[:class:`.Command`, :class:`.Group`] + A command or group from the cog. + """ + for command in self.__cog_commands__: + if isinstance(command, SlashCommandGroup): + yield from command.walk_commands() + + def get_listeners(self) -> list[tuple[str, Callable[..., Any]]]: + """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. + + Returns + ------- + List[Tuple[:class:`str`, :ref:`coroutine `]] + The listeners defined in this cog. + """ + return [ + (name, getattr(self, method_name)) + for name, method_name in self.__cog_listeners__ + ] + + @classmethod + def _get_overridden_method(cls, method: FuncT) -> FuncT | None: + """Return None if the method is not overridden. Otherwise, returns the overridden method.""" + return getattr( + getattr(method, "__func__", method), "__cog_special_method__", method + ) + + @classmethod + def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: + """A decorator that marks a function as a listener. + + This is the cog equivalent of :meth:`.Bot.listen`. + + Parameters + ---------- + name: :class:`str` + The name of the event being listened to. If not provided, it + defaults to the function's name. + + Raises + ------ + TypeError + The function is not a coroutine function or a string was not passed as + the name. + """ + + if name is not MISSING and not isinstance(name, str): + raise TypeError( + f"Cog.listener expected str but received {name.__class__.__name__!r} instead." + ) + + def decorator(func: FuncT) -> FuncT: + actual = func + if isinstance(actual, staticmethod): + actual = actual.__func__ + if not inspect.iscoroutinefunction(actual): + raise TypeError("Listener function must be a coroutine function.") + actual.__cog_listener__ = True + to_assign = name or actual.__name__ + try: + actual.__cog_listener_names__.append(to_assign) + except AttributeError: + actual.__cog_listener_names__ = [to_assign] + # we have to return `func` instead of `actual` because + # we need the type to be `staticmethod` for the metaclass + # to pick it up but the metaclass unfurls the function and + # thus the assignments need to be on the actual function + return func + + return decorator + + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the cog has an error handler. + + .. versionadded:: 1.7 + """ + return not hasattr(self.cog_command_error.__func__, "__cog_special_method__") + + @_cog_special_method + def cog_unload(self) -> None: + """A special method that is called when the cog gets removed. + + This function **cannot** be a coroutine. It must be a regular + function. + + Subclasses must replace this if they want special unloading behaviour. + """ + + @_cog_special_method + def bot_check_once(self, ctx: ApplicationContext) -> bool: + """A special method that registers as a :meth:`.Bot.check_once` + check. + + This function **can** be a coroutine and must take a sole parameter, + ``ctx``, to represent the :class:`.Context` or :class:`.ApplicationContext`. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context. + """ + return True + + @_cog_special_method + def bot_check(self, ctx: ApplicationContext) -> bool: + """A special method that registers as a :meth:`.Bot.check` + check. + + This function **can** be a coroutine and must take a sole parameter, + ``ctx``, to represent the :class:`.Context` or :class:`.ApplicationContext`. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context. + """ + return True + + @_cog_special_method + def cog_check(self, ctx: ApplicationContext) -> bool: + """A special method that registers as a :func:`~discord.ext.commands.check` + for every command and subcommand in this cog. + + This function **can** be a coroutine and must take a sole parameter, + ``ctx``, to represent the :class:`.Context` or :class:`.ApplicationContext`. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context. + """ + return True + + @_cog_special_method + async def cog_command_error( + self, ctx: ApplicationContext, error: Exception + ) -> None: + """A special method that is called whenever an error + is dispatched inside this cog. + + This is similar to :func:`.on_command_error` except only applying + to the commands inside this cog. + + This **must** be a coroutine. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context where the error happened. + error: :class:`ApplicationCommandError` + The error that happened. + """ + + @_cog_special_method + async def cog_before_invoke(self, ctx: ApplicationContext) -> None: + """A special method that acts as a cog local pre-invoke hook. + + This is similar to :meth:`.ApplicationCommand.before_invoke`. + + This **must** be a coroutine. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context. + """ + + @_cog_special_method + async def cog_after_invoke(self, ctx: ApplicationContext) -> None: + """A special method that acts as a cog local post-invoke hook. + + This is similar to :meth:`.ApplicationCommand.after_invoke`. + + This **must** be a coroutine. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context. + """ + + def _inject(self: CogT, bot) -> CogT: + cls = self.__class__ + + # realistically, the only thing that can cause loading errors + # is essentially just the command loading, which raises if there are + # duplicates. When this condition is met, we want to undo all what + # we've added so far for some form of atomic loading. + + for index, command in enumerate(self.__cog_commands__): + command._set_cog(self) + + if isinstance(command, ApplicationCommand): + if isinstance(command, discord.SlashCommandGroup): + for x in command.subcommands: + if isinstance(x, discord.SlashCommandGroup): + for y in x.subcommands: + y.parent = x + x.parent = command + bot.add_application_command(command) + + elif command.parent is None: + try: + bot.add_command(command) + except Exception as e: + # undo our additions + for to_undo in self.__cog_commands__[:index]: + if to_undo.parent is None: + bot.remove_command(to_undo.name) + raise e + # check if we're overriding the default + if cls.bot_check is not Cog.bot_check: + bot.add_check(self.bot_check) + + if cls.bot_check_once is not Cog.bot_check_once: + bot.add_check(self.bot_check_once, call_once=True) + + # while Bot.add_listener can raise if it's not a coroutine, + # this precondition is already met by the listener decorator + # already, thus this should never raise. + # Outside of, memory errors and the like... + for name, method_name in self.__cog_listeners__: + bot.add_listener(getattr(self, method_name), name) + + return self + + def _eject(self, bot) -> None: + cls = self.__class__ + + try: + for command in self.__cog_commands__: + if isinstance(command, ApplicationCommand): + bot.remove_application_command(command) + elif command.parent is None: + bot.remove_command(command.name) + + for _, method_name in self.__cog_listeners__: + bot.remove_listener(getattr(self, method_name)) + + if cls.bot_check is not Cog.bot_check: + bot.remove_check(self.bot_check) + + if cls.bot_check_once is not Cog.bot_check_once: + bot.remove_check(self.bot_check_once, call_once=True) + finally: + try: + self.cog_unload() + except Exception: + pass + + +class CogMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__cogs: dict[str, Cog] = {} + self.__extensions: dict[str, types.ModuleType] = {} + + def add_cog(self, cog: Cog, *, override: bool = False) -> None: + """Adds a "cog" to the bot. + + A cog is a class that has its own event listeners and commands. + + .. versionchanged:: 2.0 + + :exc:`.ClientException` is raised when a cog with the same name + is already loaded. + + Parameters + ---------- + cog: :class:`.Cog` + The cog to register to the bot. + override: :class:`bool` + If a previously loaded cog with the same name should be ejected + instead of raising an error. + + .. versionadded:: 2.0 + + Raises + ------ + TypeError + The cog does not inherit from :class:`.Cog`. + ApplicationCommandError + An error happened during loading. + ClientException + A cog with the same name is already loaded. + """ + + if not isinstance(cog, Cog): + raise TypeError("cogs must derive from Cog") + + cog_name = cog.__cog_name__ + existing = self.__cogs.get(cog_name) + + if existing is not None: + if not override: + raise discord.ClientException(f"Cog named {cog_name!r} already loaded") + self.remove_cog(cog_name) + + cog = cog._inject(self) + self.__cogs[cog_name] = cog + + def get_cog(self, name: str) -> Cog | None: + """Gets the cog instance requested. + + If the cog is not found, ``None`` is returned instead. + + Parameters + ---------- + name: :class:`str` + The name of the cog you are requesting. + This is equivalent to the name passed via keyword + argument in class creation or the class name if unspecified. + + Returns + ------- + Optional[:class:`Cog`] + The cog that was requested. If not found, returns ``None``. + """ + return self.__cogs.get(name) + + def remove_cog(self, name: str) -> Cog | None: + """Removes a cog from the bot and returns it. + + All registered commands and event listeners that the + cog has registered will be removed as well. + + If no cog is found then this method has no effect. + + Parameters + ---------- + name: :class:`str` + The name of the cog to remove. + + Returns + ------- + Optional[:class:`.Cog`] + The cog that was removed. ``None`` if not found. + """ + + cog = self.__cogs.pop(name, None) + if cog is None: + return + + if hasattr(self, "_help_command"): + help_command = self._help_command + if help_command and help_command.cog is cog: + help_command.cog = None + + cog._eject(self) + + return cog + + @property + def cogs(self) -> Mapping[str, Cog]: + """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" + return types.MappingProxyType(self.__cogs) + + # extensions + + def _remove_module_references(self, name: str) -> None: + # find all references to the module + # remove the cogs registered from the module + for cog_name, cog in self.__cogs.copy().items(): + if _is_submodule(name, cog.__module__): + self.remove_cog(cog_name) + + # remove all the commands from the module + if self._supports_prefixed_commands: + for cmd in self.prefixed_commands.copy().values(): + if cmd.module is not None and _is_submodule(name, cmd.module): + # if isinstance(cmd, GroupMixin): + # cmd.recursively_remove_all_commands() + self.remove_command(cmd.name) + for cmd in self._application_commands.copy().values(): + if cmd.module is not None and _is_submodule(name, cmd.module): + # if isinstance(cmd, GroupMixin): + # cmd.recursively_remove_all_commands() + self.remove_application_command(cmd) + + # remove all the listeners from the module + for event_list in self.extra_events.copy().values(): + remove = [ + index + for index, event in enumerate(event_list) + if event.__module__ is not None + and _is_submodule(name, event.__module__) + ] + + for index in reversed(remove): + del event_list[index] + + def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: + try: + func = getattr(lib, "teardown") + except AttributeError: + pass + else: + try: + func(self) + except Exception: + pass + finally: + self.__extensions.pop(key, None) + sys.modules.pop(key, None) + name = lib.__name__ + for module in list(sys.modules.keys()): + if _is_submodule(name, module): + del sys.modules[module] + + def _load_from_module_spec( + self, spec: importlib.machinery.ModuleSpec, key: str + ) -> None: + # precondition: key not in self.__extensions + lib = importlib.util.module_from_spec(spec) + sys.modules[key] = lib + try: + spec.loader.exec_module(lib) # type: ignore + except Exception as e: + del sys.modules[key] + raise errors.ExtensionFailed(key, e) from e + + try: + setup = getattr(lib, "setup") + except AttributeError: + del sys.modules[key] + raise errors.NoEntryPointError(key) + + try: + setup(self) + except Exception as e: + del sys.modules[key] + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, key) + raise errors.ExtensionFailed(key, e) from e + else: + self.__extensions[key] = lib + + def _resolve_name(self, name: str, package: str | None) -> str: + try: + return importlib.util.resolve_name(name, package) + except ImportError: + raise errors.ExtensionNotFound(name) + + @overload + def load_extension( + self, + name: str, + *, + package: str | None = None, + recursive: bool = False, + ) -> list[str]: + ... + + @overload + def load_extension( + self, + name: str, + *, + package: str | None = None, + recursive: bool = False, + store: bool = False, + ) -> dict[str, Exception | bool] | list[str] | None: + ... + + def load_extension( + self, name, *, package=None, recursive=False, store=False + ) -> dict[str, Exception | bool] | list[str] | None: + """Loads an extension. + + An extension is a python module that contains commands, cogs, or + listeners. + + An extension must have a global function, ``setup`` defined as + the entry point on what to do when the extension is loaded. This entry + point must have a single argument, the ``bot``. + + The extension passed can either be the direct name of a file within + the current working directory or a folder that contains multiple extensions. + + Parameters + ---------- + name: :class:`str` + The extension or folder name to load. It must be dot separated + like regular Python imports if accessing a submodule. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when loading an extension using a relative + path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 + recursive: Optional[:class:`bool`] + If subdirectories under the given head directory should be + recursively loaded. + Defaults to ``False``. + + .. versionadded:: 2.0 + store: Optional[:class:`bool`] + If exceptions should be stored or raised. If set to ``True``, + all exceptions encountered will be stored in a returned dictionary + as a load status. If set to ``False``, if any exceptions are + encountered they will be raised and the bot will be closed. + If no exceptions are encountered, a list of loaded + extension names will be returned. + Defaults to ``False``. + + .. versionadded:: 2.0 + + Returns + ------- + Optional[Union[Dict[:class:`str`, Union[:exc:`errors.ExtensionError`, :class:`bool`]], List[:class:`str`]]] + If the store parameter is set to ``True``, a dictionary will be returned that + contains keys to represent the loaded extension names. The values bound to + each key can either be an exception that occurred when loading that extension + or a ``True`` boolean representing a successful load. If the store parameter + is set to ``False``, either a list containing a list of loaded extensions or + nothing due to an encountered exception. + + Raises + ------ + ExtensionNotFound + The extension could not be imported. + This is also raised if the name of the extension could not + be resolved using the provided ``package`` parameter. + ExtensionAlreadyLoaded + The extension is already loaded. + NoEntryPointError + The extension does not have a setup function. + ExtensionFailed + The extension or its setup function had an execution error. + """ + + name = self._resolve_name(name, package) + + if name in self.__extensions: + exc = errors.ExtensionAlreadyLoaded(name) + final_out = {name: exc} if store else exc + # This indicates that there is neither an extension nor folder here + elif (spec := importlib.util.find_spec(name)) is None: + exc = errors.ExtensionNotFound(name) + final_out = {name: exc} if store else exc + # This indicates we've found an extension file to load, and we need to store any exceptions + elif spec.has_location and store: + try: + self._load_from_module_spec(spec, name) + except Exception as exc: + final_out = {name: exc} + else: + final_out = {name: True} + # This indicates we've found an extension file to load, and any encountered exceptions can be raised + elif spec.has_location: + self._load_from_module_spec(spec, name) + final_out = [name] + # This indicates we've been given a folder because the ModuleSpec exists but is not a file + else: + # Split the directory path and join it to get an os-native Path object + path = pathlib.Path(os.path.join(*name.split("."))) + glob = path.rglob if recursive else path.glob + final_out = {} if store else [] + + # Glob all files with a pattern to gather all .py files that don't start with _ + for ext_file in glob("[!_]*.py"): + # Gets all parts leading to the directory minus the file name + parts = list(ext_file.parts[:-1]) + # Gets the file name without the extension + parts.append(ext_file.stem) + loaded = self.load_extension( + ".".join(parts), package=package, recursive=recursive, store=store + ) + final_out.update(loaded) if store else final_out.extend(loaded) + + if isinstance(final_out, Exception): + raise final_out + else: + return final_out + + @overload + def load_extensions( + self, + *names: str, + package: str | None = None, + recursive: bool = False, + ) -> list[str]: + ... + + @overload + def load_extensions( + self, + *names: str, + package: str | None = None, + recursive: bool = False, + store: bool = False, + ) -> dict[str, Exception | bool] | list[str] | None: + ... + + def load_extensions( + self, *names, package=None, recursive=False, store=False + ) -> dict[str, Exception | bool] | list[str] | None: + """Loads multiple extensions at once. + + This method simplifies the process of loading multiple + extensions by handling the looping of ``load_extension``. + + Parameters + ---------- + names: :class:`str` + The extension or folder names to load. It must be dot separated + like regular Python imports if accessing a submodule. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when loading an extension using a relative + path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 + recursive: Optional[:class:`bool`] + If subdirectories under the given head directory should be + recursively loaded. + Defaults to ``False``. + + .. versionadded:: 2.0 + store: Optional[:class:`bool`] + If exceptions should be stored or raised. If set to ``True``, + all exceptions encountered will be stored in a returned dictionary + as a load status. If set to ``False``, if any exceptions are + encountered they will be raised and the bot will be closed. + If no exceptions are encountered, a list of loaded + extension names will be returned. + Defaults to ``False``. + + .. versionadded:: 2.0 + + Returns + ------- + Optional[Union[Dict[:class:`str`, Union[:exc:`errors.ExtensionError`, :class:`bool`]], List[:class:`str`]]] + If the store parameter is set to ``True``, a dictionary will be returned that + contains keys to represent the loaded extension names. The values bound to + each key can either be an exception that occurred when loading that extension + or a ``True`` boolean representing a successful load. If the store parameter + is set to ``False``, either a list containing names of loaded extensions or + nothing due to an encountered exception. + + Raises + ------ + ExtensionNotFound + A given extension could not be imported. + This is also raised if the name of the extension could not + be resolved using the provided ``package`` parameter. + ExtensionAlreadyLoaded + A given extension is already loaded. + NoEntryPointError + A given extension does not have a setup function. + ExtensionFailed + A given extension or its setup function had an execution error. + """ + + loaded_extensions = {} if store else [] + + for ext_path in names: + loaded = self.load_extension( + ext_path, package=package, recursive=recursive, store=store + ) + loaded_extensions.update(loaded) if store else loaded_extensions.extend( + loaded + ) + + return loaded_extensions + + def unload_extension(self, name: str, *, package: str | None = None) -> None: + """Unloads an extension. + + When the extension is unloaded, all commands, listeners, and cogs are + removed from the bot and the module is un-imported. + + The extension can provide an optional global function, ``teardown``, + to do miscellaneous clean-up if necessary. This function takes a single + parameter, the ``bot``, similar to ``setup`` from + :meth:`~.Bot.load_extension`. + + Parameters + ---------- + name: :class:`str` + The extension name to unload. It must be dot separated like + regular Python imports if accessing a submodule. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when unloading an extension using a relative path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 + + Raises + ------ + ExtensionNotFound + The name of the extension could not + be resolved using the provided ``package`` parameter. + ExtensionNotLoaded + The extension was not loaded. + """ + + name = self._resolve_name(name, package) + lib = self.__extensions.get(name) + if lib is None: + raise errors.ExtensionNotLoaded(name) + + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, name) + + def reload_extension(self, name: str, *, package: str | None = None) -> None: + """Atomically reloads an extension. + + This replaces the extension with the same extension, only refreshed. This is + equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension` + except done in an atomic way. That is, if an operation fails mid-reload then + the bot will roll back to the prior working state. + + Parameters + ---------- + name: :class:`str` + The extension name to reload. It must be dot separated like + regular Python imports if accessing a submodule. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when reloading an extension using a relative path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 + + Raises + ------ + ExtensionNotLoaded + The extension was not loaded. + ExtensionNotFound + The extension could not be imported. + This is also raised if the name of the extension could not + be resolved using the provided ``package`` parameter. + NoEntryPointError + The extension does not have a setup function. + ExtensionFailed + The extension setup function had an execution error. + """ + + name = self._resolve_name(name, package) + lib = self.__extensions.get(name) + if lib is None: + raise errors.ExtensionNotLoaded(name) + + # get the previous module states from sys modules + modules = { + name: module + for name, module in sys.modules.items() + if _is_submodule(lib.__name__, name) + } + + try: + # Unload and then load the module... + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, name) + self.load_extension(name) + except Exception: + # if the load failed, the remnants should have been + # cleaned from the load_extension function call + # so let's load it from our old compiled library. + lib.setup(self) # type: ignore + self.__extensions[name] = lib + + # revert sys.modules back to normal and raise back to caller + sys.modules.update(modules) + raise + + @property + def extensions(self) -> Mapping[str, types.ModuleType]: + """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" + return types.MappingProxyType(self.__extensions) diff --git a/discord/colour.py b/discord/colour.py new file mode 100644 index 0000000..eb4f139 --- /dev/null +++ b/discord/colour.py @@ -0,0 +1,363 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import colorsys +import random +from typing import Any, Optional, Tuple, Type, TypeVar, Union + +__all__ = ( + "Colour", + "Color", +) + +CT = TypeVar("CT", bound="Colour") + + +class Colour: + """Represents a Discord role colour. This class is similar + to a (red, green, blue) :class:`tuple`. + + There is an alias for this called Color. + + .. container:: operations + + .. describe:: x == y + + Checks if two colours are equal. + + .. describe:: x != y + + Checks if two colours are not equal. + + .. describe:: hash(x) + + Return the colour's hash. + + .. describe:: str(x) + + Returns the hex format for the colour. + + .. describe:: int(x) + + Returns the raw colour value. + + Attributes + ---------- + value: :class:`int` + The raw integer colour value. + """ + + __slots__ = ("value",) + + def __init__(self, value: int): + if not isinstance(value, int): + raise TypeError( + f"Expected int parameter, received {value.__class__.__name__} instead." + ) + + self.value: int = value + + def _get_byte(self, byte: int) -> int: + return (self.value >> (8 * byte)) & 0xFF + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Colour) and self.value == other.value + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __str__(self) -> str: + return f"#{self.value:0>6x}" + + def __int__(self) -> int: + return self.value + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self.value) + + @property + def r(self) -> int: + """:class:`int`: Returns the red component of the colour.""" + return self._get_byte(2) + + @property + def g(self) -> int: + """:class:`int`: Returns the green component of the colour.""" + return self._get_byte(1) + + @property + def b(self) -> int: + """:class:`int`: Returns the blue component of the colour.""" + return self._get_byte(0) + + def to_rgb(self) -> Tuple[int, int, int]: + """Tuple[:class:`int`, :class:`int`, :class:`int`]: Returns an (r, g, b) tuple representing the colour.""" + return self.r, self.g, self.b + + @classmethod + def from_rgb(cls: Type[CT], r: int, g: int, b: int) -> CT: + """Constructs a :class:`Colour` from an RGB tuple.""" + return cls((r << 16) + (g << 8) + b) + + @classmethod + def from_hsv(cls: Type[CT], h: float, s: float, v: float) -> CT: + """Constructs a :class:`Colour` from an HSV tuple.""" + rgb = colorsys.hsv_to_rgb(h, s, v) + return cls.from_rgb(*(int(x * 255) for x in rgb)) + + @classmethod + def default(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0``.""" + return cls(0) + + @classmethod + def random( + cls: Type[CT], + *, + seed: Optional[Union[int, str, float, bytes, bytearray]] = None, + ) -> CT: + """A factory method that returns a :class:`Colour` with a random hue. + + .. note:: + + The random algorithm works by choosing a colour with a random hue but + with maxed out saturation and value. + + .. versionadded:: 1.6 + + Parameters + ---------- + seed: Optional[Union[:class:`int`, :class:`str`, :class:`float`, :class:`bytes`, :class:`bytearray`]] + The seed to initialize the RNG with. If ``None`` is passed the default RNG is used. + + .. versionadded:: 1.7 + """ + rand = random if seed is None else random.Random(seed) + return cls.from_hsv(rand.random(), 1, 1) + + @classmethod + def teal(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" + return cls(0x1ABC9C) + + @classmethod + def dark_teal(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" + return cls(0x11806A) + + @classmethod + def brand_green(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x57F287``. + + .. versionadded:: 2.0 + """ + return cls(0x57F287) + + @classmethod + def green(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" + return cls(0x2ECC71) + + @classmethod + def dark_green(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" + return cls(0x1F8B4C) + + @classmethod + def blue(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" + return cls(0x3498DB) + + @classmethod + def dark_blue(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x206694``.""" + return cls(0x206694) + + @classmethod + def purple(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" + return cls(0x9B59B6) + + @classmethod + def dark_purple(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" + return cls(0x71368A) + + @classmethod + def magenta(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" + return cls(0xE91E63) + + @classmethod + def dark_magenta(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" + return cls(0xAD1457) + + @classmethod + def gold(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" + return cls(0xF1C40F) + + @classmethod + def dark_gold(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" + return cls(0xC27C0E) + + @classmethod + def orange(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" + return cls(0xE67E22) + + @classmethod + def dark_orange(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" + return cls(0xA84300) + + @classmethod + def brand_red(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xED4245``. + + .. versionadded:: 2.0 + """ + return cls(0xED4245) + + @classmethod + def red(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" + return cls(0xE74C3C) + + @classmethod + def dark_red(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" + return cls(0x992D22) + + @classmethod + def lighter_grey(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" + return cls(0x95A5A6) + + lighter_gray = lighter_grey + + @classmethod + def dark_grey(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" + return cls(0x607D8B) + + dark_gray = dark_grey + + @classmethod + def light_grey(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" + return cls(0x979C9F) + + light_gray = light_grey + + @classmethod + def darker_grey(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" + return cls(0x546E7A) + + darker_gray = darker_grey + + @classmethod + def og_blurple(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" + return cls(0x7289DA) + + @classmethod + def blurple(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x5865F2``.""" + return cls(0x5865F2) + + @classmethod + def greyple(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" + return cls(0x99AAB5) + + @classmethod + def dark_theme(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0x36393F``. + This will appear transparent on Discord's dark theme. + + .. versionadded:: 1.5 + """ + return cls(0x36393F) + + @classmethod + def fuchsia(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xEB459E``. + + .. versionadded:: 2.0 + """ + return cls(0xEB459E) + + @classmethod + def yellow(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xFEE75C``. + + .. versionadded:: 2.0 + """ + return cls(0xFEE75C) + + @classmethod + def nitro_pink(cls: Type[CT]) -> CT: + """A factory method that returns a :class:`Colour` with a value of ``0xf47fff``. + + .. versionadded:: 2.0 + """ + return cls(0xF47FFF) + + @classmethod + def embed_background(cls: Type[CT], theme: str = "dark") -> CT: + """A factory method that returns a :class:`Color` corresponding to the + embed colors on discord clients, with a value of: + + - ``0x2F3136`` (dark) + - ``0xf2f3f5`` (light) + - ``0x000000`` (amoled). + + .. versionadded:: 2.0 + + Parameters + ---------- + theme: :class:`str` + The theme color to apply, must be one of "dark", "light", or "amoled". + """ + themes_cls = { + "dark": 0x2F3136, + "light": 0xF2F3F5, + "amoled": 0x000000, + } + + if theme not in themes_cls: + raise TypeError('Theme must be "dark", "light", or "amoled".') + + return cls(themes_cls[theme]) + + +Color = Colour diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py new file mode 100644 index 0000000..1813faf --- /dev/null +++ b/discord/commands/__init__.py @@ -0,0 +1,29 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .context import * +from .core import * +from .options import * +from .permissions import * diff --git a/discord/commands/context.py b/discord/commands/context.py new file mode 100644 index 0000000..212f796 --- /dev/null +++ b/discord/commands/context.py @@ -0,0 +1,402 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import discord.abc +from discord.interactions import Interaction, InteractionMessage, InteractionResponse +from discord.webhook.async_ import Webhook + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + import discord + from .. import Bot + from ..state import ConnectionState + from ..voice_client import VoiceProtocol + + from .core import ApplicationCommand, Option + from ..interactions import InteractionChannel + from ..guild import Guild + from ..member import Member + from ..message import Message + from ..user import User + from ..permissions import Permissions + from ..client import ClientUser + + from ..cog import Cog + from ..webhook import WebhookMessage + + from typing import Callable, Awaitable + +from ..utils import cached_property + +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") + +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") + +__all__ = ("ApplicationContext", "AutocompleteContext") + + +class ApplicationContext(discord.abc.Messageable): + """Represents a Discord application command interaction context. + + This class is not created manually and is instead passed to application + commands as the first parameter. + + .. versionadded:: 2.0 + + Attributes + ---------- + bot: :class:`.Bot` + The bot that the command belongs to. + interaction: :class:`.Interaction` + The interaction object that invoked the command. + command: :class:`.ApplicationCommand` + The command that this context belongs to. + """ + + def __init__(self, bot: Bot, interaction: Interaction): + self.bot = bot + self.interaction = interaction + + # below attributes will be set after initialization + self.command: ApplicationCommand = None # type: ignore + self.focused: Option = None # type: ignore + self.value: str = None # type: ignore + self.options: dict = None # type: ignore + + self._state: ConnectionState = self.interaction._state + + async def _get_channel(self) -> InteractionChannel | None: + return self.interaction.channel + + async def invoke( + self, + command: ApplicationCommand[CogT, P, T], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + r"""|coro| + + Calls a command with the arguments given. + This is useful if you want to just call the callback that a + :class:`.ApplicationCommand` holds internally. + + .. note:: + + This does not handle converters, checks, cooldowns, pre-invoke, + or after-invoke hooks in any matter. It calls the internal callback + directly as-if it was a regular function. + You must take care in passing the proper arguments when + using this function. + + Parameters + ----------- + command: :class:`.ApplicationCommand` + The command that is going to be called. + \*args + The arguments to use. + \*\*kwargs + The keyword arguments to use. + + Raises + ------- + TypeError + The command argument to invoke is missing. + """ + return await command(self, *args, **kwargs) + + @cached_property + def channel(self) -> InteractionChannel | None: + """Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]: + Returns the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel`. + """ + return self.interaction.channel + + @cached_property + def channel_id(self) -> int | None: + """:class:`int`: Returns the ID of the channel associated with this context's command. + Shorthand for :attr:`.Interaction.channel_id`. + """ + return self.interaction.channel_id + + @cached_property + def guild(self) -> Guild | None: + """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. + Shorthand for :attr:`.Interaction.guild`. + """ + return self.interaction.guild + + @cached_property + def guild_id(self) -> int | None: + """:class:`int`: Returns the ID of the guild associated with this context's command. + Shorthand for :attr:`.Interaction.guild_id`. + """ + return self.interaction.guild_id + + @cached_property + def locale(self) -> str | None: + """:class:`str`: Returns the locale of the guild associated with this context's command. + Shorthand for :attr:`.Interaction.locale`. + """ + return self.interaction.locale + + @cached_property + def guild_locale(self) -> str | None: + """:class:`str`: Returns the locale of the guild associated with this context's command. + Shorthand for :attr:`.Interaction.guild_locale`. + """ + return self.interaction.guild_locale + + @cached_property + def app_permissions(self) -> Permissions: + return self.interaction.app_permissions + + @cached_property + def me(self) -> Member | ClientUser | None: + """Union[:class:`.Member`, :class:`.ClientUser`]: + Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message + message contexts, or when :meth:`Intents.guilds` is absent. + """ + return ( + self.interaction.guild.me + if self.interaction.guild is not None + else self.bot.user + ) + + @cached_property + def message(self) -> Message | None: + """Optional[:class:`.Message`]: Returns the message sent with this context's command. + Shorthand for :attr:`.Interaction.message`, if applicable. + """ + return self.interaction.message + + @cached_property + def user(self) -> Member | User | None: + """Union[:class:`.Member`, :class:`.User`]: Returns the user that sent this context's command. + Shorthand for :attr:`.Interaction.user`. + """ + return self.interaction.user + + author: Member | User | None = user + + @property + def voice_client(self) -> VoiceProtocol | None: + """Optional[:class:`.VoiceProtocol`]: Returns the voice client associated with this context's command. + Shorthand for :attr:`Interaction.guild.voice_client<~discord.Guild.voice_client>`, if applicable. + """ + if self.interaction.guild is None: + return None + + return self.interaction.guild.voice_client + + @cached_property + def response(self) -> InteractionResponse: + """:class:`.InteractionResponse`: Returns the response object associated with this context's command. + Shorthand for :attr:`.Interaction.response`. + """ + return self.interaction.response + + @property + def selected_options(self) -> list[dict[str, Any]] | None: + """The options and values that were selected by the user when sending the command. + + Returns + ------- + Optional[List[Dict[:class:`str`, Any]]] + A dictionary containing the options and values that were selected by the user when the command + was processed, if applicable. Returns ``None`` if the command has not yet been invoked, + or if there are no options defined for that command. + """ + return self.interaction.data.get("options", None) + + @property + def unselected_options(self) -> list[Option] | None: + """The options that were not provided by the user when sending the command. + + Returns + ------- + Optional[List[:class:`.Option`]] + A list of Option objects (if any) that were not selected by the user when the command was processed. + Returns ``None`` if there are no options defined for that command. + """ + if self.command.options is not None: # type: ignore + if self.selected_options: + return [ + option + for option in self.command.options # type: ignore + if option.to_dict()["name"] + not in [opt["name"] for opt in self.selected_options] + ] + else: + return self.command.options # type: ignore + return None + + @property + @discord.utils.copy_doc(InteractionResponse.send_modal) + def send_modal(self) -> Callable[..., Awaitable[Interaction]]: + return self.interaction.response.send_modal + + async def respond(self, *args, **kwargs) -> Interaction | WebhookMessage: + """|coro| + + Sends either a response or a message using the followup webhook determined by whether the interaction + has been responded to or not. + + Returns + ------- + Union[:class:`discord.Interaction`, :class:`discord.WebhookMessage`]: + The response, its type depending on whether it's an interaction response or a followup. + """ + try: + if not self.interaction.response.is_done(): + return await self.interaction.response.send_message( + *args, **kwargs + ) # self.response + else: + return await self.followup.send(*args, **kwargs) # self.send_followup + except discord.errors.InteractionResponded: + return await self.followup.send(*args, **kwargs) + + @property + @discord.utils.copy_doc(InteractionResponse.send_message) + def send_response(self) -> Callable[..., Awaitable[Interaction]]: + if not self.interaction.response.is_done(): + return self.interaction.response.send_message + else: + raise RuntimeError( + f"Interaction was already issued a response. Try using {type(self).__name__}.send_followup() instead." + ) + + @property + @discord.utils.copy_doc(Webhook.send) + def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: + if self.interaction.response.is_done(): + return self.followup.send + else: + raise RuntimeError( + f"Interaction was not yet issued a response. Try using {type(self).__name__}.respond() first." + ) + + @property + @discord.utils.copy_doc(InteractionResponse.defer) + def defer(self) -> Callable[..., Awaitable[None]]: + return self.interaction.response.defer + + @property + def followup(self) -> Webhook: + """:class:`Webhook`: Returns the followup webhook for followup interactions.""" + return self.interaction.followup + + async def delete(self, *, delay: float | None = None) -> None: + """|coro| + + Deletes the original interaction response message. + + This is a higher level interface to :meth:`Interaction.delete_original_response`. + + Parameters + ---------- + delay: Optional[:class:`float`] + If provided, the number of seconds to wait before deleting the message. + + Raises + ------ + HTTPException + Deleting the message failed. + Forbidden + You do not have proper permissions to delete the message. + """ + if not self.interaction.response.is_done(): + await self.defer() + + return await self.interaction.delete_original_response(delay=delay) + + @property + @discord.utils.copy_doc(Interaction.edit_original_response) + def edit(self) -> Callable[..., Awaitable[InteractionMessage]]: + return self.interaction.edit_original_response + + @property + def cog(self) -> Cog | None: + """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. + ``None`` if it does not exist. + """ + if self.command is None: + return None + + return self.command.cog + + +class AutocompleteContext: + """Represents context for a slash command's option autocomplete. + + This class is not created manually and is instead passed to an :class:`.Option`'s autocomplete callback. + + .. versionadded:: 2.0 + + Attributes + ---------- + bot: :class:`.Bot` + The bot that the command belongs to. + interaction: :class:`.Interaction` + The interaction object that invoked the autocomplete. + command: :class:`.ApplicationCommand` + The command that this context belongs to. + focused: :class:`.Option` + The option the user is currently typing. + value: :class:`.str` + The content of the focused option. + options: Dict[:class:`str`, Any] + A name to value mapping of the options that the user has selected before this option. + """ + + __slots__ = ("bot", "interaction", "command", "focused", "value", "options") + + def __init__(self, bot: Bot, interaction: Interaction): + self.bot = bot + self.interaction = interaction + + self.command: ApplicationCommand = None # type: ignore + self.focused: Option = None # type: ignore + self.value: str = None # type: ignore + self.options: dict = None # type: ignore + + @property + def cog(self) -> Cog | None: + """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. + ``None`` if it does not exist. + """ + if self.command is None: + return None + + return self.command.cog diff --git a/discord/commands/core.py b/discord/commands/core.py new file mode 100644 index 0000000..7862333 --- /dev/null +++ b/discord/commands/core.py @@ -0,0 +1,1869 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import datetime +import functools +import inspect +import re +import types +from collections import OrderedDict +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Generator, + Generic, + TypeVar, + Union, +) + +from ..channel import _threaded_guild_channel_factory +from ..enums import Enum as DiscordEnum +from ..enums import MessageType, SlashCommandOptionType, try_enum +from ..errors import ( + ApplicationCommandError, + ApplicationCommandInvokeError, + CheckFailure, + ClientException, + ValidationError, +) +from ..member import Member +from ..message import Attachment, Message +from ..object import Object +from ..role import Role +from ..threads import Thread +from ..user import User +from ..utils import MISSING, async_all, find, maybe_coroutine, utcnow +from .context import ApplicationContext, AutocompleteContext +from .options import Option, OptionChoice + +__all__ = ( + "_BaseCommand", + "ApplicationCommand", + "SlashCommand", + "slash_command", + "application_command", + "user_command", + "message_command", + "command", + "SlashCommandGroup", + "ContextMenuCommand", + "UserCommand", + "MessageCommand", +) + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec + + from .. import Permissions + from ..cog import Cog + +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") +Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) + +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") + + +def wrap_callback(coro): + from ..ext.commands.errors import CommandError + + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except ApplicationCommandError: + raise + except CommandError: + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise ApplicationCommandInvokeError(exc) from exc + return ret + + return wrapped + + +def hooked_wrapped_callback(command, ctx, coro): + from ..ext.commands.errors import CommandError + + @functools.wraps(coro) + async def wrapped(arg): + try: + ret = await coro(arg) + except ApplicationCommandError: + raise + except CommandError: + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise ApplicationCommandInvokeError(exc) from exc + finally: + if ( + hasattr(command, "_max_concurrency") + and command._max_concurrency is not None + ): + await command._max_concurrency.release(ctx) + await command.call_after_hooks(ctx) + return ret + + return wrapped + + +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def _validate_names(obj): + validate_chat_input_name(obj.name) + if obj.name_localizations: + for locale, string in obj.name_localizations.items(): + validate_chat_input_name(string, locale=locale) + + +def _validate_descriptions(obj): + validate_chat_input_description(obj.description) + if obj.description_localizations: + for locale, string in obj.description_localizations.items(): + validate_chat_input_description(string, locale=locale) + + +class _BaseCommand: + __slots__ = () + + +class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]): + __original_kwargs__: dict[str, Any] + cog = None + + def __init__(self, func: Callable, **kwargs) -> None: + from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency + + cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) + + self._buckets: CooldownMapping = buckets + + max_concurrency = getattr( + func, "__commands_max_concurrency__", kwargs.get("max_concurrency") + ) + + self._max_concurrency: MaxConcurrency | None = max_concurrency + + self._callback = None + self.module = None + + self.name: str = kwargs.get("name", func.__name__) + + try: + checks = func.__commands_checks__ + checks.reverse() + except AttributeError: + checks = kwargs.get("checks", []) + + self.checks = checks + self.id: int | None = kwargs.get("id") + self.guild_ids: list[int] | None = kwargs.get("guild_ids", None) + self.parent = kwargs.get("parent") + + # Permissions + self.default_member_permissions: Permissions | None = getattr( + func, + "__default_member_permissions__", + kwargs.get("default_member_permissions", None), + ) + self.guild_only: bool | None = getattr( + func, "__guild_only__", kwargs.get("guild_only", None) + ) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other) -> bool: + if ( + getattr(self, "id", None) is not None + and getattr(other, "id", None) is not None + ): + check = self.id == other.id + else: + check = self.name == other.name and self.guild_ids == other.guild_ids + return ( + isinstance(other, self.__class__) and self.parent == other.parent and check + ) + + async def __call__(self, ctx, *args, **kwargs): + """|coro| + Calls the command's callback. + + This method bypasses all checks that a command has and does not + convert the arguments beforehand, so take care to pass the correct + arguments in. + """ + if self.cog is not None: + return await self.callback(self.cog, ctx, *args, **kwargs) + return await self.callback(ctx, *args, **kwargs) + + @property + def callback( + self, + ) -> ( + Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]] + | Callable[Concatenate[ApplicationContext, P], Coro[T]] + ): + return self._callback + + @callback.setter + def callback( + self, + function: ( + Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]] + | Callable[Concatenate[ApplicationContext, P], Coro[T]] + ), + ) -> None: + self._callback = function + unwrap = unwrap_function(function) + self.module = unwrap.__module__ + + def _prepare_cooldowns(self, ctx: ApplicationContext): + if self._buckets.valid: + current = datetime.datetime.now().timestamp() + bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message + + if bucket is not None: + retry_after = bucket.update_rate_limit(current) + + if retry_after: + from ..ext.commands.errors import CommandOnCooldown + + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + + async def prepare(self, ctx: ApplicationContext) -> None: + # This should be same across all 3 types + ctx.command = self + + if not await self.can_run(ctx): + raise CheckFailure( + f"The check functions for the command {self.name} failed" + ) + + if hasattr(self, "_max_concurrency"): + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message + + try: + self._prepare_cooldowns(ctx) + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) # type: ignore # ctx instead of non-existent message + raise + + def is_on_cooldown(self, ctx: ApplicationContext) -> bool: + """Checks whether the command is currently on cooldown. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context to use when checking the command's cooldown status. + + Returns + ------- + :class:`bool` + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False + + bucket = self._buckets.get_bucket(ctx) + current = utcnow().timestamp() + return bucket.get_tokens(current) == 0 + + def reset_cooldown(self, ctx: ApplicationContext) -> None: + """Resets the cooldown on this command. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) # type: ignore # ctx instead of non-existent message + bucket.reset() + + def get_cooldown_retry_after(self, ctx: ApplicationContext) -> float: + """Retrieves the amount of seconds before this command can be tried again. + + .. note:: + + This uses the current time instead of the interaction time. + + Parameters + ---------- + ctx: :class:`.ApplicationContext` + The invocation context to retrieve the cooldown from. + + Returns + ------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) + current = utcnow().timestamp() + return bucket.get_retry_after(current) + + return 0.0 + + async def invoke(self, ctx: ApplicationContext) -> None: + await self.prepare(ctx) + + injected = hooked_wrapped_callback(self, ctx, self._invoke) + await injected(ctx) + + async def can_run(self, ctx: ApplicationContext) -> bool: + + if not await ctx.bot.can_run(ctx): + raise CheckFailure( + f"The global check functions for command {self.name} failed." + ) + + predicates = self.checks + if self.parent is not None: + # parent checks should be run first + predicates = self.parent.checks + predicates + + cog = self.cog + if cog is not None: + local_check = cog._get_overridden_method(cog.cog_check) + if local_check is not None: + ret = await maybe_coroutine(local_check, ctx) + if not ret: + return False + + if not predicates: + # since we have no checks, then we just return True. + return True + + return await async_all(predicate(ctx) for predicate in predicates) # type: ignore + + async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> None: + ctx.command_failed = True + cog = self.cog + try: + coro = self.on_error + except AttributeError: + pass + else: + injected = wrap_callback(coro) + if cog is not None: + await injected(cog, ctx, error) + else: + await injected(ctx, error) + + try: + if cog is not None: + local = cog.__class__._get_overridden_method(cog.cog_command_error) + if local is not None: + wrapped = wrap_callback(local) + await wrapped(ctx, error) + finally: + ctx.bot.dispatch("application_command_error", ctx, error) + + def _get_signature_parameters(self): + return OrderedDict(inspect.signature(self.callback).parameters) + + def error(self, coro): + """A decorator that registers a coroutine as a local error handler. + + A local error handler is an :func:`.on_command_error` event limited to + a single command. However, the :func:`.on_command_error` is still + invoked afterwards as the catch-all. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + + self.on_error = coro + return coro + + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the command has an error handler registered.""" + return hasattr(self, "on_error") + + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.ApplicationContext`. + See :meth:`.Bot.before_invoke` for more info. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + """A decorator that registers a coroutine as a post-invoke hook. + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.ApplicationContext`. + See :meth:`.Bot.after_invoke` for more info. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + async def call_before_hooks(self, ctx: ApplicationContext) -> None: + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.cog + if self._before_invoke is not None: + # should be cog if @commands.before_invoke is used + instance = getattr(self._before_invoke, "__self__", cog) + # __self__ only exists for methods, not functions + # however, if @command.before_invoke is used, it will be a function + if instance: + await self._before_invoke(instance, ctx) # type: ignore + else: + await self._before_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = cog.__class__._get_overridden_method(cog.cog_before_invoke) + if hook is not None: + await hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + await hook(ctx) + + async def call_after_hooks(self, ctx: ApplicationContext) -> None: + cog = self.cog + if self._after_invoke is not None: + instance = getattr(self._after_invoke, "__self__", cog) + if instance: + await self._after_invoke(instance, ctx) # type: ignore + else: + await self._after_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = cog.__class__._get_overridden_method(cog.cog_after_invoke) + if hook is not None: + await hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + await hook(ctx) + + @property + def cooldown(self): + return self._buckets._cooldown + + @property + def full_parent_name(self) -> str: + """:class:`str`: Retrieves the fully qualified parent command name. + + This the base command name required to execute it. For example, + in ``/one two three`` the parent name would be ``one two``. + """ + entries = [] + command = self + while command.parent is not None and hasattr(command.parent, "name"): + command = command.parent + entries.append(command.name) + + return " ".join(reversed(entries)) + + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``/one two three`` the qualified name would be + ``one two three``. + """ + + parent = self.full_parent_name + + if parent: + return f"{parent} {self.name}" + else: + return self.name + + @property + def qualified_id(self) -> int: + """:class:`int`: Retrieves the fully qualified command ID. + + This is the root parent ID. For example, in ``/one two three`` + the qualified ID would return ``one.id``. + """ + if self.id is None: + return self.parent.qualified_id + return self.id + + def to_dict(self) -> dict[str, Any]: + raise NotImplementedError + + def __str__(self) -> str: + return self.qualified_name + + def _set_cog(self, cog): + self.cog = cog + + +class SlashCommand(ApplicationCommand): + r"""A class that implements the protocol for a slash command. + + These are not created manually, instead they are created via the + decorator or functional interface. + + .. versionadded:: 2.0 + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + description: Optional[:class:`str`] + The description for the command. + guild_ids: Optional[List[:class:`int`]] + The ids of the guilds where this command will be registered. + options: List[:class:`Option`] + The parameters for this command. + parent: Optional[:class:`SlashCommandGroup`] + The parent group that this command belongs to. ``None`` if there + isn't one. + mention: :class:`str` + Returns a string that allows you to mention the slash command. + guild_only: :class:`bool` + Whether the command should only be usable inside a guild. + default_member_permissions: :class:`~discord.Permissions` + The default permissions a member needs to be able to run the command. + cog: Optional[:class:`Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] + The cooldown applied when the command is invoked. ``None`` if the command + doesn't have a cooldown. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this command. The values of this should be ``"locale": "name"``. See + `here `_ for a list of valid locales. + description_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The description localizations for this command. The values of this should be ``"locale": "description"``. + See `here `_ for a list of valid locales. + """ + type = 1 + + def __new__(cls, *args, **kwargs) -> SlashCommand: + self = super().__new__(cls) + + self.__original_kwargs__ = kwargs.copy() + return self + + def __init__(self, func: Callable, *args, **kwargs) -> None: + super().__init__(func, **kwargs) + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + self.callback = func + + self.name_localizations: dict[str, str] | None = kwargs.get( + "name_localizations", None + ) + _validate_names(self) + + description = kwargs.get("description") or ( + inspect.cleandoc(func.__doc__).splitlines()[0] + if func.__doc__ is not None + else "No description provided" + ) + + self.description: str = description + self.description_localizations: dict[str, str] | None = kwargs.get( + "description_localizations", None + ) + _validate_descriptions(self) + + self.attached_to_group: bool = False + + self.options: list[Option] = kwargs.get("options", []) + + try: + checks = func.__commands_checks__ + checks.reverse() + except AttributeError: + checks = kwargs.get("checks", []) + + self.checks = checks + + self._before_invoke = None + self._after_invoke = None + + def _validate_parameters(self): + params = self._get_signature_parameters() + if kwop := self.options: + self.options: list[Option] = self._match_option_param_names(params, kwop) + else: + self.options: list[Option] = self._parse_options(params) + + def _check_required_params(self, params): + params = iter(params.items()) + required_params = ( + ["self", "context"] if self.attached_to_group or self.cog else ["context"] + ) + for p in required_params: + try: + next(params) + except StopIteration: + raise ClientException( + f'Callback for {self.name} command is missing "{p}" parameter.' + ) + + return params + + def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: + if check_params: + params = self._check_required_params(params) + + final_options = [] + for p_name, p_obj in params: + option = p_obj.annotation + if option == inspect.Parameter.empty: + option = str + + if self._is_typing_union(option): + if self._is_typing_optional(option): + option = Option(option.__args__[0], default=None) + else: + option = Option(option.__args__) + + if not isinstance(option, Option): + if isinstance(p_obj.default, Option): + p_obj.default.input_type = SlashCommandOptionType.from_datatype( + option + ) + option = p_obj.default + else: + option = Option(option) + + if option.default is None and not p_obj.default == inspect.Parameter.empty: + if isinstance(p_obj.default, type) and issubclass( + p_obj.default, (DiscordEnum, Enum) + ): + option = Option(p_obj.default) + elif ( + isinstance(p_obj.default, Option) + and not (default := p_obj.default.default) is None + ): + option.default = default + else: + option.default = p_obj.default + option.required = False + if option.name is None: + option.name = p_name + if option.name != p_name or option._parameter_name is None: + option._parameter_name = p_name + + _validate_names(option) + _validate_descriptions(option) + + final_options.append(option) + + return final_options + + def _match_option_param_names(self, params, options): + params = self._check_required_params(params) + + check_annotations: list[Callable[[Option, type], bool]] = [ + lambda o, a: o.input_type == SlashCommandOptionType.string + and o.converter is not None, # pass on converters + lambda o, a: isinstance( + o.input_type, SlashCommandOptionType + ), # pass on slash cmd option type enums + lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types + lambda o, a: self._is_typing_optional(a) + and not o.required + and o._raw_type in a.__args__, # optional + lambda o, a: isinstance(a, type) + and issubclass(a, o._raw_type), # 'normal' types + ] + for o in options: + _validate_names(o) + _validate_descriptions(o) + try: + p_name, p_obj = next(params) + except StopIteration: # not enough params for all the options + raise ClientException("Too many arguments passed to the options kwarg.") + p_obj = p_obj.annotation + + if not any(check(o, p_obj) for check in check_annotations): + raise TypeError( + f"Parameter {p_name} does not match input type of {o.name}." + ) + o._parameter_name = p_name + + left_out_params = OrderedDict() + for k, v in params: + left_out_params[k] = v + options.extend(self._parse_options(left_out_params, check_params=False)) + + return options + + def _is_typing_union(self, annotation): + return getattr(annotation, "__origin__", None) is Union or type( + annotation + ) is getattr( + types, "UnionType", Union + ) # type: ignore + + def _is_typing_optional(self, annotation): + return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore + + @property + def cog(self): + return getattr(self, "_cog", MISSING) + + @cog.setter + def cog(self, val): + self._cog = val + self._validate_parameters() + + @property + def is_subcommand(self) -> bool: + return self.parent is not None + + @property + def mention(self) -> str: + return f"" + + def to_dict(self) -> dict: + as_dict = { + "name": self.name, + "description": self.description, + "options": [o.to_dict() for o in self.options], + } + if self.name_localizations is not None: + as_dict["name_localizations"] = self.name_localizations + if self.description_localizations is not None: + as_dict["description_localizations"] = self.description_localizations + if self.is_subcommand: + as_dict["type"] = SlashCommandOptionType.sub_command.value + + if self.guild_only is not None: + as_dict["dm_permission"] = not self.guild_only + + if self.default_member_permissions is not None: + as_dict[ + "default_member_permissions" + ] = self.default_member_permissions.value + + return as_dict + + async def _invoke(self, ctx: ApplicationContext) -> None: + # TODO: Parse the args better + kwargs = {} + for arg in ctx.interaction.data.get("options", []): + op = find(lambda x: x.name == arg["name"], self.options) + if op is None: + continue + arg = arg["value"] + + # Checks if input_type is user, role or channel + if op.input_type in ( + SlashCommandOptionType.user, + SlashCommandOptionType.role, + SlashCommandOptionType.channel, + SlashCommandOptionType.attachment, + SlashCommandOptionType.mentionable, + ): + resolved = ctx.interaction.data.get("resolved", {}) + if ( + op.input_type + in (SlashCommandOptionType.user, SlashCommandOptionType.mentionable) + and (_data := resolved.get("members", {}).get(arg)) is not None + ): + # The option type is a user, we resolved a member from the snowflake and assigned it to _data + if (_user_data := resolved.get("users", {}).get(arg)) is not None: + # We resolved the user from the user id + _data["user"] = _user_data + cache_flag = ctx.interaction._state.member_cache_flags.interaction + arg = ctx.guild._get_and_update_member(_data, int(arg), cache_flag) + elif op.input_type is SlashCommandOptionType.mentionable: + if (_data := resolved.get("users", {}).get(arg)) is not None: + arg = User(state=ctx.interaction._state, data=_data) + elif (_data := resolved.get("roles", {}).get(arg)) is not None: + arg = Role( + state=ctx.interaction._state, data=_data, guild=ctx.guild + ) + else: + arg = Object(id=int(arg)) + elif ( + _data := resolved.get(f"{op.input_type.name}s", {}).get(arg) + ) is not None: + if op.input_type is SlashCommandOptionType.channel and ( + int(arg) in ctx.guild._channels + or int(arg) in ctx.guild._threads + ): + arg = ctx.guild.get_channel_or_thread(int(arg)) + _data["_invoke_flag"] = True + arg._update(_data) if isinstance(arg, Thread) else arg._update( + ctx.guild, _data + ) + else: + obj_type = None + kw = {} + if op.input_type is SlashCommandOptionType.user: + obj_type = User + elif op.input_type is SlashCommandOptionType.role: + obj_type = Role + kw["guild"] = ctx.guild + elif op.input_type is SlashCommandOptionType.channel: + # NOTE: + # This is a fallback in case the channel/thread is not found in the + # guild's channels/threads. For channels, if this fallback occurs, at the very minimum, + # permissions will be incorrect due to a lack of permission_overwrite data. + # For threads, if this fallback occurs, info like thread owner id, message count, + # flags, and more will be missing due to a lack of data sent by Discord. + obj_type = _threaded_guild_channel_factory(_data["type"])[0] + kw["guild"] = ctx.guild + elif op.input_type is SlashCommandOptionType.attachment: + obj_type = Attachment + arg = obj_type(state=ctx.interaction._state, data=_data, **kw) + else: + # We couldn't resolve the object, so we just return an empty object + arg = Object(id=int(arg)) + + elif ( + op.input_type == SlashCommandOptionType.string + and (converter := op.converter) is not None + ): + from discord.ext.commands import Converter + + if isinstance(converter, Converter): + if isinstance(converter, type): + arg = await converter().convert(ctx, arg) + else: + arg = await converter.convert(ctx, arg) + + elif op._raw_type in ( + SlashCommandOptionType.integer, + SlashCommandOptionType.number, + SlashCommandOptionType.string, + SlashCommandOptionType.boolean, + ): + pass + + elif issubclass(op._raw_type, Enum): + if isinstance(arg, str) and arg.isdigit(): + try: + arg = op._raw_type(int(arg)) + except ValueError: + arg = op._raw_type(arg) + elif choice := find(lambda c: c.value == arg, op.choices): + arg = getattr(op._raw_type, choice.name) + + kwargs[op._parameter_name] = arg + + for o in self.options: + if o._parameter_name not in kwargs: + kwargs[o._parameter_name] = o.default + + if self.cog is not None: + await self.callback(self.cog, ctx, **kwargs) + elif self.parent is not None and self.attached_to_group is True: + await self.callback(self.parent, ctx, **kwargs) + else: + await self.callback(ctx, **kwargs) + + async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): + values = {i.name: i.default for i in self.options} + + for op in ctx.interaction.data.get("options", []): + if op.get("focused", False): + option = find(lambda o: o.name == op["name"], self.options) + values.update( + {i["name"]: i["value"] for i in ctx.interaction.data["options"]} + ) + ctx.command = self + ctx.focused = option + ctx.value = op.get("value") + ctx.options = values + + if len(inspect.signature(option.autocomplete).parameters) == 2: + instance = getattr(option.autocomplete, "__self__", ctx.cog) + result = option.autocomplete(instance, ctx) + else: + result = option.autocomplete(ctx) + + if asyncio.iscoroutinefunction(option.autocomplete): + result = await result + + choices = [ + o if isinstance(o, OptionChoice) else OptionChoice(o) + for o in result + ][:25] + return await ctx.interaction.response.send_autocomplete_result( + choices=choices + ) + + def copy(self): + """Creates a copy of this command. + + Returns + ------- + :class:`SlashCommand` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) + + def _ensure_assignment_on_copy(self, other): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + # if self._buckets.valid and not other._buckets.valid: + # other._buckets = self._buckets.copy() + # if self._max_concurrency != other._max_concurrency: + # # _max_concurrency won't be None at this point + # other._max_concurrency = self._max_concurrency.copy() # type: ignore + + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def _update_copy(self, kwargs: dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + +class SlashCommandGroup(ApplicationCommand): + r"""A class that implements the protocol for a slash command group. + + These can be created manually, but they should be created via the + decorator or functional interface. + + Attributes + ----------- + name: :class:`str` + The name of the command. + description: Optional[:class:`str`] + The description for the command. + guild_ids: Optional[List[:class:`int`]] + The ids of the guilds where this command will be registered. + parent: Optional[:class:`SlashCommandGroup`] + The parent group that this group belongs to. ``None`` if there + isn't one. + guild_only: :class:`bool` + Whether the command should only be usable inside a guild. + default_member_permissions: :class:`~discord.Permissions` + The default permissions a member needs to be able to run the command. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this command. The values of this should be ``"locale": "name"``. See + `here `_ for a list of valid locales. + description_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The description localizations for this command. The values of this should be ``"locale": "description"``. + See `here `_ for a list of valid locales. + """ + __initial_commands__: list[SlashCommand | SlashCommandGroup] + type = 1 + + def __new__(cls, *args, **kwargs) -> SlashCommandGroup: + self = super().__new__(cls) + self.__original_kwargs__ = kwargs.copy() + + self.__initial_commands__ = [] + for i, c in cls.__dict__.items(): + if isinstance(c, type) and SlashCommandGroup in c.__bases__: + c = c( + c.__name__, + ( + inspect.cleandoc(cls.__doc__).splitlines()[0] + if cls.__doc__ is not None + else "No description provided" + ), + ) + if isinstance(c, (SlashCommand, SlashCommandGroup)): + c.parent = self + c.attached_to_group = True + self.__initial_commands__.append(c) + + return self + + def __init__( + self, + name: str, + description: str | None = None, + guild_ids: list[int] | None = None, + parent: SlashCommandGroup | None = None, + **kwargs, + ) -> None: + self.name = str(name) + self.description = description or "No description provided" + validate_chat_input_name(self.name) + validate_chat_input_description(self.description) + self.input_type = SlashCommandOptionType.sub_command_group + self.subcommands: list[ + SlashCommand | SlashCommandGroup + ] = self.__initial_commands__ + self.guild_ids = guild_ids + self.parent = parent + self.attached_to_group: bool = False + self.checks = kwargs.get("checks", []) + + self._before_invoke = None + self._after_invoke = None + self.cog = MISSING + self.id = None + + # Permissions + self.default_member_permissions: Permissions | None = kwargs.get( + "default_member_permissions", None + ) + self.guild_only: bool | None = kwargs.get("guild_only", None) + + self.name_localizations: dict[str, str] | None = kwargs.get( + "name_localizations", None + ) + self.description_localizations: dict[str, str] | None = kwargs.get( + "description_localizations", None + ) + + @property + def module(self) -> str | None: + return self.__module__ + + def to_dict(self) -> dict: + as_dict = { + "name": self.name, + "description": self.description, + "options": [c.to_dict() for c in self.subcommands], + } + if self.name_localizations is not None: + as_dict["name_localizations"] = self.name_localizations + if self.description_localizations is not None: + as_dict["description_localizations"] = self.description_localizations + + if self.parent is not None: + as_dict["type"] = self.input_type.value + + if self.guild_only is not None: + as_dict["dm_permission"] = not self.guild_only + + if self.default_member_permissions is not None: + as_dict[ + "default_member_permissions" + ] = self.default_member_permissions.value + + return as_dict + + def add_command(self, command: SlashCommand) -> None: + # check if subcommand has no cog set + # also check if cog is MISSING because it + # might not have been set by the cog yet + if command.cog is MISSING and self.cog is not MISSING: + command.cog = self.cog + + self.subcommands.append(command) + + def command( + self, cls: type[T] = SlashCommand, **kwargs + ) -> Callable[[Callable], SlashCommand]: + def wrap(func) -> T: + command = cls(func, parent=self, **kwargs) + self.add_command(command) + return command + + return wrap + + def create_subgroup( + self, + name: str, + description: str | None = None, + guild_ids: list[int] | None = None, + **kwargs, + ) -> SlashCommandGroup: + """ + Creates a new subgroup for this SlashCommandGroup. + + Parameters + ---------- + name: :class:`str` + The name of the group to create. + description: Optional[:class:`str`] + The description of the group to create. + guild_ids: Optional[List[:class:`int`]] + A list of the IDs of each guild this group should be added to, making it a guild command. + This will be a global command if ``None`` is passed. + guild_only: :class:`bool` + Whether the command should only be usable inside a guild. + default_member_permissions: :class:`~discord.Permissions` + The default permissions a member needs to be able to run the command. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this command. The values of this should be ``"locale": "name"``. See + `here `_ for a list of valid locales. + description_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The description localizations for this command. The values of this should be ``"locale": "description"``. + See `here `_ for a list of valid locales. + + Returns + ------- + SlashCommandGroup + The slash command group that was created. + """ + + if self.parent is not None: + # TODO: Improve this error message + raise Exception("a subgroup cannot have a subgroup") + + sub_command_group = SlashCommandGroup( + name, description, guild_ids, parent=self, **kwargs + ) + self.subcommands.append(sub_command_group) + return sub_command_group + + def subgroup( + self, + name: str | None = None, + description: str | None = None, + guild_ids: list[int] | None = None, + ) -> Callable[[type[SlashCommandGroup]], SlashCommandGroup]: + """A shortcut decorator that initializes the provided subclass of :class:`.SlashCommandGroup` + as a subgroup. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: Optional[:class:`str`] + The name of the group to create. This will resolve to the name of the decorated class if ``None`` is passed. + description: Optional[:class:`str`] + The description of the group to create. + guild_ids: Optional[List[:class:`int`]] + A list of the IDs of each guild this group should be added to, making it a guild command. + This will be a global command if ``None`` is passed. + + Returns + ------- + Callable[[Type[SlashCommandGroup]], SlashCommandGroup] + The slash command group that was created. + """ + + def inner(cls: type[SlashCommandGroup]) -> SlashCommandGroup: + group = cls( + name or cls.__name__, + description + or ( + inspect.cleandoc(cls.__doc__).splitlines()[0] + if cls.__doc__ is not None + else "No description provided" + ), + guild_ids=guild_ids, + parent=self, + ) + self.add_command(group) + return group + + return inner + + async def _invoke(self, ctx: ApplicationContext) -> None: + option = ctx.interaction.data["options"][0] + resolved = ctx.interaction.data.get("resolved", None) + command = find(lambda x: x.name == option["name"], self.subcommands) + option["resolved"] = resolved + ctx.interaction.data = option + await command.invoke(ctx) + + async def invoke_autocomplete_callback(self, ctx: AutocompleteContext) -> None: + option = ctx.interaction.data["options"][0] + command = find(lambda x: x.name == option["name"], self.subcommands) + ctx.interaction.data = option + await command.invoke_autocomplete_callback(ctx) + + def walk_commands(self) -> Generator[SlashCommand, None, None]: + """An iterator that recursively walks through all slash commands in this group. + + Yields + ------ + :class:`.SlashCommand` + A slash command from the group. + """ + for command in self.subcommands: + if isinstance(command, SlashCommandGroup): + yield from command.walk_commands() + yield command + + def copy(self): + """Creates a copy of this command group. + + Returns + ------- + :class:`SlashCommandGroup` + A new instance of this command group. + """ + ret = self.__class__( + name=self.name, + description=self.description, + **{ + param: value + for param, value in self.__original_kwargs__.items() + if param not in ("name", "description") + }, + ) + return self._ensure_assignment_on_copy(ret) + + def _ensure_assignment_on_copy(self, other): + other.parent = self.parent + + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + + if self.subcommands != other.subcommands: + other.subcommands = self.subcommands.copy() + + if self.checks != other.checks: + other.checks = self.checks.copy() + + return other + + def _update_copy(self, kwargs: dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + def _set_cog(self, cog): + super()._set_cog(cog) + for subcommand in self.subcommands: + subcommand._set_cog(cog) + + +class ContextMenuCommand(ApplicationCommand): + r"""A class that implements the protocol for context menu commands. + + These are not created manually, instead they are created via the + decorator or functional interface. + + .. versionadded:: 2.0 + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + guild_ids: Optional[List[:class:`int`]] + The ids of the guilds where this command will be registered. + guild_only: :class:`bool` + Whether the command should only be usable inside a guild. + default_member_permissions: :class:`~discord.Permissions` + The default permissions a member needs to be able to run the command. + cog: Optional[:class:`Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + cooldown: Optional[:class:`~discord.ext.commands.Cooldown`] + The cooldown applied when the command is invoked. ``None`` if the command + doesn't have a cooldown. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this command. The values of this should be ``"locale": "name"``. See + `here `_ for a list of valid locales. + """ + + def __new__(cls, *args, **kwargs) -> ContextMenuCommand: + self = super().__new__(cls) + + self.__original_kwargs__ = kwargs.copy() + return self + + def __init__(self, func: Callable, *args, **kwargs) -> None: + super().__init__(func, **kwargs) + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + self.callback = func + + self.name_localizations: dict[str, str] | None = kwargs.get( + "name_localizations", None + ) + + # Discord API doesn't support setting descriptions for context menu commands, so it must be empty + self.description = "" + if not isinstance(self.name, str): + raise TypeError("Name of a command must be a string.") + + self.cog = None + self.id = None + + self._before_invoke = None + self._after_invoke = None + + self.validate_parameters() + + # Context Menu commands can't have parents + self.parent = None + + def validate_parameters(self): + params = self._get_signature_parameters() + if list(params.items())[0][0] == "self": + temp = list(params.items()) + temp.pop(0) + params = dict(temp) + params = iter(params) + + # next we have the 'ctx' as the next parameter + try: + next(params) + except StopIteration: + raise ClientException( + f'Callback for {self.name} command is missing "ctx" parameter.' + ) + + # next we have the 'user/message' as the next parameter + try: + next(params) + except StopIteration: + cmd = "user" if type(self) == UserCommand else "message" + raise ClientException( + f'Callback for {self.name} command is missing "{cmd}" parameter.' + ) + + # next there should be no more parameters + try: + next(params) + raise ClientException( + f"Callback for {self.name} command has too many parameters." + ) + except StopIteration: + pass + + @property + def qualified_name(self): + return self.name + + def to_dict(self) -> dict[str, str | int]: + as_dict = { + "name": self.name, + "description": self.description, + "type": self.type, + } + + if self.guild_only is not None: + as_dict["dm_permission"] = not self.guild_only + + if self.default_member_permissions is not None: + as_dict[ + "default_member_permissions" + ] = self.default_member_permissions.value + + if self.name_localizations is not None: + as_dict["name_localizations"] = self.name_localizations + + return as_dict + + +class UserCommand(ContextMenuCommand): + r"""A class that implements the protocol for user context menu commands. + + These are not created manually, instead they are created via the + decorator or functional interface. + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + guild_ids: Optional[List[:class:`int`]] + The ids of the guilds where this command will be registered. + cog: Optional[:class:`.Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + """ + type = 2 + + def __new__(cls, *args, **kwargs) -> UserCommand: + self = super().__new__(cls) + + self.__original_kwargs__ = kwargs.copy() + return self + + async def _invoke(self, ctx: ApplicationContext) -> None: + if "members" not in ctx.interaction.data["resolved"]: + _data = ctx.interaction.data["resolved"]["users"] + for i, v in _data.items(): + v["id"] = int(i) + user = v + target = User(state=ctx.interaction._state, data=user) + else: + _data = ctx.interaction.data["resolved"]["members"] + for i, v in _data.items(): + v["id"] = int(i) + member = v + _data = ctx.interaction.data["resolved"]["users"] + for i, v in _data.items(): + v["id"] = int(i) + user = v + member["user"] = user + target = Member( + data=member, + guild=ctx.interaction._state._get_guild(ctx.interaction.guild_id), + state=ctx.interaction._state, + ) + + if self.cog is not None: + await self.callback(self.cog, ctx, target) + else: + await self.callback(ctx, target) + + def copy(self): + """Creates a copy of this command. + + Returns + ------- + :class:`UserCommand` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) + + def _ensure_assignment_on_copy(self, other): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + # if self._buckets.valid and not other._buckets.valid: + # other._buckets = self._buckets.copy() + # if self._max_concurrency != other._max_concurrency: + # # _max_concurrency won't be None at this point + # other._max_concurrency = self._max_concurrency.copy() # type: ignore + + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def _update_copy(self, kwargs: dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + +class MessageCommand(ContextMenuCommand): + r"""A class that implements the protocol for message context menu commands. + + These are not created manually, instead they are created via the + decorator or functional interface. + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + guild_ids: Optional[List[:class:`int`]] + The ids of the guilds where this command will be registered. + cog: Optional[:class:`.Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.ApplicationContext`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.ApplicationContext` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.ApplicationCommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error` + event. + """ + type = 3 + + def __new__(cls, *args, **kwargs) -> MessageCommand: + self = super().__new__(cls) + + self.__original_kwargs__ = kwargs.copy() + return self + + async def _invoke(self, ctx: ApplicationContext): + _data = ctx.interaction.data["resolved"]["messages"] + for i, v in _data.items(): + v["id"] = int(i) + message = v + channel = ctx.interaction._state.get_channel(int(message["channel_id"])) + if channel is None: + author_id = int(message["author"]["id"]) + self_or_system_message: bool = ctx.bot.user.id == author_id or try_enum( + MessageType, message["type"] + ) not in ( + MessageType.default, + MessageType.reply, + MessageType.application_command, + MessageType.thread_starter_message, + ) + user_id = ctx.author.id if self_or_system_message else author_id + data = await ctx.interaction._state.http.start_private_message(user_id) + channel = ctx.interaction._state.add_dm_channel(data) + + target = Message(state=ctx.interaction._state, channel=channel, data=message) + + if self.cog is not None: + await self.callback(self.cog, ctx, target) + else: + await self.callback(ctx, target) + + def copy(self): + """Creates a copy of this command. + + Returns + ------- + :class:`MessageCommand` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) + + def _ensure_assignment_on_copy(self, other): + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + # if self._buckets.valid and not other._buckets.valid: + # other._buckets = self._buckets.copy() + # if self._max_concurrency != other._max_concurrency: + # # _max_concurrency won't be None at this point + # other._max_concurrency = self._max_concurrency.copy() # type: ignore + + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def _update_copy(self, kwargs: dict[str, Any]): + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + +def slash_command(**kwargs): + """Decorator for slash commands that invokes :func:`application_command`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`.SlashCommand`] + A decorator that converts the provided method into a :class:`.SlashCommand`. + """ + return application_command(cls=SlashCommand, **kwargs) + + +def user_command(**kwargs): + """Decorator for user commands that invokes :func:`application_command`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`.UserCommand`] + A decorator that converts the provided method into a :class:`.UserCommand`. + """ + return application_command(cls=UserCommand, **kwargs) + + +def message_command(**kwargs): + """Decorator for message commands that invokes :func:`application_command`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`.MessageCommand`] + A decorator that converts the provided method into a :class:`.MessageCommand`. + """ + return application_command(cls=MessageCommand, **kwargs) + + +def application_command(cls=SlashCommand, **attrs): + """A decorator that transforms a function into an :class:`.ApplicationCommand`. More specifically, + usually one of :class:`.SlashCommand`, :class:`.UserCommand`, or :class:`.MessageCommand`. The exact class + depends on the ``cls`` parameter. + By default, the ``description`` attribute is received automatically from the + docstring of the function and is cleaned up with the use of + ``inspect.cleandoc``. If the docstring is ``bytes``, then it is decoded + into :class:`str` using utf-8 encoding. + The ``name`` attribute also defaults to the function name unchanged. + + .. versionadded:: 2.0 + + Parameters + ---------- + cls: :class:`.ApplicationCommand` + The class to construct with. By default, this is :class:`.SlashCommand`. + You usually do not change this. + attrs + Keyword arguments to pass into the construction of the class denoted + by ``cls``. + + Returns + ------- + Callable[..., :class:`.ApplicationCommand`] + A decorator that converts the provided method into an :class:`.ApplicationCommand`, or subclass of it. + + Raises + ------ + TypeError + If the function is not a coroutine or is already a command. + """ + + def decorator(func: Callable) -> cls: + if isinstance(func, ApplicationCommand): + func = func.callback + elif not callable(func): + raise TypeError( + "func needs to be a callable or a subclass of ApplicationCommand." + ) + return cls(func, **attrs) + + return decorator + + +def command(**kwargs): + """An alias for :meth:`application_command`. + + .. note:: + This decorator is overridden by :func:`ext.commands.command`. + + .. versionadded:: 2.0 + + Returns + ------- + Callable[..., :class:`.ApplicationCommand`] + A decorator that converts the provided method into an :class:`.ApplicationCommand`. + """ + return application_command(**kwargs) + + +docs = "https://discord.com/developers/docs" +valid_locales = [ + "da", + "de", + "en-GB", + "en-US", + "es-ES", + "fr", + "hr", + "it", + "lt", + "hu", + "nl", + "no", + "pl", + "pt-BR", + "ro", + "fi", + "sv-SE", + "vi", + "tr", + "cs", + "el", + "bg", + "ru", + "uk", + "hi", + "th", + "zh-CN", + "ja", + "zh-TW", + "ko", +] + + +# Validation +def validate_chat_input_name(name: Any, locale: str | None = None): + # Must meet the regex ^[-_\w\d\u0901-\u097D\u0E00-\u0E7F]{1,32}$ + if locale is not None and locale not in valid_locales: + raise ValidationError( + f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for list of supported locales." + ) + error = None + if not isinstance(name, str): + error = TypeError( + f'Command names and options must be of type str. Received "{name}"' + ) + elif not re.match(r"^[-_\w\d\u0901-\u097D\u0E00-\u0E7F]{1,32}$", name): + error = ValidationError( + r"Command names and options must follow the regex \"^[-_\w\d\u0901-\u097D\u0E00-\u0E7F]{1,32}$\". " + f"For more information, see {docs}/interactions/application-commands#application-command-object-" + f'application-command-naming. Received "{name}"' + ) + elif ( + name.lower() != name + ): # Can't use islower() as it fails if none of the chars can be lowered. See #512. + error = ValidationError( + f'Command names and options must be lowercase. Received "{name}"' + ) + + if error: + if locale: + error.args = (f"{error.args[0]} in locale {locale}",) + raise error + + +def validate_chat_input_description(description: Any, locale: str | None = None): + if locale is not None and locale not in valid_locales: + raise ValidationError( + f"Locale '{locale}' is not a valid locale, see {docs}/reference#locales for list of supported locales." + ) + error = None + if not isinstance(description, str): + error = TypeError( + f'Command and option description must be of type str. Received "{description}"' + ) + elif not 1 <= len(description) <= 100: + error = ValidationError( + f'Command and option description must be 1-100 characters long. Received "{description}"' + ) + + if error: + if locale: + error.args = (f"{error.args[0]} in locale {locale}",) + raise error diff --git a/discord/commands/options.py b/discord/commands/options.py new file mode 100644 index 0000000..c5c013a --- /dev/null +++ b/discord/commands/options.py @@ -0,0 +1,391 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import inspect +from enum import Enum +from typing import TYPE_CHECKING, Literal, Optional, Type, Union + +from ..abc import GuildChannel, Mentionable +from ..channel import CategoryChannel, StageChannel, TextChannel, Thread, VoiceChannel +from ..enums import ChannelType +from ..enums import Enum as DiscordEnum +from ..enums import SlashCommandOptionType + +if TYPE_CHECKING: + from ..ext.commands import Converter + from ..member import Member + from ..message import Attachment + from ..role import Role + from ..user import User + + InputType = Union[ + Type[str], + Type[bool], + Type[int], + Type[float], + Type[GuildChannel], + Type[Thread], + Type[Member], + Type[User], + Type[Attachment], + Type[Role], + Type[Mentionable], + SlashCommandOptionType, + Converter, + Type[Converter], + Type[Enum], + Type[DiscordEnum], + ] + +__all__ = ( + "ThreadOption", + "Option", + "OptionChoice", + "option", +) + +CHANNEL_TYPE_MAP = { + TextChannel: ChannelType.text, + VoiceChannel: ChannelType.voice, + StageChannel: ChannelType.stage_voice, + CategoryChannel: ChannelType.category, + Thread: ChannelType.public_thread, +} + + +class ThreadOption: + """Represents a class that can be passed as the ``input_type`` for an :class:`Option` class. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_type: Literal["public", "private", "news"] + The thread type to expect for this options input. + """ + + def __init__(self, thread_type: Literal["public", "private", "news"]): + type_map = { + "public": ChannelType.public_thread, + "private": ChannelType.private_thread, + "news": ChannelType.news_thread, + } + self._type = type_map[thread_type] + + +class Option: + """Represents a selectable option for a slash command. + + Attributes + ---------- + input_type: Union[Type[:class:`str`], Type[:class:`bool`], Type[:class:`int`], Type[:class:`float`], Type[:class:`.abc.GuildChannel`], Type[:class:`Thread`], Type[:class:`Member`], Type[:class:`User`], Type[:class:`Attachment`], Type[:class:`Role`], Type[:class:`.abc.Mentionable`], :class:`SlashCommandOptionType`, Type[:class:`.ext.commands.Converter`], Type[:class:`enums.Enum`], Type[:class:`Enum`]] + The type of input that is expected for this option. This can be a :class:`SlashCommandOptionType`, + an associated class, a channel type, a :class:`Converter`, a converter class or an :class:`enum.Enum`. + name: :class:`str` + The name of this option visible in the UI. + Inherits from the variable name if not provided as a parameter. + description: Optional[:class:`str`] + The description of this option. + Must be 100 characters or fewer. + choices: Optional[List[Union[:class:`Any`, :class:`OptionChoice`]]] + The list of available choices for this option. + Can be a list of values or :class:`OptionChoice` objects (which represent a name:value pair). + If provided, the input from the user must match one of the choices in the list. + required: Optional[:class:`bool`] + Whether this option is required. + default: Optional[:class:`Any`] + The default value for this option. If provided, ``required`` will be considered ``False``. + min_value: Optional[:class:`int`] + The minimum value that can be entered. + Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. + max_value: Optional[:class:`int`] + The maximum value that can be entered. + Only applies to Options with an :attr:`.input_type` of :class:`int` or :class:`float`. + min_length: Optional[:class:`int`] + The minimum length of the string that can be entered. Must be between 0 and 6000 (inclusive). + Only applies to Options with an :attr:`input_type` of :class:`str`. + max_length: Optional[:class:`int`] + The maximum length of the string that can be entered. Must be between 1 and 6000 (inclusive). + Only applies to Options with an :attr:`input_type` of :class:`str`. + autocomplete: Optional[:class:`Any`] + The autocomplete handler for the option. Accepts an iterable of :class:`str`, a callable (sync or async) + that takes a single argument of :class:`AutocompleteContext`, or a coroutine. + Must resolve to an iterable of :class:`str`. + + .. note:: + + Does not validate the input value against the autocomplete results. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this option. The values of this should be ``"locale": "name"``. + See `here `_ for a list of valid locales. + description_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The description localizations for this option. The values of this should be ``"locale": "description"``. + See `here `_ for a list of valid locales. + + Examples + -------- + Basic usage: :: + + @bot.slash_command(guild_ids=[...]) + async def hello( + ctx: discord.ApplicationContext, + name: Option(str, "Enter your name"), + age: Option(int, "Enter your age", min_value=1, max_value=99, default=18) + # passing the default value makes an argument optional + # you also can create optional argument using: + # age: Option(int, "Enter your age") = 18 + ): + await ctx.respond(f"Hello! Your name is {name} and you are {age} years old.") + + .. versionadded:: 2.0 + """ + + input_type: SlashCommandOptionType + converter: Converter | type[Converter] | None = None + + def __init__( + self, input_type: InputType = str, /, description: str | None = None, **kwargs + ) -> None: + self.name: str | None = kwargs.pop("name", None) + if self.name is not None: + self.name = str(self.name) + self._parameter_name = self.name # default + self._raw_type: InputType | tuple = input_type + + enum_choices = [] + input_type_is_class = isinstance(input_type, type) + if input_type_is_class and issubclass(input_type, (Enum, DiscordEnum)): + description = inspect.getdoc(input_type) + enum_choices = [OptionChoice(e.name, e.value) for e in input_type] + value_class = enum_choices[0].value.__class__ + if all(isinstance(elem.value, value_class) for elem in enum_choices): + input_type = SlashCommandOptionType.from_datatype( + enum_choices[0].value.__class__ + ) + else: + enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] + input_type = SlashCommandOptionType.string + + self.description = description or "No description provided" + self.channel_types: list[ChannelType] = kwargs.pop("channel_types", []) + + if isinstance(input_type, SlashCommandOptionType): + self.input_type = input_type + else: + from ..ext.commands import Converter + + if ( + isinstance(input_type, Converter) + or input_type_is_class + and issubclass(input_type, Converter) + ): + self.converter = input_type + self._raw_type = str + self.input_type = SlashCommandOptionType.string + else: + try: + self.input_type = SlashCommandOptionType.from_datatype(input_type) + except TypeError as exc: + from ..ext.commands.converter import CONVERTER_MAPPING + + if input_type not in CONVERTER_MAPPING: + raise exc + self.converter = CONVERTER_MAPPING[input_type] + self._raw_type = str + self.input_type = SlashCommandOptionType.string + else: + if self.input_type == SlashCommandOptionType.channel: + if not isinstance(self._raw_type, tuple): + if hasattr(input_type, "__args__"): + self._raw_type = input_type.__args__ # type: ignore # Union.__args__ + else: + self._raw_type = (input_type,) + self.channel_types = [ + CHANNEL_TYPE_MAP[t] + for t in self._raw_type + if t is not GuildChannel + ] + self.required: bool = ( + kwargs.pop("required", True) if "default" not in kwargs else False + ) + self.default = kwargs.pop("default", None) + self.choices: list[OptionChoice] = enum_choices or [ + o if isinstance(o, OptionChoice) else OptionChoice(o) + for o in kwargs.pop("choices", list()) + ] + + if self.input_type == SlashCommandOptionType.integer: + minmax_types = (int, type(None)) + minmax_typehint = Optional[int] + elif self.input_type == SlashCommandOptionType.number: + minmax_types = (int, float, type(None)) + minmax_typehint = Optional[Union[int, float]] + else: + minmax_types = (type(None),) + minmax_typehint = type(None) + + if self.input_type == SlashCommandOptionType.string: + minmax_length_types = (int, type(None)) + minmax_length_typehint = Optional[int] + else: + minmax_length_types = (type(None),) + minmax_length_typehint = type(None) + + self.min_value: int | float | None = kwargs.pop("min_value", None) + self.max_value: int | float | None = kwargs.pop("max_value", None) + self.min_length: int | None = kwargs.pop("min_length", None) + self.max_length: int | None = kwargs.pop("max_length", None) + + if ( + self.input_type != SlashCommandOptionType.integer + and self.input_type != SlashCommandOptionType.number + and (self.min_value or self.max_value) + ): + raise AttributeError( + "Option does not take min_value or max_value if not of type " + "SlashCommandOptionType.integer or SlashCommandOptionType.number" + ) + if self.input_type != SlashCommandOptionType.string and ( + self.min_length or self.max_length + ): + raise AttributeError( + "Option does not take min_length or max_length if not of type str" + ) + + if self.min_value is not None and not isinstance(self.min_value, minmax_types): + raise TypeError( + f'Expected {minmax_typehint} for min_value, got "{type(self.min_value).__name__}"' + ) + if self.max_value is not None and not isinstance(self.max_value, minmax_types): + raise TypeError( + f'Expected {minmax_typehint} for max_value, got "{type(self.max_value).__name__}"' + ) + + if self.min_length is not None: + if not isinstance(self.min_length, minmax_length_types): + raise TypeError( + f"Expected {minmax_length_typehint} for min_length," + f' got "{type(self.min_length).__name__}"' + ) + if self.min_length < 0 or self.min_length > 6000: + raise AttributeError( + "min_length must be between 0 and 6000 (inclusive)" + ) + if self.max_length is not None: + if not isinstance(self.max_length, minmax_length_types): + raise TypeError( + f"Expected {minmax_length_typehint} for max_length," + f' got "{type(self.max_length).__name__}"' + ) + if self.max_length < 1 or self.max_length > 6000: + raise AttributeError("max_length must between 1 and 6000 (inclusive)") + + self.autocomplete = kwargs.pop("autocomplete", None) + + self.name_localizations = kwargs.pop("name_localizations", None) + self.description_localizations = kwargs.pop("description_localizations", None) + + def to_dict(self) -> dict: + as_dict = { + "name": self.name, + "description": self.description, + "type": self.input_type.value, + "required": self.required, + "choices": [c.to_dict() for c in self.choices], + "autocomplete": bool(self.autocomplete), + } + if self.name_localizations is not None: + as_dict["name_localizations"] = self.name_localizations + if self.description_localizations is not None: + as_dict["description_localizations"] = self.description_localizations + if self.channel_types: + as_dict["channel_types"] = [t.value for t in self.channel_types] + if self.min_value is not None: + as_dict["min_value"] = self.min_value + if self.max_value is not None: + as_dict["max_value"] = self.max_value + if self.min_length is not None: + as_dict["min_length"] = self.min_length + if self.max_length is not None: + as_dict["max_length"] = self.max_length + + return as_dict + + def __repr__(self): + return f"" + + +class OptionChoice: + """ + Represents a name:value pairing for a selected :class:`.Option`. + + .. versionadded:: 2.0 + + Attributes + ---------- + name: :class:`str` + The name of the choice. Shown in the UI when selecting an option. + value: Optional[Union[:class:`str`, :class:`int`, :class:`float`]] + The value of the choice. If not provided, will use the value of ``name``. + name_localizations: Optional[Dict[:class:`str`, :class:`str`]] + The name localizations for this choice. The values of this should be ``"locale": "name"``. + See `here `_ for a list of valid locales. + """ + + def __init__( + self, + name: str, + value: str | int | float | None = None, + name_localizations: dict[str, str] | None = None, + ): + self.name = str(name) + self.value = value if value is not None else name + self.name_localizations = name_localizations + + def to_dict(self) -> dict[str, str | int | float]: + as_dict = {"name": self.name, "value": self.value} + if self.name_localizations is not None: + as_dict["name_localizations"] = self.name_localizations + + return as_dict + + +def option(name, type=None, **kwargs): + """A decorator that can be used instead of typehinting :class:`Option`. + + .. versionadded:: 2.0 + """ + + def decorator(func): + nonlocal type + type = type or func.__annotations__.get(name, str) + if parameter := kwargs.get("parameter_name"): + func.__annotations__[parameter] = Option(type, name=name, **kwargs) + else: + func.__annotations__[name] = Option(type, **kwargs) + return func + + return decorator diff --git a/discord/commands/permissions.py b/discord/commands/permissions.py new file mode 100644 index 0000000..b6c1cc0 --- /dev/null +++ b/discord/commands/permissions.py @@ -0,0 +1,110 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from typing import Callable + +from ..permissions import Permissions +from .core import ApplicationCommand + +__all__ = ( + "default_permissions", + "guild_only", +) + + +def default_permissions(**perms: bool) -> Callable: + """A decorator that limits the usage of a slash command to members with certain + permissions. + + The permissions passed in must be exactly like the properties shown under + :class:`.discord.Permissions`. + + .. note:: + These permissions can be updated by server administrators per-guild. As such, these are only "defaults", as the + name suggests. If you want to make sure that a user **always** has the specified permissions regardless, you + should use an internal check such as :func:`~.ext.commands.has_permissions`. + + Parameters + ---------- + **perms: Dict[:class:`str`, :class:`bool`] + An argument list of permissions to check for. + + Example + ------- + + .. code-block:: python3 + + from discord import default_permissions + + @bot.slash_command() + @default_permissions(manage_messages=True) + async def test(ctx): + await ctx.respond('You can manage messages.') + """ + + invalid = set(perms) - set(Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") + + def inner(command: Callable): + if isinstance(command, ApplicationCommand): + if command.parent is not None: + raise RuntimeError( + "Permission restrictions can only be set on top-level commands" + ) + command.default_member_permissions = Permissions(**perms) + else: + command.__default_member_permissions__ = Permissions(**perms) + return command + + return inner + + +def guild_only() -> Callable: + """A decorator that limits the usage of a slash command to guild contexts. + The command won't be able to be used in private message channels. + + Example + ------- + + .. code-block:: python3 + + from discord import guild_only + + @bot.slash_command() + @guild_only() + async def test(ctx): + await ctx.respond("You're in a guild.") + """ + + def inner(command: Callable): + if isinstance(command, ApplicationCommand): + command.guild_only = True + else: + command.__guild_only__ = True + + return command + + return inner diff --git a/discord/components.py b/discord/components.py new file mode 100644 index 0000000..5934b51 --- /dev/null +++ b/discord/components.py @@ -0,0 +1,479 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar + +from .enums import ButtonStyle, ComponentType, InputTextStyle, try_enum +from .partial_emoji import PartialEmoji, _EmojiTag +from .utils import MISSING, get_slots + +if TYPE_CHECKING: + from .emoji import Emoji + from .types.components import ActionRow as ActionRowPayload + from .types.components import ButtonComponent as ButtonComponentPayload + from .types.components import Component as ComponentPayload + from .types.components import InputText as InputTextComponentPayload + from .types.components import SelectMenu as SelectMenuPayload + from .types.components import SelectOption as SelectOptionPayload + + +__all__ = ( + "Component", + "ActionRow", + "Button", + "SelectMenu", + "SelectOption", + "InputText", +) + +C = TypeVar("C", bound="Component") + + +class Component: + """Represents a Discord Bot UI Kit Component. + + Currently, the only components supported by Discord are: + + - :class:`ActionRow` + - :class:`Button` + - :class:`SelectMenu` + + This class is abstract and cannot be instantiated. + + .. versionadded:: 2.0 + + Attributes + ---------- + type: :class:`ComponentType` + The type of component. + """ + + __slots__: tuple[str, ...] = ("type",) + + __repr_info__: ClassVar[tuple[str, ...]] + type: ComponentType + + def __repr__(self) -> str: + attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__) + return f"<{self.__class__.__name__} {attrs}>" + + @classmethod + def _raw_construct(cls: type[C], **kwargs) -> C: + self: C = cls.__new__(cls) + for slot in get_slots(cls): + try: + value = kwargs[slot] + except KeyError: + pass + else: + setattr(self, slot, value) + return self + + def to_dict(self) -> dict[str, Any]: + raise NotImplementedError + + +class ActionRow(Component): + """Represents a Discord Bot UI Kit Action Row. + + This is a component that holds up to 5 children components in a row. + + This inherits from :class:`Component`. + + .. versionadded:: 2.0 + + Attributes + ---------- + type: :class:`ComponentType` + The type of component. + children: List[:class:`Component`] + The children components that this holds, if any. + """ + + __slots__: tuple[str, ...] = ("children",) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + + def __init__(self, data: ComponentPayload): + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.children: list[Component] = [ + _component_factory(d) for d in data.get("components", []) + ] + + def to_dict(self) -> ActionRowPayload: + return { + "type": int(self.type), + "components": [child.to_dict() for child in self.children], + } # type: ignore + + +class InputText(Component): + """Represents an Input Text field from the Discord Bot UI Kit. + This inherits from :class:`Component`. + + Attributes + ---------- + style: :class:`.InputTextStyle` + The style of the input text field. + custom_id: Optional[:class:`str`] + The ID of the input text field that gets received during an interaction. + label: :class:`str` + The label for the input text field. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_length: Optional[:class:`int`] + The minimum number of characters that must be entered + Defaults to 0 + max_length: Optional[:class:`int`] + The maximum number of characters that can be entered + required: Optional[:class:`bool`] + Whether the input text field is required or not. Defaults to `True`. + value: Optional[:class:`str`] + The value that has been entered in the input text field. + """ + + __slots__: tuple[str, ...] = ( + "type", + "style", + "custom_id", + "label", + "placeholder", + "min_length", + "max_length", + "required", + "value", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + + def __init__(self, data: InputTextComponentPayload): + self.type = ComponentType.input_text + self.style: InputTextStyle = try_enum(InputTextStyle, data["style"]) + self.custom_id = data["custom_id"] + self.label: str = data.get("label", None) + self.placeholder: str | None = data.get("placeholder", None) + self.min_length: int | None = data.get("min_length", None) + self.max_length: int | None = data.get("max_length", None) + self.required: bool = data.get("required", True) + self.value: str | None = data.get("value", None) + + def to_dict(self) -> InputTextComponentPayload: + payload = { + "type": 4, + "style": self.style.value, + "label": self.label, + } + if self.custom_id: + payload["custom_id"] = self.custom_id + + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.min_length: + payload["min_length"] = self.min_length + + if self.max_length: + payload["max_length"] = self.max_length + + if not self.required: + payload["required"] = self.required + + if self.value: + payload["value"] = self.value + + return payload # type: ignore + + +class Button(Component): + """Represents a button from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type to create a button is :class:`discord.ui.Button` + not this one. + + .. versionadded:: 2.0 + + Attributes + ---------- + style: :class:`.ButtonStyle` + The style of the button. + custom_id: Optional[:class:`str`] + The ID of the button that gets received during an interaction. + If this button is for a URL, it does not have a custom ID. + url: Optional[:class:`str`] + The URL this button sends you to. + disabled: :class:`bool` + Whether the button is disabled or not. + label: Optional[:class:`str`] + The label of the button, if any. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the button, if available. + """ + + __slots__: tuple[str, ...] = ( + "style", + "custom_id", + "url", + "disabled", + "label", + "emoji", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + + def __init__(self, data: ButtonComponentPayload): + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) + self.custom_id: str | None = data.get("custom_id") + self.url: str | None = data.get("url") + self.disabled: bool = data.get("disabled", False) + self.label: str | None = data.get("label") + self.emoji: PartialEmoji | None + try: + self.emoji = PartialEmoji.from_dict(data["emoji"]) + except KeyError: + self.emoji = None + + def to_dict(self) -> ButtonComponentPayload: + payload = { + "type": 2, + "style": int(self.style), + "label": self.label, + "disabled": self.disabled, + } + if self.custom_id: + payload["custom_id"] = self.custom_id + + if self.url: + payload["url"] = self.url + + if self.emoji: + payload["emoji"] = self.emoji.to_dict() + + return payload # type: ignore + + +class SelectMenu(Component): + """Represents a select menu from the Discord Bot UI Kit. + + A select menu is functionally the same as a dropdown, however + on mobile it renders a bit differently. + + .. note:: + + The user constructible and usable type to create a select menu is + :class:`discord.ui.Select` not this one. + + .. versionadded:: 2.0 + + Attributes + ---------- + custom_id: Optional[:class:`str`] + The ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of items that must be chosen for this select menu. + Defaults to 1 and must be between 0 and 25. + max_values: :class:`int` + The maximum number of items that must be chosen for this select menu. + Defaults to 1 and must be between 1 and 25. + options: List[:class:`SelectOption`] + A list of options that can be selected in this menu. + disabled: :class:`bool` + Whether the select is disabled or not. + """ + + __slots__: tuple[str, ...] = ( + "custom_id", + "placeholder", + "min_values", + "max_values", + "options", + "disabled", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + + def __init__(self, data: SelectMenuPayload): + self.type = ComponentType.select + self.custom_id: str = data["custom_id"] + self.placeholder: str | None = data.get("placeholder") + self.min_values: int = data.get("min_values", 1) + self.max_values: int = data.get("max_values", 1) + self.options: list[SelectOption] = [ + SelectOption.from_dict(option) for option in data.get("options", []) + ] + self.disabled: bool = data.get("disabled", False) + + def to_dict(self) -> SelectMenuPayload: + payload: SelectMenuPayload = { + "type": self.type.value, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + "options": [op.to_dict() for op in self.options], + "disabled": self.disabled, + } + + if self.placeholder: + payload["placeholder"] = self.placeholder + + return payload + + +class SelectOption: + """Represents a select menu's option. + + These can be created by users. + + .. versionadded:: 2.0 + + Attributes + ---------- + label: :class:`str` + The label of the option. This is displayed to users. + Can only be up to 100 characters. + value: :class:`str` + The value of the option. This is not displayed to users. + If not provided when constructed then it defaults to the + label. Can only be up to 100 characters. + description: Optional[:class:`str`] + An additional description of the option, if any. + Can only be up to 100 characters. + default: :class:`bool` + Whether this option is selected by default. + """ + + __slots__: tuple[str, ...] = ( + "label", + "value", + "description", + "_emoji", + "default", + ) + + def __init__( + self, + *, + label: str, + value: str = MISSING, + description: str | None = None, + emoji: str | Emoji | PartialEmoji | None = None, + default: bool = False, + ) -> None: + if len(label) > 100: + raise ValueError("label must be 100 characters or fewer") + + if value is not MISSING and len(value) > 100: + raise ValueError("value must be 100 characters or fewer") + + if description is not None and len(description) > 100: + raise ValueError("description must be 100 characters or fewer") + + self.label = label + self.value = label if value is MISSING else value + self.description = description + self.emoji = emoji + self.default = default + + def __repr__(self) -> str: + return ( + f"" + ) + + def __str__(self) -> str: + base = f"{self.emoji} {self.label}" if self.emoji else self.label + if self.description: + return f"{base}\n{self.description}" + return base + + @property + def emoji(self) -> str | Emoji | PartialEmoji | None: + """Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]: The emoji of the option, if available.""" + return self._emoji + + @emoji.setter + def emoji(self, value) -> None: + if value is not None: + if isinstance(value, str): + value = PartialEmoji.from_str(value) + elif isinstance(value, _EmojiTag): + value = value._to_partial() + else: + raise TypeError( + f"expected emoji to be str, Emoji, or PartialEmoji not {value.__class__}" + ) + + self._emoji = value + + @classmethod + def from_dict(cls, data: SelectOptionPayload) -> SelectOption: + try: + emoji = PartialEmoji.from_dict(data["emoji"]) + except KeyError: + emoji = None + + return cls( + label=data["label"], + value=data["value"], + description=data.get("description"), + emoji=emoji, + default=data.get("default", False), + ) + + def to_dict(self) -> SelectOptionPayload: + payload: SelectOptionPayload = { + "label": self.label, + "value": self.value, + "default": self.default, + } + + if self.emoji: + payload["emoji"] = self.emoji.to_dict() # type: ignore + + if self.description: + payload["description"] = self.description + + return payload + + +def _component_factory(data: ComponentPayload) -> Component: + component_type = data["type"] + if component_type == 1: + return ActionRow(data) + elif component_type == 2: + return Button(data) # type: ignore + elif component_type == 3: + return SelectMenu(data) # type: ignore + else: + as_enum = try_enum(ComponentType, component_type) + return Component._raw_construct(type=as_enum) diff --git a/discord/context_managers.py b/discord/context_managers.py new file mode 100644 index 0000000..c9d930b --- /dev/null +++ b/discord/context_managers.py @@ -0,0 +1,90 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from types import TracebackType + + from .abc import Messageable + + TypingT = TypeVar("TypingT", bound="Typing") + +__all__ = ("Typing",) + + +def _typing_done_callback(fut: asyncio.Future) -> None: + # just retrieve any exception and call it a day + try: + fut.exception() + except (asyncio.CancelledError, Exception): + pass + + +class Typing: + def __init__(self, messageable: Messageable) -> None: + self.loop: asyncio.AbstractEventLoop = messageable._state.loop + self.messageable: Messageable = messageable + + async def do_typing(self) -> None: + try: + channel = self._channel + except AttributeError: + channel = await self.messageable._get_channel() + + typing = channel._state.http.send_typing + + while True: + await typing(channel.id) + await asyncio.sleep(5) + + def __enter__(self: TypingT) -> TypingT: + self.task: asyncio.Task = self.loop.create_task(self.do_typing()) + self.task.add_done_callback(_typing_done_callback) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.task.cancel() + + async def __aenter__(self: TypingT) -> TypingT: + self._channel = channel = await self.messageable._get_channel() + await channel._state.http.send_typing(channel.id) + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.task.cancel() diff --git a/discord/embeds.py b/discord/embeds.py new file mode 100644 index 0000000..8b7c9f3 --- /dev/null +++ b/discord/embeds.py @@ -0,0 +1,890 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Final, Mapping, Protocol, TypeVar, Union + +from . import utils +from .colour import Colour + +__all__ = ( + "Embed", + "EmbedField", +) + + +class _EmptyEmbed: + def __bool__(self) -> bool: + return False + + def __repr__(self) -> str: + return "Embed.Empty" + + def __len__(self) -> int: + return 0 + + +EmptyEmbed: Final = _EmptyEmbed() + + +class EmbedProxy: + def __init__(self, layer: dict[str, Any]): + self.__dict__.update(layer) + + def __len__(self) -> int: + return len(self.__dict__) + + def __repr__(self) -> str: + inner = ", ".join( + (f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")) + ) + return f"EmbedProxy({inner})" + + def __getattr__(self, attr: str) -> _EmptyEmbed: + return EmptyEmbed + + +E = TypeVar("E", bound="Embed") + +if TYPE_CHECKING: + from discord.types.embed import Embed as EmbedData + from discord.types.embed import EmbedType + + T = TypeVar("T") + MaybeEmpty = Union[T, _EmptyEmbed] + + class _EmbedFooterProxy(Protocol): + text: MaybeEmpty[str] + icon_url: MaybeEmpty[str] + + class _EmbedMediaProxy(Protocol): + url: MaybeEmpty[str] + proxy_url: MaybeEmpty[str] + height: MaybeEmpty[int] + width: MaybeEmpty[int] + + class _EmbedVideoProxy(Protocol): + url: MaybeEmpty[str] + height: MaybeEmpty[int] + width: MaybeEmpty[int] + + class _EmbedProviderProxy(Protocol): + name: MaybeEmpty[str] + url: MaybeEmpty[str] + + class _EmbedAuthorProxy(Protocol): + name: MaybeEmpty[str] + url: MaybeEmpty[str] + icon_url: MaybeEmpty[str] + proxy_icon_url: MaybeEmpty[str] + + +class EmbedField: + """Represents a field on the :class:`Embed` object. + + .. versionadded:: 2.0 + + Attributes + ---------- + name: :class:`str` + The name of the field. + value: :class:`str` + The value of the field. + inline: :class:`bool` + Whether the field should be displayed inline. + """ + + def __init__(self, name: str, value: str, inline: bool | None = False): + self.name = name + self.value = value + self.inline = inline + + @classmethod + def from_dict(cls: type[E], data: Mapping[str, Any]) -> E: + """Converts a :class:`dict` to a :class:`EmbedField` provided it is in the + format that Discord expects it to be in. + + You can find out about this format in the `official Discord documentation`__. + + .. _DiscordDocsEF: https://discord.com/developers/docs/resources/channel#embed-object-embed-field-structure + + __ DiscordDocsEF_ + + Parameters + ---------- + data: :class:`dict` + The dictionary to convert into an EmbedField object. + """ + self: E = cls.__new__(cls) + + self.name = data["name"] + self.value = data["value"] + self.inline = data.get("inline", False) + + return self + + def to_dict(self) -> dict[str, str | bool]: + """Converts this EmbedField object into a dict. + + Returns + ------- + Dict[:class:`str`, Union[:class:`str`, :class:`bool`]] + A dictionary of :class:`str` embed field keys bound to the respective value. + """ + return { + "name": self.name, + "value": self.value, + "inline": self.inline, + } + + +class Embed: + """Represents a Discord embed. + + .. container:: operations + + .. describe:: len(x) + + Returns the total size of the embed. + Useful for checking if it's within the 6000 character limit. + + .. describe:: bool(b) + + Returns whether the embed has any data set. + + .. versionadded:: 2.0 + + Certain properties return an ``EmbedProxy``, a type + that acts similar to a regular :class:`dict` except using dotted access, + e.g. ``embed.author.icon_url``. If the attribute + is invalid or empty, then a special sentinel value is returned, + :attr:`Embed.Empty`. + + For ease of use, all parameters that expect a :class:`str` are implicitly + cast to :class:`str` for you. + + Attributes + ---------- + title: :class:`str` + The title of the embed. + This can be set during initialisation. + Must be 256 characters or fewer. + type: :class:`str` + The type of embed. Usually "rich". + This can be set during initialisation. + Possible strings for embed types can be found on discord's + `api docs `_ + description: :class:`str` + The description of the embed. + This can be set during initialisation. + Must be 4096 characters or fewer. + url: :class:`str` + The URL of the embed. + This can be set during initialisation. + timestamp: :class:`datetime.datetime` + The timestamp of the embed content. This is an aware datetime. + If a naive datetime is passed, it is converted to an aware + datetime with the local timezone. + colour: Union[:class:`Colour`, :class:`int`] + The colour code of the embed. Aliased to ``color`` as well. + This can be set during initialisation. + Empty + A special sentinel value used by ``EmbedProxy`` and this class + to denote that the value or attribute is empty. + """ + + __slots__ = ( + "title", + "url", + "type", + "_timestamp", + "_colour", + "_footer", + "_image", + "_thumbnail", + "_video", + "_provider", + "_author", + "_fields", + "description", + ) + + Empty: Final = EmptyEmbed + + def __init__( + self, + *, + colour: int | Colour | _EmptyEmbed = EmptyEmbed, + color: int | Colour | _EmptyEmbed = EmptyEmbed, + title: MaybeEmpty[Any] = EmptyEmbed, + type: EmbedType = "rich", + url: MaybeEmpty[Any] = EmptyEmbed, + description: MaybeEmpty[Any] = EmptyEmbed, + timestamp: datetime.datetime = None, + fields: list[EmbedField] | None = None, + ): + + self.colour = colour if colour is not EmptyEmbed else color + self.title = title + self.type = type + self.url = url + self.description = description + + if self.title is not EmptyEmbed and self.title is not None: + self.title = str(self.title) + + if self.description is not EmptyEmbed and self.description is not None: + self.description = str(self.description) + + if self.url is not EmptyEmbed and self.url is not None: + self.url = str(self.url) + + if timestamp: + self.timestamp = timestamp + self._fields: list[EmbedField] = fields or [] + + @classmethod + def from_dict(cls: type[E], data: Mapping[str, Any]) -> E: + """Converts a :class:`dict` to a :class:`Embed` provided it is in the + format that Discord expects it to be in. + + You can find out about this format in the `official Discord documentation`__. + + .. _DiscordDocs: https://discord.com/developers/docs/resources/channel#embed-object + + __ DiscordDocs_ + + Parameters + ---------- + data: :class:`dict` + The dictionary to convert into an embed. + + Returns + ------- + :class:`Embed` + The converted embed object. + """ + # we are bypassing __init__ here since it doesn't apply here + self: E = cls.__new__(cls) + + # fill in the basic fields + + self.title = data.get("title", EmptyEmbed) + self.type = data.get("type", EmptyEmbed) + self.description = data.get("description", EmptyEmbed) + self.url = data.get("url", EmptyEmbed) + + if self.title is not EmptyEmbed: + self.title = str(self.title) + + if self.description is not EmptyEmbed: + self.description = str(self.description) + + if self.url is not EmptyEmbed: + self.url = str(self.url) + + # try to fill in the more rich fields + + try: + self._colour = Colour(value=data["color"]) + except KeyError: + pass + + try: + self._timestamp = utils.parse_time(data["timestamp"]) + except KeyError: + pass + + for attr in ( + "thumbnail", + "video", + "provider", + "author", + "fields", + "image", + "footer", + ): + if attr == "fields": + value = data.get(attr, []) + self._fields = [EmbedField.from_dict(d) for d in value] if value else [] + else: + try: + value = data[attr] + except KeyError: + continue + else: + setattr(self, f"_{attr}", value) + + return self + + def copy(self: E) -> E: + """Creates a shallow copy of the :class:`Embed` object. + + Returns + ------- + :class:`Embed` + The copied embed object. + """ + return self.__class__.from_dict(self.to_dict()) + + def __len__(self) -> int: + total = len(self.title) + len(self.description) + for field in getattr(self, "_fields", []): + total += len(field.name) + len(field.value) + + try: + footer_text = self._footer["text"] + except (AttributeError, KeyError): + pass + else: + total += len(footer_text) + + try: + author = self._author + except AttributeError: + pass + else: + total += len(author["name"]) + + return total + + def __bool__(self) -> bool: + return any( + ( + self.title, + self.url, + self.description, + self.colour, + self.fields, + self.timestamp, + self.author, + self.thumbnail, + self.footer, + self.image, + self.provider, + self.video, + ) + ) + + @property + def colour(self) -> MaybeEmpty[Colour]: + return getattr(self, "_colour", EmptyEmbed) + + @colour.setter + def colour(self, value: int | Colour | _EmptyEmbed): # type: ignore + if isinstance(value, (Colour, _EmptyEmbed)): + self._colour = value + elif isinstance(value, int): + self._colour = Colour(value=value) + else: + raise TypeError( + f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead." + ) + + color = colour + + @property + def timestamp(self) -> MaybeEmpty[datetime.datetime]: + return getattr(self, "_timestamp", EmptyEmbed) + + @timestamp.setter + def timestamp(self, value: MaybeEmpty[datetime.datetime]): + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.astimezone() + self._timestamp = value + elif isinstance(value, _EmptyEmbed): + self._timestamp = value + else: + raise TypeError( + f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead" + ) + + @property + def footer(self) -> _EmbedFooterProxy: + """Returns an ``EmbedProxy`` denoting the footer contents. + + See :meth:`set_footer` for possible values you can access. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_footer", {})) # type: ignore + + def set_footer( + self: E, + *, + text: MaybeEmpty[Any] = EmptyEmbed, + icon_url: MaybeEmpty[Any] = EmptyEmbed, + ) -> E: + """Sets the footer for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ---------- + text: :class:`str` + The footer text. + Must be 2048 characters or fewer. + icon_url: :class:`str` + The URL of the footer icon. Only HTTP(S) is supported. + """ + + self._footer = {} + if text is not EmptyEmbed: + self._footer["text"] = str(text) + + if icon_url is not EmptyEmbed: + self._footer["icon_url"] = str(icon_url) + + return self + + def remove_footer(self: E) -> E: + """Clears embed's footer information. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 2.0 + """ + try: + del self._footer + except AttributeError: + pass + + return self + + @property + def image(self) -> _EmbedMediaProxy: + """Returns an ``EmbedProxy`` denoting the image contents. + + Possible attributes you can access are: + + - ``url`` + - ``proxy_url`` + - ``width`` + - ``height`` + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_image", {})) # type: ignore + + def set_image(self: E, *, url: MaybeEmpty[Any]) -> E: + """Sets the image for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionchanged:: 1.4 + Passing :attr:`Empty` removes the image. + + Parameters + ---------- + url: :class:`str` + The source URL for the image. Only HTTP(S) is supported. + """ + + if url is EmptyEmbed: + try: + del self._image + except AttributeError: + pass + else: + self._image = { + "url": str(url), + } + + return self + + def remove_image(self: E) -> E: + """Removes the embed's image. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 2.0 + """ + try: + del self._image + except AttributeError: + pass + + return self + + @property + def thumbnail(self) -> _EmbedMediaProxy: + """Returns an ``EmbedProxy`` denoting the thumbnail contents. + + Possible attributes you can access are: + + - ``url`` + - ``proxy_url`` + - ``width`` + - ``height`` + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore + + def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E: + """Sets the thumbnail for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionchanged:: 1.4 + Passing :attr:`Empty` removes the thumbnail. + + Parameters + ---------- + url: :class:`str` + The source URL for the thumbnail. Only HTTP(S) is supported. + """ + + if url is EmptyEmbed: + try: + del self._thumbnail + except AttributeError: + pass + else: + self._thumbnail = { + "url": str(url), + } + + return self + + def remove_thumbnail(self: E) -> E: + """Removes the embed's thumbnail. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 2.0 + """ + try: + del self._thumbnail + except AttributeError: + pass + + return self + + @property + def video(self) -> _EmbedVideoProxy: + """Returns an ``EmbedProxy`` denoting the video contents. + + Possible attributes include: + + - ``url`` for the video URL. + - ``height`` for the video height. + - ``width`` for the video width. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_video", {})) # type: ignore + + @property + def provider(self) -> _EmbedProviderProxy: + """Returns an ``EmbedProxy`` denoting the provider contents. + + The only attributes that might be accessed are ``name`` and ``url``. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_provider", {})) # type: ignore + + @property + def author(self) -> _EmbedAuthorProxy: + """Returns an ``EmbedProxy`` denoting the author contents. + + See :meth:`set_author` for possible values you can access. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_author", {})) # type: ignore + + def set_author( + self: E, + *, + name: Any, + url: MaybeEmpty[Any] = EmptyEmbed, + icon_url: MaybeEmpty[Any] = EmptyEmbed, + ) -> E: + """Sets the author for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ---------- + name: :class:`str` + The name of the author. + Must be 256 characters or fewer. + url: :class:`str` + The URL for the author. + icon_url: :class:`str` + The URL of the author icon. Only HTTP(S) is supported. + """ + + self._author = { + "name": str(name), + } + + if url is not EmptyEmbed: + self._author["url"] = str(url) + + if icon_url is not EmptyEmbed: + self._author["icon_url"] = str(icon_url) + + return self + + def remove_author(self: E) -> E: + """Clears embed's author information. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 1.4 + """ + try: + del self._author + except AttributeError: + pass + + return self + + @property + def fields(self) -> list[EmbedField]: + """Returns a :class:`list` of :class:`EmbedField` objects denoting the field contents. + + See :meth:`add_field` for possible values you can access. + + If the attribute has no value then ``None`` is returned. + """ + return self._fields + + @fields.setter + def fields(self, value: list[EmbedField]) -> None: + """Sets the fields for the embed. This overwrites any existing fields. + + Parameters + ---------- + value: List[:class:`EmbedField`] + The list of :class:`EmbedField` objects to include in the embed. + """ + if not all(isinstance(x, EmbedField) for x in value): + raise TypeError("Expected a list of EmbedField objects.") + + self._fields = value + + def append_field(self, field: EmbedField) -> None: + """Appends an :class:`EmbedField` object to the embed. + + .. versionadded:: 2.0 + + Parameters + ---------- + field: :class:`EmbedField` + The field to add. + """ + if not isinstance(field, EmbedField): + raise TypeError("Expected an EmbedField object.") + + self._fields.append(field) + + def add_field(self: E, *, name: str, value: str, inline: bool = True) -> E: + """Adds a field to the embed object. + + This function returns the class instance to allow for fluent-style + chaining. There must be 25 fields or fewer. + + Parameters + ---------- + name: :class:`str` + The name of the field. + Must be 256 characters or fewer. + value: :class:`str` + The value of the field. + Must be 1024 characters or fewer. + inline: :class:`bool` + Whether the field should be displayed inline. + """ + self._fields.append(EmbedField(name=str(name), value=str(value), inline=inline)) + + return self + + def insert_field_at( + self: E, index: int, *, name: Any, value: Any, inline: bool = True + ) -> E: + """Inserts a field before a specified index to the embed. + + This function returns the class instance to allow for fluent-style + chaining. There must be 25 fields or fewer. + + .. versionadded:: 1.2 + + Parameters + ---------- + index: :class:`int` + The index of where to insert the field. + name: :class:`str` + The name of the field. + Must be 256 characters or fewer. + value: :class:`str` + The value of the field. + Must be 1024 characters or fewer. + inline: :class:`bool` + Whether the field should be displayed inline. + """ + + field = EmbedField(name=str(name), value=str(value), inline=inline) + + self._fields.insert(index, field) + + return self + + def clear_fields(self) -> None: + """Removes all fields from this embed.""" + self._fields.clear() + + def remove_field(self, index: int) -> None: + """Removes a field at a specified index. + + If the index is invalid or out of bounds then the error is + silently swallowed. + + .. note:: + + When deleting a field by index, the index of the other fields + shift to fill the gap just like a regular list. + + Parameters + ---------- + index: :class:`int` + The index of the field to remove. + """ + try: + del self._fields[index] + except IndexError: + pass + + def set_field_at( + self: E, index: int, *, name: Any, value: Any, inline: bool = True + ) -> E: + """Modifies a field to the embed object. + + The index must point to a valid pre-existing field. There must be 25 fields or fewer. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ---------- + index: :class:`int` + The index of the field to modify. + name: :class:`str` + The name of the field. + Must be 256 characters or fewer. + value: :class:`str` + The value of the field. + Must be 1024 characters or fewer. + inline: :class:`bool` + Whether the field should be displayed inline. + + Raises + ------ + IndexError + An invalid index was provided. + """ + + try: + field = self._fields[index] + except (TypeError, IndexError): + raise IndexError("field index out of range") + + field.name = str(name) + field.value = str(value) + field.inline = inline + return self + + def to_dict(self) -> EmbedData: + """Converts this embed object into a dict. + + Returns + ------- + Dict[:class:`str`, Union[:class:`str`, :class:`int`, :class:`bool`]] + A dictionary of :class:`str` embed keys bound to the respective value. + """ + + # add in the raw data into the dict + result = { + key[1:]: getattr(self, key) + for key in self.__slots__ + if key != "_fields" and key[0] == "_" and hasattr(self, key) + } + + # add in the fields + result["fields"] = [field.to_dict() for field in self._fields] + + # deal with basic convenience wrappers + + try: + colour = result.pop("colour") + except KeyError: + pass + else: + if colour: + result["color"] = colour.value + + try: + timestamp = result.pop("timestamp") + except KeyError: + pass + else: + if timestamp: + if timestamp.tzinfo: + result["timestamp"] = timestamp.astimezone( + tz=datetime.timezone.utc + ).isoformat() + else: + result["timestamp"] = timestamp.replace( + tzinfo=datetime.timezone.utc + ).isoformat() + + # add in the non-raw attribute ones + if self.type: + result["type"] = self.type + + if self.description: + result["description"] = self.description + + if self.url: + result["url"] = self.url + + if self.title: + result["title"] = self.title + + return result # type: ignore diff --git a/discord/emoji.py b/discord/emoji.py new file mode 100644 index 0000000..08dad90 --- /dev/null +++ b/discord/emoji.py @@ -0,0 +1,266 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterator + +from .asset import Asset, AssetMixin +from .partial_emoji import PartialEmoji, _EmojiTag +from .user import User +from .utils import MISSING, SnowflakeList, snowflake_time + +__all__ = ("Emoji",) + +if TYPE_CHECKING: + from datetime import datetime + + from .abc import Snowflake + from .guild import Guild + from .role import Role + from .state import ConnectionState + from .types.emoji import Emoji as EmojiPayload + + +class Emoji(_EmojiTag, AssetMixin): + """Represents a custom emoji. + + Depending on the way this object was created, some attributes can + have a value of ``None``. + + .. container:: operations + + .. describe:: x == y + + Checks if two emoji are the same. + + .. describe:: x != y + + Checks if two emoji are not the same. + + .. describe:: hash(x) + + Return the emoji's hash. + + .. describe:: iter(x) + + Returns an iterator of ``(field, value)`` pairs. This allows this class + to be used as an iterable in list/dict/etc constructions. + + .. describe:: str(x) + + Returns the emoji rendered for discord. + + Attributes + ---------- + name: :class:`str` + The name of the emoji. + id: :class:`int` + The emoji's ID. + require_colons: :class:`bool` + If colons are required to use this emoji in the client (:PJSalt: vs PJSalt). + animated: :class:`bool` + Whether an emoji is animated or not. + managed: :class:`bool` + If this emoji is managed by a Twitch integration. + guild_id: :class:`int` + The guild ID the emoji belongs to. + available: :class:`bool` + Whether the emoji is available for use. + user: Optional[:class:`User`] + The user that created the emoji. This can only be retrieved using :meth:`Guild.fetch_emoji` and + having the :attr:`~Permissions.manage_emojis` permission. + """ + + __slots__: tuple[str, ...] = ( + "require_colons", + "animated", + "managed", + "id", + "name", + "_roles", + "guild_id", + "_state", + "user", + "available", + ) + + def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload): + self.guild_id: int = guild.id + self._state: ConnectionState = state + self._from_data(data) + + def _from_data(self, emoji: EmojiPayload): + self.require_colons: bool = emoji.get("require_colons", False) + self.managed: bool = emoji.get("managed", False) + self.id: int = int(emoji["id"]) # type: ignore + self.name: str = emoji["name"] # type: ignore + self.animated: bool = emoji.get("animated", False) + self.available: bool = emoji.get("available", True) + self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", []))) + user = emoji.get("user") + self.user: User | None = User(state=self._state, data=user) if user else None + + def _to_partial(self) -> PartialEmoji: + return PartialEmoji(name=self.name, animated=self.animated, id=self.id) + + def __iter__(self) -> Iterator[tuple[str, Any]]: + for attr in self.__slots__: + if attr[0] != "_": + value = getattr(self, attr, None) + if value is not None: + yield attr, value + + def __str__(self) -> str: + if self.animated: + return f"" + return f"<:{self.name}:{self.id}>" + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _EmojiTag) and self.id == other.id + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return self.id >> 22 + + @property + def created_at(self) -> datetime: + """:class:`datetime.datetime`: Returns the emoji's creation time in UTC.""" + return snowflake_time(self.id) + + @property + def url(self) -> str: + """:class:`str`: Returns the URL of the emoji.""" + fmt = "gif" if self.animated else "png" + return f"{Asset.BASE}/emojis/{self.id}.{fmt}" + + @property + def roles(self) -> list[Role]: + """List[:class:`Role`]: A :class:`list` of roles that is allowed to use this emoji. + + If roles is empty, the emoji is unrestricted. + """ + guild = self.guild + if guild is None: + return [] + + return [role for role in guild.roles if self._roles.has(role.id)] + + @property + def guild(self) -> Guild: + """:class:`Guild`: The guild this emoji belongs to.""" + return self._state._get_guild(self.guild_id) + + def is_usable(self) -> bool: + """:class:`bool`: Whether the bot can use this emoji. + + .. versionadded:: 1.3 + """ + if not self.available: + return False + if not self._roles: + return True + emoji_roles, my_roles = self._roles, self.guild.me._roles + return any(my_roles.has(role_id) for role_id in emoji_roles) + + async def delete(self, *, reason: str | None = None) -> None: + """|coro| + + Deletes the custom emoji. + + You must have :attr:`~Permissions.manage_emojis` permission to + do this. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this emoji. Shows up on the audit log. + + Raises + ------ + Forbidden + You are not allowed to delete emojis. + HTTPException + An error occurred deleting the emoji. + """ + + await self._state.http.delete_custom_emoji( + self.guild.id, self.id, reason=reason + ) + + async def edit( + self, + *, + name: str = MISSING, + roles: list[Snowflake] = MISSING, + reason: str | None = None, + ) -> Emoji: + r"""|coro| + + Edits the custom emoji. + + You must have :attr:`~Permissions.manage_emojis` permission to + do this. + + .. versionchanged:: 2.0 + The newly updated emoji is returned. + + Parameters + ----------- + name: :class:`str` + The new emoji name. + roles: Optional[List[:class:`~discord.abc.Snowflake`]] + A list of roles that can use this emoji. An empty list can be passed to make it available to everyone. + reason: Optional[:class:`str`] + The reason for editing this emoji. Shows up on the audit log. + + Raises + ------- + Forbidden + You are not allowed to edit emojis. + HTTPException + An error occurred editing the emoji. + + Returns + -------- + :class:`Emoji` + The newly updated emoji. + """ + + payload = {} + if name is not MISSING: + payload["name"] = name + if roles is not MISSING: + payload["roles"] = [role.id for role in roles] + + data = await self._state.http.edit_custom_emoji( + self.guild.id, self.id, payload=payload, reason=reason + ) + return Emoji(guild=self.guild, data=data, state=self._state) diff --git a/discord/enums.py b/discord/enums.py new file mode 100644 index 0000000..dc952d0 --- /dev/null +++ b/discord/enums.py @@ -0,0 +1,844 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import types +from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) + +__all__ = ( + "Enum", + "ChannelType", + "MessageType", + "VoiceRegion", + "SpeakingState", + "VerificationLevel", + "ContentFilter", + "Status", + "DefaultAvatar", + "AuditLogAction", + "AuditLogActionCategory", + "UserFlags", + "ActivityType", + "NotificationLevel", + "TeamMembershipState", + "WebhookType", + "ExpireBehaviour", + "ExpireBehavior", + "StickerType", + "StickerFormatType", + "InviteTarget", + "VideoQualityMode", + "ComponentType", + "ButtonStyle", + "StagePrivacyLevel", + "InteractionType", + "InteractionResponseType", + "NSFWLevel", + "EmbeddedActivity", + "ScheduledEventStatus", + "ScheduledEventPrivacyLevel", + "ScheduledEventLocationType", + "InputTextStyle", + "SlashCommandOptionType", + "AutoModTriggerType", + "AutoModEventType", + "AutoModActionType", + "AutoModKeywordPresetType", +) + + +def _create_value_cls(name, comparable): + cls = namedtuple(f"_EnumValue_{name}", "name value") + cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" + cls.__str__ = lambda self: f"{name}.{self.name}" + if comparable: + cls.__le__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value <= other.value + ) + cls.__ge__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value >= other.value + ) + cls.__lt__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value < other.value + ) + cls.__gt__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value > other.value + ) + return cls + + +def _is_descriptor(obj): + return ( + hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") + ) + + +class EnumMeta(type): + if TYPE_CHECKING: + __name__: ClassVar[str] + _enum_member_names_: ClassVar[List[str]] + _enum_member_map_: ClassVar[Dict[str, Any]] + _enum_value_map_: ClassVar[Dict[Any, Any]] + + def __new__(cls, name, bases, attrs, *, comparable: bool = False): + value_mapping = {} + member_mapping = {} + member_names = [] + + value_cls = _create_value_cls(name, comparable) + for key, value in list(attrs.items()): + is_descriptor = _is_descriptor(value) + if key[0] == "_" and not is_descriptor: + continue + + # Special case classmethod to just pass through + if isinstance(value, classmethod): + continue + + if is_descriptor: + setattr(value_cls, key, value) + del attrs[key] + continue + + try: + new_value = value_mapping[value] + except KeyError: + new_value = value_cls(name=key, value=value) + value_mapping[value] = new_value + member_names.append(key) + + member_mapping[key] = new_value + attrs[key] = new_value + + attrs["_enum_value_map_"] = value_mapping + attrs["_enum_member_map_"] = member_mapping + attrs["_enum_member_names_"] = member_names + attrs["_enum_value_cls_"] = value_cls + actual_cls = super().__new__(cls, name, bases, attrs) + value_cls._actual_enum_cls_ = actual_cls # type: ignore + return actual_cls + + def __iter__(cls): + return (cls._enum_member_map_[name] for name in cls._enum_member_names_) + + def __reversed__(cls): + return ( + cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_) + ) + + def __len__(cls): + return len(cls._enum_member_names_) + + def __repr__(cls): + return f"" + + @property + def __members__(cls): + return types.MappingProxyType(cls._enum_member_map_) + + def __call__(cls, value): + try: + return cls._enum_value_map_[value] + except (KeyError, TypeError): + raise ValueError(f"{value!r} is not a valid {cls.__name__}") + + def __getitem__(cls, key): + return cls._enum_member_map_[key] + + def __setattr__(cls, name, value): + raise TypeError("Enums are immutable.") + + def __delattr__(cls, attr): + raise TypeError("Enums are immutable") + + def __instancecheck__(self, instance): + # isinstance(x, Y) + # -> __instancecheck__(Y, x) + try: + return instance._actual_enum_cls_ is self + except AttributeError: + return False + + +if TYPE_CHECKING: + from enum import Enum +else: + + class Enum(metaclass=EnumMeta): + @classmethod + def try_value(cls, value): + try: + return cls._enum_value_map_[value] + except (KeyError, TypeError): + return value + + +class ChannelType(Enum): + text = 0 + private = 1 + voice = 2 + group = 3 + category = 4 + news = 5 + news_thread = 10 + public_thread = 11 + private_thread = 12 + stage_voice = 13 + directory = 14 + forum = 15 + + def __str__(self): + return self.name + + +class MessageType(Enum): + default = 0 + recipient_add = 1 + recipient_remove = 2 + call = 3 + channel_name_change = 4 + channel_icon_change = 5 + pins_add = 6 + new_member = 7 + premium_guild_subscription = 8 + premium_guild_tier_1 = 9 + premium_guild_tier_2 = 10 + premium_guild_tier_3 = 11 + channel_follow_add = 12 + guild_stream = 13 + guild_discovery_disqualified = 14 + guild_discovery_requalified = 15 + guild_discovery_grace_period_initial_warning = 16 + guild_discovery_grace_period_final_warning = 17 + thread_created = 18 + reply = 19 + application_command = 20 + thread_starter_message = 21 + guild_invite_reminder = 22 + context_menu_command = 23 + auto_moderation_action = 24 + + +class VoiceRegion(Enum): + us_west = "us-west" + us_east = "us-east" + us_south = "us-south" + us_central = "us-central" + eu_west = "eu-west" + eu_central = "eu-central" + singapore = "singapore" + london = "london" + sydney = "sydney" + amsterdam = "amsterdam" + frankfurt = "frankfurt" + brazil = "brazil" + hongkong = "hongkong" + russia = "russia" + japan = "japan" + southafrica = "southafrica" + south_korea = "south-korea" + india = "india" + europe = "europe" + dubai = "dubai" + vip_us_east = "vip-us-east" + vip_us_west = "vip-us-west" + vip_amsterdam = "vip-amsterdam" + + def __str__(self): + return self.value + + +class SpeakingState(Enum): + none = 0 + voice = 1 + soundshare = 2 + priority = 4 + + def __str__(self): + return self.name + + def __int__(self): + return self.value + + +class VerificationLevel(Enum, comparable=True): + none = 0 + low = 1 + medium = 2 + high = 3 + highest = 4 + + def __str__(self): + return self.name + + +class ContentFilter(Enum, comparable=True): + disabled = 0 + no_role = 1 + all_members = 2 + + def __str__(self): + return self.name + + +class Status(Enum): + online = "online" + offline = "offline" + idle = "idle" + dnd = "dnd" + do_not_disturb = "dnd" + invisible = "invisible" + streaming = "streaming" + + def __str__(self): + return self.value + + +class DefaultAvatar(Enum): + blurple = 0 + grey = 1 + gray = 1 + green = 2 + orange = 3 + red = 4 + + def __str__(self): + return self.name + + +class NotificationLevel(Enum, comparable=True): + all_messages = 0 + only_mentions = 1 + + +class AuditLogActionCategory(Enum): + create = 1 + delete = 2 + update = 3 + + +class AuditLogAction(Enum): + guild_update = 1 + channel_create = 10 + channel_update = 11 + channel_delete = 12 + overwrite_create = 13 + overwrite_update = 14 + overwrite_delete = 15 + kick = 20 + member_prune = 21 + ban = 22 + unban = 23 + member_update = 24 + member_role_update = 25 + member_move = 26 + member_disconnect = 27 + bot_add = 28 + role_create = 30 + role_update = 31 + role_delete = 32 + invite_create = 40 + invite_update = 41 + invite_delete = 42 + webhook_create = 50 + webhook_update = 51 + webhook_delete = 52 + emoji_create = 60 + emoji_update = 61 + emoji_delete = 62 + message_delete = 72 + message_bulk_delete = 73 + message_pin = 74 + message_unpin = 75 + integration_create = 80 + integration_update = 81 + integration_delete = 82 + stage_instance_create = 83 + stage_instance_update = 84 + stage_instance_delete = 85 + sticker_create = 90 + sticker_update = 91 + sticker_delete = 92 + scheduled_event_create = 100 + scheduled_event_update = 101 + scheduled_event_delete = 102 + thread_create = 110 + thread_update = 111 + thread_delete = 112 + application_command_permission_update = 121 + auto_moderation_rule_create = 140 + auto_moderation_rule_update = 141 + auto_moderation_rule_delete = 142 + auto_moderation_block_message = 143 + + @property + def category(self) -> Optional[AuditLogActionCategory]: + lookup: Dict[AuditLogAction, Optional[AuditLogActionCategory]] = { + AuditLogAction.guild_update: AuditLogActionCategory.update, + AuditLogAction.channel_create: AuditLogActionCategory.create, + AuditLogAction.channel_update: AuditLogActionCategory.update, + AuditLogAction.channel_delete: AuditLogActionCategory.delete, + AuditLogAction.overwrite_create: AuditLogActionCategory.create, + AuditLogAction.overwrite_update: AuditLogActionCategory.update, + AuditLogAction.overwrite_delete: AuditLogActionCategory.delete, + AuditLogAction.kick: None, + AuditLogAction.member_prune: None, + AuditLogAction.ban: None, + AuditLogAction.unban: None, + AuditLogAction.member_update: AuditLogActionCategory.update, + AuditLogAction.member_role_update: AuditLogActionCategory.update, + AuditLogAction.member_move: None, + AuditLogAction.member_disconnect: None, + AuditLogAction.bot_add: None, + AuditLogAction.role_create: AuditLogActionCategory.create, + AuditLogAction.role_update: AuditLogActionCategory.update, + AuditLogAction.role_delete: AuditLogActionCategory.delete, + AuditLogAction.invite_create: AuditLogActionCategory.create, + AuditLogAction.invite_update: AuditLogActionCategory.update, + AuditLogAction.invite_delete: AuditLogActionCategory.delete, + AuditLogAction.webhook_create: AuditLogActionCategory.create, + AuditLogAction.webhook_update: AuditLogActionCategory.update, + AuditLogAction.webhook_delete: AuditLogActionCategory.delete, + AuditLogAction.emoji_create: AuditLogActionCategory.create, + AuditLogAction.emoji_update: AuditLogActionCategory.update, + AuditLogAction.emoji_delete: AuditLogActionCategory.delete, + AuditLogAction.message_delete: AuditLogActionCategory.delete, + AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete, + AuditLogAction.message_pin: None, + AuditLogAction.message_unpin: None, + AuditLogAction.integration_create: AuditLogActionCategory.create, + AuditLogAction.integration_update: AuditLogActionCategory.update, + AuditLogAction.integration_delete: AuditLogActionCategory.delete, + AuditLogAction.stage_instance_create: AuditLogActionCategory.create, + AuditLogAction.stage_instance_update: AuditLogActionCategory.update, + AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete, + AuditLogAction.sticker_create: AuditLogActionCategory.create, + AuditLogAction.sticker_update: AuditLogActionCategory.update, + AuditLogAction.sticker_delete: AuditLogActionCategory.delete, + AuditLogAction.scheduled_event_create: AuditLogActionCategory.create, + AuditLogAction.scheduled_event_update: AuditLogActionCategory.update, + AuditLogAction.scheduled_event_delete: AuditLogActionCategory.delete, + AuditLogAction.thread_create: AuditLogActionCategory.create, + AuditLogAction.thread_update: AuditLogActionCategory.update, + AuditLogAction.thread_delete: AuditLogActionCategory.delete, + AuditLogAction.application_command_permission_update: AuditLogActionCategory.update, + AuditLogAction.auto_moderation_rule_create: AuditLogActionCategory.create, + AuditLogAction.auto_moderation_rule_update: AuditLogActionCategory.update, + AuditLogAction.auto_moderation_rule_delete: AuditLogActionCategory.delete, + AuditLogAction.auto_moderation_block_message: None, + } + return lookup[self] + + @property + def target_type(self) -> Optional[str]: + v = self.value + if v == -1: + return "all" + elif v < 10: + return "guild" + elif v < 20: + return "channel" + elif v < 30: + return "user" + elif v < 40: + return "role" + elif v < 50: + return "invite" + elif v < 60: + return "webhook" + elif v < 70: + return "emoji" + elif v == 73: + return "channel" + elif v < 80: + return "message" + elif v < 83: + return "integration" + elif v < 90: + return "stage_instance" + elif v < 93: + return "sticker" + elif v < 103: + return "scheduled_event" + elif v < 113: + return "thread" + elif v < 122: + return "application_command_permission" + elif v < 144: + return "auto_moderation_rule" + + +class UserFlags(Enum): + staff = 1 + partner = 2 + hypesquad = 4 + bug_hunter = 8 + mfa_sms = 16 + premium_promo_dismissed = 32 + hypesquad_bravery = 64 + hypesquad_brilliance = 128 + hypesquad_balance = 256 + early_supporter = 512 + team_user = 1024 + partner_or_verification_application = 2048 + system = 4096 + has_unread_urgent_messages = 8192 + bug_hunter_level_2 = 16384 + underage_deleted = 32768 + verified_bot = 65536 + verified_bot_developer = 131072 + discord_certified_moderator = 262144 + bot_http_interactions = 524288 + spammer = 1048576 + + +class ActivityType(Enum): + unknown = -1 + playing = 0 + streaming = 1 + listening = 2 + watching = 3 + custom = 4 + competing = 5 + + def __int__(self): + return self.value + + +class TeamMembershipState(Enum): + invited = 1 + accepted = 2 + + +class WebhookType(Enum): + incoming = 1 + channel_follower = 2 + application = 3 + + +class ExpireBehaviour(Enum): + remove_role = 0 + kick = 1 + + +ExpireBehavior = ExpireBehaviour + + +class StickerType(Enum): + standard = 1 + guild = 2 + + +class StickerFormatType(Enum): + png = 1 + apng = 2 + lottie = 3 + + @property + def file_extension(self) -> str: + lookup: Dict[StickerFormatType, str] = { + StickerFormatType.png: "png", + StickerFormatType.apng: "png", + StickerFormatType.lottie: "json", + } + return lookup[self] + + +class InviteTarget(Enum): + unknown = 0 + stream = 1 + embedded_application = 2 + + +class InteractionType(Enum): + ping = 1 + application_command = 2 + component = 3 + auto_complete = 4 + modal_submit = 5 + + +class InteractionResponseType(Enum): + pong = 1 + # ack = 2 (deprecated) + # channel_message = 3 (deprecated) + channel_message = 4 # (with source) + deferred_channel_message = 5 # (with source) + deferred_message_update = 6 # for components + message_update = 7 # for components + auto_complete_result = 8 # for autocomplete interactions + modal = 9 # for modal dialogs + + +class VideoQualityMode(Enum): + auto = 1 + full = 2 + + def __int__(self): + return self.value + + +class ComponentType(Enum): + action_row = 1 + button = 2 + select = 3 + input_text = 4 + + def __int__(self): + return self.value + + +class ButtonStyle(Enum): + primary = 1 + secondary = 2 + success = 3 + danger = 4 + link = 5 + + # Aliases + blurple = 1 + grey = 2 + gray = 2 + green = 3 + red = 4 + url = 5 + + def __int__(self): + return self.value + + +class InputTextStyle(Enum): + short = 1 + singleline = 1 + paragraph = 2 + multiline = 2 + long = 2 + + +class ApplicationType(Enum): + game = 1 + music = 2 + ticketed_events = 3 + guild_role_subscriptions = 4 + + +class StagePrivacyLevel(Enum): + # public = 1 (deprecated) + closed = 2 + guild_only = 2 + + +class NSFWLevel(Enum, comparable=True): + default = 0 + explicit = 1 + safe = 2 + age_restricted = 3 + + +class SlashCommandOptionType(Enum): + sub_command = 1 + sub_command_group = 2 + string = 3 + integer = 4 + boolean = 5 + user = 6 + channel = 7 + role = 8 + mentionable = 9 + number = 10 + attachment = 11 + + @classmethod + def from_datatype(cls, datatype): + if isinstance(datatype, tuple): # typing.Union has been used + datatypes = [cls.from_datatype(op) for op in datatype] + if all(x == cls.channel for x in datatypes): + return cls.channel + elif set(datatypes) <= {cls.role, cls.user}: + return cls.mentionable + else: + raise TypeError("Invalid usage of typing.Union") + + py_3_10_union_type = hasattr(types, "UnionType") and isinstance( + datatype, types.UnionType + ) + + if py_3_10_union_type or getattr(datatype, "__origin__", None) is Union: + # Python 3.10+ "|" operator or typing.Union has been used. The __args__ attribute is a tuple of the types. + # Type checking fails for this case, so ignore it. + return cls.from_datatype(datatype.__args__) # type: ignore + + if datatype.__name__ in ["Member", "User"]: + return cls.user + if datatype.__name__ in [ + "GuildChannel", + "TextChannel", + "VoiceChannel", + "StageChannel", + "CategoryChannel", + "ThreadOption", + "Thread", + ]: + return cls.channel + if datatype.__name__ == "Role": + return cls.role + if datatype.__name__ == "Attachment": + return cls.attachment + if datatype.__name__ == "Mentionable": + return cls.mentionable + + if issubclass(datatype, str): + return cls.string + if issubclass(datatype, bool): + return cls.boolean + if issubclass(datatype, int): + return cls.integer + if issubclass(datatype, float): + return cls.number + + from .commands.context import ApplicationContext + + if not issubclass( + datatype, ApplicationContext + ): # TODO: prevent ctx being passed here in cog commands + raise TypeError( + f"Invalid class {datatype} used as an input type for an Option" + ) # TODO: Improve the error message + + +class EmbeddedActivity(Enum): + awkword = 879863881349087252 + betrayal = 773336526917861400 + checkers_in_the_park = 832013003968348200 + checkers_in_the_park_dev = 832012682520428625 + checkers_in_the_park_staging = 832012938398400562 + checkers_in_the_park_qa = 832012894068801636 + chess_in_the_park = 832012774040141894 + chess_in_the_park_dev = 832012586023256104 + chest_in_the_park_staging = 832012730599735326 + chest_in_the_park_qa = 832012815819604009 + decoders_dev = 891001866073296967 + doodle_crew = 878067389634314250 + doodle_crew_dev = 878067427668275241 + fishington = 814288819477020702 + letter_tile = 879863686565621790 + ocho = 832025144389533716 + ocho_dev = 832013108234289153 + ocho_staging = 832025061657280566 + ocho_qa = 832025114077298718 + poker_night = 755827207812677713 + poker_night_staging = 763116274876022855 + poker_night_qa = 801133024841957428 + putts = 832012854282158180 + sketchy_artist = 879864070101172255 + sketchy_artist_dev = 879864104980979792 + spell_cast = 852509694341283871 + watch_together = 880218394199220334 + watch_together_dev = 880218832743055411 + word_snacks = 879863976006127627 + word_snacks_dev = 879864010126786570 + youtube_together = 755600276941176913 + + +class ScheduledEventStatus(Enum): + scheduled = 1 + active = 2 + completed = 3 + canceled = 4 + cancelled = 4 + + def __int__(self): + return self.value + + +class ScheduledEventPrivacyLevel(Enum): + guild_only = 2 + + def __int__(self): + return self.value + + +class ScheduledEventLocationType(Enum): + stage_instance = 1 + voice = 2 + external = 3 + + +class AutoModTriggerType(Enum): + keyword = 1 + harmful_link = 2 + spam = 3 + keyword_preset = 4 + + +class AutoModEventType(Enum): + message_send = 1 + + +class AutoModActionType(Enum): + block_message = 1 + send_alert_message = 2 + timeout = 3 + + +class AutoModKeywordPresetType(Enum): + profanity = 1 + sexual_content = 2 + slurs = 3 + + +T = TypeVar("T") + + +def create_unknown_value(cls: Type[T], val: Any) -> T: + value_cls = cls._enum_value_cls_ # type: ignore + name = f"unknown_{val}" + return value_cls(name=name, value=val) + + +def try_enum(cls: Type[T], val: Any) -> T: + """A function that tries to turn the value into enum ``cls``. + + If it fails it returns a proxy invalid value instead. + """ + + try: + return cls._enum_value_map_[val] # type: ignore + except (KeyError, TypeError, AttributeError): + return create_unknown_value(cls, val) diff --git a/discord/errors.py b/discord/errors.py new file mode 100644 index 0000000..0d13944 --- /dev/null +++ b/discord/errors.py @@ -0,0 +1,409 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Union + +if TYPE_CHECKING: + from aiohttp import ClientResponse, ClientWebSocketResponse + + try: + from requests import Response + + _ResponseType = Union[ClientResponse, Response] + except ModuleNotFoundError: + _ResponseType = ClientResponse + + from .interactions import Interaction + +__all__ = ( + "DiscordException", + "ClientException", + "NoMoreItems", + "GatewayNotFound", + "ValidationError", + "HTTPException", + "Forbidden", + "NotFound", + "DiscordServerError", + "InvalidData", + "InvalidArgument", + "LoginFailure", + "ConnectionClosed", + "PrivilegedIntentsRequired", + "InteractionResponded", + "ExtensionError", + "ExtensionAlreadyLoaded", + "ExtensionNotLoaded", + "NoEntryPointError", + "ExtensionFailed", + "ExtensionNotFound", + "ApplicationCommandError", + "CheckFailure", + "ApplicationCommandInvokeError", +) + + +class DiscordException(Exception): + """Base exception class for pycord + + Ideally speaking, this could be caught to handle any exceptions raised from this library. + """ + + +class ClientException(DiscordException): + """Exception that's raised when an operation in the :class:`Client` fails. + + These are usually for exceptions that happened due to user input. + """ + + +class NoMoreItems(DiscordException): + """Exception that is raised when an async iteration operation has no more items.""" + + +class GatewayNotFound(DiscordException): + """An exception that is raised when the gateway for Discord could not be found""" + + def __init__(self): + message = "The gateway to connect to discord was not found." + super().__init__(message) + + +class ValidationError(DiscordException): + """An Exception that is raised when there is a Validation Error.""" + + +def _flatten_error_dict(d: dict[str, Any], key: str = "") -> dict[str, str]: + items: list[tuple[str, str]] = [] + for k, v in d.items(): + new_key = f"{key}.{k}" if key else k + + if isinstance(v, dict): + try: + _errors: list[dict[str, Any]] = v["_errors"] + except KeyError: + items.extend(_flatten_error_dict(v, new_key).items()) + else: + items.append((new_key, " ".join(x.get("message", "") for x in _errors))) + else: + items.append((new_key, v)) + + return dict(items) + + +class HTTPException(DiscordException): + """Exception that's raised when an HTTP request operation fails. + + Attributes + ---------- + response: :class:`aiohttp.ClientResponse` + The response of the failed HTTP request. This is an + instance of :class:`aiohttp.ClientResponse`. In some cases + this could also be a :class:`requests.Response`. + + text: :class:`str` + The text of the error. Could be an empty string. + status: :class:`int` + The status code of the HTTP request. + code: :class:`int` + The Discord specific error code for the failure. + """ + + def __init__(self, response: _ResponseType, message: str | dict[str, Any] | None): + self.response: _ResponseType = response + self.status: int = response.status # type: ignore + self.code: int + self.text: str + if isinstance(message, dict): + self.code = message.get("code", 0) + base = message.get("message", "") + errors = message.get("errors") + if errors: + errors = _flatten_error_dict(errors) + helpful = "\n".join("In %s: %s" % t for t in errors.items()) + self.text = f"{base}\n{helpful}" + else: + self.text = base + else: + self.text = message or "" + self.code = 0 + + fmt = "{0.status} {0.reason} (error code: {1})" + if len(self.text): + fmt += ": {2}" + + super().__init__(fmt.format(self.response, self.code, self.text)) + + +class Forbidden(HTTPException): + """Exception that's raised for when status code 403 occurs. + + Subclass of :exc:`HTTPException` + """ + + +class NotFound(HTTPException): + """Exception that's raised for when status code 404 occurs. + + Subclass of :exc:`HTTPException` + """ + + +class DiscordServerError(HTTPException): + """Exception that's raised for when a 500 range status code occurs. + + Subclass of :exc:`HTTPException`. + + .. versionadded:: 1.5 + """ + + +class InvalidData(ClientException): + """Exception that's raised when the library encounters unknown + or invalid data from Discord. + """ + + +class InvalidArgument(ClientException): + """Exception that's raised when an argument to a function + is invalid some way (e.g. wrong value or wrong type). + + This could be considered the parallel of ``ValueError`` and + ``TypeError`` except inherited from :exc:`ClientException` and thus + :exc:`DiscordException`. + """ + + +class LoginFailure(ClientException): + """Exception that's raised when the :meth:`Client.login` function + fails to log you in from improper credentials or some other misc. + failure. + """ + + +class ConnectionClosed(ClientException): + """Exception that's raised when the gateway connection is + closed for reasons that could not be handled internally. + + Attributes + ---------- + code: :class:`int` + The close code of the websocket. + reason: :class:`str` + The reason provided for the closure. + shard_id: Optional[:class:`int`] + The shard ID that got closed if applicable. + """ + + def __init__( + self, + socket: ClientWebSocketResponse, + *, + shard_id: int | None, + code: int | None = None, + ): + # This exception is just the same exception except + # reconfigured to subclass ClientException for users + self.code: int = code or socket.close_code or -1 + # aiohttp doesn't seem to consistently provide close reason + self.reason: str = "" + self.shard_id: int | None = shard_id + super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}") + + +class PrivilegedIntentsRequired(ClientException): + """Exception that's raised when the gateway is requesting privileged intents, but + they're not ticked in the developer page yet. + + Go to https://discord.com/developers/applications/ and enable the intents + that are required. Currently, these are as follows: + + - :attr:`Intents.members` + - :attr:`Intents.presences` + - :attr:`Intents.message_content` + + Attributes + ---------- + shard_id: Optional[:class:`int`] + The shard ID that got closed if applicable. + """ + + def __init__(self, shard_id: int | None): + self.shard_id: int | None = shard_id + msg = ( + "Shard ID %s is requesting privileged intents that have not been explicitly enabled in the " + "developer portal. It is recommended to go to https://discord.com/developers/applications/ " + "and explicitly enable the privileged intents within your application's page. If this is not " + "possible, then consider disabling the privileged intents instead." + ) + super().__init__(msg % shard_id) + + +class InteractionResponded(ClientException): + """Exception that's raised when sending another interaction response using + :class:`InteractionResponse` when one has already been done before. + + An interaction can only respond once. + + .. versionadded:: 2.0 + + Attributes + ---------- + interaction: :class:`Interaction` + The interaction that's already been responded to. + """ + + def __init__(self, interaction: Interaction): + self.interaction: Interaction = interaction + super().__init__("This interaction has already been responded to before") + + +class ExtensionError(DiscordException): + """Base exception for extension related errors. + + This inherits from :exc:`~discord.DiscordException`. + + Attributes + ---------- + name: :class:`str` + The extension that had an error. + """ + + def __init__(self, message: str | None = None, *args: Any, name: str) -> None: + self.name: str = name + message = message or f"Extension {name!r} had an error." + # clean-up @everyone and @here mentions + m = message.replace("@everyone", "@\u200beveryone").replace( + "@here", "@\u200bhere" + ) + super().__init__(m, *args) + + +class ExtensionAlreadyLoaded(ExtensionError): + """An exception raised when an extension has already been loaded. + + This inherits from :exc:`ExtensionError` + """ + + def __init__(self, name: str) -> None: + super().__init__(f"Extension {name!r} is already loaded.", name=name) + + +class ExtensionNotLoaded(ExtensionError): + """An exception raised when an extension was not loaded. + + This inherits from :exc:`ExtensionError` + """ + + def __init__(self, name: str) -> None: + super().__init__(f"Extension {name!r} has not been loaded.", name=name) + + +class NoEntryPointError(ExtensionError): + """An exception raised when an extension does not have a ``setup`` entry point function. + + This inherits from :exc:`ExtensionError` + """ + + def __init__(self, name: str) -> None: + super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) + + +class ExtensionFailed(ExtensionError): + """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. + + This inherits from :exc:`ExtensionError` + + Attributes + ---------- + name: :class:`str` + The extension that had the error. + original: :exc:`Exception` + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + """ + + def __init__(self, name: str, original: Exception) -> None: + self.original: Exception = original + msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}" + super().__init__(msg, name=name) + + +class ExtensionNotFound(ExtensionError): + """An exception raised when an extension is not found. + + This inherits from :exc:`ExtensionError` + + .. versionchanged:: 1.3 + Made the ``original`` attribute always None. + + Attributes + ---------- + name: :class:`str` + The extension that had the error. + """ + + def __init__(self, name: str) -> None: + msg = f"Extension {name!r} could not be found." + super().__init__(msg, name=name) + + +class ApplicationCommandError(DiscordException): + r"""The base exception type for all application command related errors. + + This inherits from :exc:`DiscordException`. + + This exception and exceptions inherited from it are handled + in a special way as they are caught and passed into a special event + from :class:`.Bot`\, :func:`.on_command_error`. + """ + + +class CheckFailure(ApplicationCommandError): + """Exception raised when the predicates in :attr:`.Command.checks` have failed. + + This inherits from :exc:`ApplicationCommandError` + """ + + +class ApplicationCommandInvokeError(ApplicationCommandError): + """Exception raised when the command being invoked raised an exception. + + This inherits from :exc:`ApplicationCommandError` + + Attributes + ---------- + original: :exc:`Exception` + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + """ + + def __init__(self, e: Exception) -> None: + self.original: Exception = e + super().__init__( + f"Application Command raised an exception: {e.__class__.__name__}: {e}" + ) diff --git a/discord/ext/bridge/__init__.py b/discord/ext/bridge/__init__.py new file mode 100644 index 0000000..b92a536 --- /dev/null +++ b/discord/ext/bridge/__init__.py @@ -0,0 +1,28 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .bot import * +from .context import * +from .core import * diff --git a/discord/ext/bridge/bot.py b/discord/ext/bridge/bot.py new file mode 100644 index 0000000..3db78d5 --- /dev/null +++ b/discord/ext/bridge/bot.py @@ -0,0 +1,111 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from abc import ABC + +from discord.interactions import Interaction +from discord.message import Message + +from ..commands import AutoShardedBot as ExtAutoShardedBot +from ..commands import Bot as ExtBot +from .context import BridgeApplicationContext, BridgeExtContext +from .core import BridgeCommand, BridgeCommandGroup, bridge_command, bridge_group + +__all__ = ("Bot", "AutoShardedBot") + + +class BotBase(ABC): + async def get_application_context( + self, interaction: Interaction, cls=None + ) -> BridgeApplicationContext: + cls = cls if cls is not None else BridgeApplicationContext + # Ignore the type hinting error here. BridgeApplicationContext is a subclass of ApplicationContext, and since + # we gave it cls, it will be used instead. + return await super().get_application_context(interaction, cls=cls) # type: ignore + + async def get_context(self, message: Message, cls=None) -> BridgeExtContext: + cls = cls if cls is not None else BridgeExtContext + # Ignore the type hinting error here. BridgeExtContext is a subclass of Context, and since we gave it cls, it + # will be used instead. + return await super().get_context(message, cls=cls) # type: ignore + + def add_bridge_command(self, command: BridgeCommand): + """Takes a :class:`.BridgeCommand` and adds both a slash and traditional (prefix-based) version of the command + to the bot. + """ + # Ignore the type hinting error here. All subclasses of BotBase pass the type checks. + command.add_to(self) # type: ignore + + def bridge_command(self, **kwargs): + """A shortcut decorator that invokes :func:`bridge_command` and adds it to + the internal command list via :meth:`~.Bot.add_bridge_command`. + + Returns + ------- + Callable[..., :class:`BridgeCommand`] + A decorator that converts the provided method into an :class:`.BridgeCommand`, adds both a slash and + traditional (prefix-based) version of the command to the bot, and returns the :class:`.BridgeCommand`. + """ + + def decorator(func) -> BridgeCommand: + result = bridge_command(**kwargs)(func) + self.add_bridge_command(result) + return result + + return decorator + + def bridge_group(self, **kwargs): + """A decorator that is used to wrap a function as a bridge command group. + + Parameters + ---------- + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommandGroup` and :class:`.ext.commands.Group`) + """ + + def decorator(func) -> BridgeCommandGroup: + result = bridge_group(**kwargs)(func) + self.add_bridge_command(result) + return result + + return decorator + + +class Bot(BotBase, ExtBot): + """Represents a discord bot, with support for cross-compatibility between command types. + + This class is a subclass of :class:`.ext.commands.Bot` and as a result + anything that you can do with a :class:`.ext.commands.Bot` you can do with + this bot. + + .. versionadded:: 2.0 + """ + + +class AutoShardedBot(BotBase, ExtAutoShardedBot): + """This is similar to :class:`.Bot` except that it is inherited from + :class:`.ext.commands.AutoShardedBot` instead. + + .. versionadded:: 2.0 + """ diff --git a/discord/ext/bridge/context.py b/discord/ext/bridge/context.py new file mode 100644 index 0000000..8035572 --- /dev/null +++ b/discord/ext/bridge/context.py @@ -0,0 +1,197 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, overload + +from discord.commands import ApplicationContext +from discord.interactions import Interaction, InteractionMessage +from discord.message import Message +from discord.webhook import WebhookMessage + +from ..commands import Context + +if TYPE_CHECKING: + from .core import BridgeExtCommand, BridgeSlashCommand + + +__all__ = ("BridgeContext", "BridgeExtContext", "BridgeApplicationContext") + + +class BridgeContext(ABC): + """ + The base context class for compatibility commands. This class is an :term:`abstract base class` (also known as an + ``abc``), which is subclassed by :class:`BridgeExtContext` and :class:`BridgeApplicationContext`. The methods in + this class are meant to give parity between the two contexts, while still allowing for all of their functionality. + + When this is passed to a command, it will either be passed as :class:`BridgeExtContext`, or + :class:`BridgeApplicationContext`. Since they are two separate classes, it's easy to use the :attr:`BridgeContext.is_app` attribute. + to make different functionality for each context. For example, if you want to respond to a command with the command + type that it was invoked with, you can do the following: + + .. code-block:: python3 + + @bot.bridge_command() + async def example(ctx: BridgeContext): + if ctx.is_app: + command_type = "Application command" + else: + command_type = "Traditional (prefix-based) command" + await ctx.send(f"This command was invoked with a(n) {command_type}.") + + .. versionadded:: 2.0 + """ + + @abstractmethod + async def _respond(self, *args, **kwargs) -> Interaction | WebhookMessage | Message: + ... + + @abstractmethod + async def _defer(self, *args, **kwargs) -> None: + ... + + @abstractmethod + async def _edit(self, *args, **kwargs) -> InteractionMessage | Message: + ... + + @overload + async def invoke( + self, command: BridgeSlashCommand | BridgeExtCommand, *args, **kwargs + ) -> None: + ... + + async def respond(self, *args, **kwargs) -> Interaction | WebhookMessage | Message: + """|coro| + + Responds to the command with the respective response type to the current context. In :class:`BridgeExtContext`, + this will be :meth:`~.Context.reply` while in :class:`BridgeApplicationContext`, this will be + :meth:`~.ApplicationContext.respond`. + """ + return await self._respond(*args, **kwargs) + + async def reply(self, *args, **kwargs) -> Interaction | WebhookMessage | Message: + """|coro| + + Alias for :meth:`~.BridgeContext.respond`. + """ + return await self.respond(*args, **kwargs) + + async def defer(self, *args, **kwargs) -> None: + """|coro| + + Defers the command with the respective approach to the current context. In :class:`BridgeExtContext`, this will + be :meth:`~discord.abc.Messageable.trigger_typing` while in :class:`BridgeApplicationContext`, this will be + :attr:`~.ApplicationContext.defer`. + + .. note:: + There is no ``trigger_typing`` alias for this method. ``trigger_typing`` will always provide the same + functionality across contexts. + """ + return await self._defer(*args, **kwargs) + + async def edit(self, *args, **kwargs) -> InteractionMessage | Message: + """|coro| + + Edits the original response message with the respective approach to the current context. In + :class:`BridgeExtContext`, this will have a custom approach where :meth:`.respond` caches the message to be + edited here. In :class:`BridgeApplicationContext`, this will be :attr:`~.ApplicationContext.edit`. + """ + return await self._edit(*args, **kwargs) + + def _get_super(self, attr: str) -> Any: + return getattr(super(), attr) + + @property + def is_app(self) -> bool: + """bool: Whether the context is an :class:`BridgeApplicationContext` or not.""" + return isinstance(self, BridgeApplicationContext) + + +class BridgeApplicationContext(BridgeContext, ApplicationContext): + """ + The application context class for compatibility commands. This class is a subclass of :class:`BridgeContext` and + :class:`~.ApplicationContext`. This class is meant to be used with :class:`BridgeCommand`. + + .. versionadded:: 2.0 + """ + + def __init__(self, *args, **kwargs): + # This is needed in order to represent the correct class init signature on the docs + super().__init__(*args, **kwargs) + + async def _respond(self, *args, **kwargs) -> Interaction | WebhookMessage: + return await self._get_super("respond")(*args, **kwargs) + + async def _defer(self, *args, **kwargs) -> None: + return await self._get_super("defer")(*args, **kwargs) + + async def _edit(self, *args, **kwargs) -> InteractionMessage: + return await self._get_super("edit")(*args, **kwargs) + + +class BridgeExtContext(BridgeContext, Context): + """ + The ext.commands context class for compatibility commands. This class is a subclass of :class:`BridgeContext` and + :class:`~.Context`. This class is meant to be used with :class:`BridgeCommand`. + + .. versionadded:: 2.0 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._original_response_message: Message | None = None + + async def _respond(self, *args, **kwargs) -> Message: + kwargs.pop("ephemeral", None) + message = await self._get_super("reply")(*args, **kwargs) + if self._original_response_message is None: + self._original_response_message = message + return message + + async def _defer(self, *args, **kwargs) -> None: + kwargs.pop("ephemeral", None) + return await self._get_super("trigger_typing")(*args, **kwargs) + + async def _edit(self, *args, **kwargs) -> Message | None: + if self._original_response_message: + return await self._original_response_message.edit(*args, **kwargs) + + async def delete( + self, *, delay: float | None = None, reason: str | None = None + ) -> None: + """|coro| + + Deletes the original response message, if it exists. + + Parameters + ---------- + delay: Optional[:class:`float`] + If provided, the number of seconds to wait before deleting the message. + reason: Optional[:class:`str`] + The reason for deleting the message. Shows up on the audit log. + """ + if self._original_response_message: + await self._original_response_message.delete(delay=delay, reason=reason) diff --git a/discord/ext/bridge/core.py b/discord/ext/bridge/core.py new file mode 100644 index 0000000..f305615 --- /dev/null +++ b/discord/ext/bridge/core.py @@ -0,0 +1,533 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Callable + +import discord.commands.options +from discord import ( + ApplicationCommand, + Attachment, + Option, + Permissions, + SlashCommand, + SlashCommandGroup, + SlashCommandOptionType, +) + +from ...utils import filter_params, find, get +from ..commands import BadArgument +from ..commands import Bot as ExtBot +from ..commands import ( + Command, + Context, + Converter, + Group, + GuildChannelConverter, + RoleConverter, + UserConverter, +) +from ..commands.converter import _convert_to_bool, run_converters + +if TYPE_CHECKING: + from .context import BridgeApplicationContext, BridgeExtContext + + +__all__ = ( + "BridgeCommand", + "BridgeCommandGroup", + "bridge_command", + "bridge_group", + "BridgeExtCommand", + "BridgeSlashCommand", + "BridgeExtGroup", + "BridgeSlashGroup", + "map_to", + "guild_only", + "has_permissions", +) + + +class BridgeSlashCommand(SlashCommand): + """A subclass of :class:`.SlashCommand` that is used for bridge commands.""" + + def __init__(self, func, **kwargs): + kwargs = filter_params(kwargs, brief="description") + super().__init__(func, **kwargs) + + +class BridgeExtCommand(Command): + """A subclass of :class:`.ext.commands.Command` that is used for bridge commands.""" + + def __init__(self, func, **kwargs): + kwargs = filter_params(kwargs, description="brief") + super().__init__(func, **kwargs) + + async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: + if param.annotation is Attachment: + # skip the parameter checks for bridge attachments + return await run_converters(ctx, AttachmentConverter, None, param) + else: + return await super().transform(ctx, param) + + +class BridgeSlashGroup(SlashCommandGroup): + """A subclass of :class:`.SlashCommandGroup` that is used for bridge commands.""" + + __slots__ = ("module",) + + def __init__(self, callback, *args, **kwargs): + super().__init__(*args, **kwargs) + self.callback = callback + self.__original_kwargs__["callback"] = callback + self.__command = None + + async def _invoke(self, ctx: BridgeApplicationContext) -> None: + if not (options := ctx.interaction.data.get("options")): + if not self.__command: + self.__command = BridgeSlashCommand(self.callback) + ctx.command = self.__command + return await ctx.command.invoke(ctx) + option = options[0] + resolved = ctx.interaction.data.get("resolved", None) + command = find(lambda x: x.name == option["name"], self.subcommands) + option["resolved"] = resolved + ctx.interaction.data = option + await command.invoke(ctx) + + +class BridgeExtGroup(BridgeExtCommand, Group): + """A subclass of :class:`.ext.commands.Group` that is used for bridge commands.""" + + +class BridgeCommand: + """Compatibility class between prefixed-based commands and slash commands. + + Parameters + ---------- + callback: Callable[[:class:`.BridgeContext`, ...], Awaitable[Any]] + The callback to invoke when the command is executed. The first argument will be a :class:`BridgeContext`, + and any additional arguments will be passed to the callback. This callback must be a coroutine. + parent: Optional[:class:`.BridgeCommandGroup`]: + Parent of the BridgeCommand. + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommand` and :class:`.ext.commands.Command`) + + Attributes + ---------- + slash_variant: :class:`.BridgeSlashCommand` + The slash command version of this bridge command. + ext_variant: :class:`.BridgeExtCommand` + The prefix-based version of this bridge command. + """ + + def __init__(self, callback, **kwargs): + self.parent = kwargs.pop("parent", None) + self.slash_variant: BridgeSlashCommand = kwargs.pop( + "slash_variant", None + ) or BridgeSlashCommand(callback, **kwargs) + self.ext_variant: BridgeExtCommand = kwargs.pop( + "ext_variant", None + ) or BridgeExtCommand(callback, **kwargs) + + @property + def name_localizations(self): + """Dict[:class:`str`, :class:`str`]: Returns name_localizations from :attr:`slash_variant` + + You can edit/set name_localizations directly with + + .. code-block:: python3 + + bridge_command.name_localizations["en-UK"] = ... # or any other locale + # or + bridge_command.name_localizations = {"en-UK": ..., "fr-FR": ...} + """ + return self.slash_variant.name_localizations + + @name_localizations.setter + def name_localizations(self, value): + self.slash_variant.name_localizations = value + + @property + def description_localizations(self): + """Dict[:class:`str`, :class:`str`]: Returns description_localizations from :attr:`slash_variant` + + You can edit/set description_localizations directly with + + .. code-block:: python3 + + bridge_command.description_localizations["en-UK"] = ... # or any other locale + # or + bridge_command.description_localizations = {"en-UK": ..., "fr-FR": ...} + """ + return self.slash_variant.description_localizations + + @description_localizations.setter + def description_localizations(self, value): + self.slash_variant.description_localizations = value + + def add_to(self, bot: ExtBot) -> None: + """Adds the command to a bot. This method is inherited by :class:`.BridgeCommandGroup`. + + Parameters + ---------- + bot: Union[:class:`.Bot`, :class:`.AutoShardedBot`] + The bot to add the command to. + """ + bot.add_application_command(self.slash_variant) + bot.add_command(self.ext_variant) + + async def invoke( + self, ctx: BridgeExtContext | BridgeApplicationContext, /, *args, **kwargs + ): + if ctx.is_app: + return await self.slash_variant.invoke(ctx) + return await self.ext_variant.invoke(ctx) + + def error(self, coro): + """A decorator that registers a coroutine as a local error handler. + + This error handler is limited to the command it is defined to. + However, higher scope handlers (per-cog and global) are still + invoked afterwards as a catch-all. This handler also functions as + the handler for both the prefixed and slash versions of the command. + + This error handler takes two parameters, a :class:`.BridgeContext` and + a :class:`~discord.DiscordException`. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + self.slash_variant.error(coro) + self.ext_variant.on_error = coro + + return coro + + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + + This hook is called directly before the command is called, making + it useful for any sort of set up required. This hook is called + for both the prefixed and slash versions of the command. + + This pre-invoke hook takes a sole parameter, a :class:`.BridgeContext`. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + self.slash_variant.before_invoke(coro) + self.ext_variant._before_invoke = coro + + return coro + + def after_invoke(self, coro): + """A decorator that registers a coroutine as a post-invoke hook. + + This hook is called directly after the command is called, making it + useful for any sort of clean up required. This hook is called for + both the prefixed and slash versions of the command. + + This post-invoke hook takes a sole parameter, a :class:`.BridgeContext`. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + self.slash_variant.after_invoke(coro) + self.ext_variant._after_invoke = coro + + return coro + + +class BridgeCommandGroup(BridgeCommand): + """Compatibility class between prefixed-based commands and slash commands. + + Parameters + ---------- + callback: Callable[[:class:`.BridgeContext`, ...], Awaitable[Any]] + The callback to invoke when the command is executed. The first argument will be a :class:`BridgeContext`, + and any additional arguments will be passed to the callback. This callback must be a coroutine. + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommand` and :class:`.ext.commands.Command`) + + Attributes + ---------- + slash_variant: :class:`.SlashCommandGroup` + The slash command version of this command group. + ext_variant: :class:`.ext.commands.Group` + The prefix-based version of this command group. + subcommands: List[:class:`.BridgeCommand`] + List of bridge commands in this group + mapped: Optional[:class:`.SlashCommand`] + If :func:`map_to` is used, the mapped slash command. + """ + + def __init__(self, callback, *args, **kwargs): + self.ext_variant: BridgeExtGroup = BridgeExtGroup(callback, *args, **kwargs) + name = kwargs.pop("name", self.ext_variant.name) + self.slash_variant: BridgeSlashGroup = BridgeSlashGroup( + callback, name, *args, **kwargs + ) + self.subcommands: list[BridgeCommand] = [] + + self.mapped: SlashCommand | None = None + if map_to := getattr(callback, "__custom_map_to__", None): + kwargs.update(map_to) + self.mapped = self.slash_variant.command(**kwargs)(callback) + + def command(self, *args, **kwargs): + """A decorator to register a function as a subcommand. + + Parameters + ---------- + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommand` and :class:`.ext.commands.Command`) + """ + + def wrap(callback): + slash = self.slash_variant.command( + *args, + **filter_params(kwargs, brief="description"), + cls=BridgeSlashCommand, + )(callback) + ext = self.ext_variant.command( + *args, + **filter_params(kwargs, description="brief"), + cls=BridgeExtCommand, + )(callback) + command = BridgeCommand( + callback, parent=self, slash_variant=slash, ext_variant=ext + ) + self.subcommands.append(command) + return command + + return wrap + + +def bridge_command(**kwargs): + """A decorator that is used to wrap a function as a bridge command. + + Parameters + ---------- + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors. (:class:`.SlashCommand` and :class:`.ext.commands.Command`) + """ + + def decorator(callback): + return BridgeCommand(callback, **kwargs) + + return decorator + + +def bridge_group(**kwargs): + """A decorator that is used to wrap a function as a bridge command group. + + Parameters + ---------- + kwargs: Optional[Dict[:class:`str`, Any]] + Keyword arguments that are directly passed to the respective command constructors (:class:`.SlashCommandGroup` and :class:`.ext.commands.Group`). + """ + + def decorator(callback): + return BridgeCommandGroup(callback, **kwargs) + + return decorator + + +def map_to(name, description=None): + """To be used with bridge command groups, map the main command to a slash subcommand. + + Parameters + ---------- + name: :class:`str` + The new name of the mapped command. + description: Optional[:class:`str`] + The new description of the mapped command. + + Example + ------- + + .. code-block:: python3 + + @bot.bridge_group() + @bridge.map_to("show") + async def config(ctx: BridgeContext): + ... + + @config.command() + async def toggle(ctx: BridgeContext): + ... + + Prefixed commands will not be affected, but slash commands will appear as: + + .. code-block:: + + /config show + /config toggle + """ + + def decorator(callback): + callback.__custom_map_to__ = {"name": name, "description": description} + return callback + + return decorator + + +def guild_only(): + """Intended to work with :class:`.ApplicationCommand` and :class:`BridgeCommand`, adds a :func:`~ext.commands.check` + that locks the command to only run in guilds, and also registers the command as guild only client-side (on discord). + + Basically a utility function that wraps both :func:`discord.ext.commands.guild_only` and :func:`discord.commands.guild_only`. + """ + + def predicate(func: Callable | ApplicationCommand): + if isinstance(func, ApplicationCommand): + func.guild_only = True + else: + func.__guild_only__ = True + + from ..commands import guild_only + + return guild_only()(func) + + return predicate + + +def has_permissions(**perms: dict[str, bool]): + r"""Intended to work with :class:`.SlashCommand` and :class:`BridgeCommand`, adds a + :func:`~ext.commands.check` that locks the command to be run by people with certain + permissions inside guilds, and also registers the command as locked behind said permissions. + + Basically a utility function that wraps both :func:`discord.ext.commands.has_permissions` + and :func:`discord.commands.default_permissions`. + + Parameters + ---------- + \*\*perms: Dict[:class:`str`, :class:`bool`] + An argument list of permissions to check for. + """ + + def predicate(func: Callable | ApplicationCommand): + from ..commands import has_permissions + + func = has_permissions(**perms)(func) + Permissions(**perms) + if isinstance(func, ApplicationCommand): + func.default_member_permissions = perms + else: + func.__default_member_permissions__ = perms + + return perms + + return predicate + + +class MentionableConverter(Converter): + """A converter that can convert a mention to a user or a role.""" + + async def convert(self, ctx, argument): + try: + return await RoleConverter().convert(ctx, argument) + except BadArgument: + return await UserConverter().convert(ctx, argument) + + +class AttachmentConverter(Converter): + async def convert(self, ctx: Context, arg: str): + try: + attach = ctx.message.attachments[0] + except IndexError: + raise BadArgument("At least 1 attachment is needed") + else: + return attach + + +BRIDGE_CONVERTER_MAPPING = { + SlashCommandOptionType.string: str, + SlashCommandOptionType.integer: int, + SlashCommandOptionType.boolean: lambda val: _convert_to_bool(str(val)), + SlashCommandOptionType.user: UserConverter, + SlashCommandOptionType.channel: GuildChannelConverter, + SlashCommandOptionType.role: RoleConverter, + SlashCommandOptionType.mentionable: MentionableConverter, + SlashCommandOptionType.number: float, + SlashCommandOptionType.attachment: AttachmentConverter, +} + + +class BridgeOption(Option, Converter): + async def convert(self, ctx, argument: str) -> Any: + try: + if self.converter is not None: + converted = await self.converter.convert(ctx, argument) + else: + converter = BRIDGE_CONVERTER_MAPPING[self.input_type] + if issubclass(converter, Converter): + converted = await converter().convert(ctx, argument) # type: ignore # protocol class + else: + converted = converter(argument) + + if self.choices: + choices_names: list[str | int | float] = [ + choice.name for choice in self.choices + ] + if converted in choices_names and ( + choice := get(self.choices, name=converted) + ): + converted = choice.value + else: + choices = [choice.value for choice in self.choices] + if converted not in choices: + raise ValueError( + f"{argument} is not a valid choice. Valid choices: {list(set(choices_names + choices))}" + ) + + return converted + except ValueError as exc: + raise BadArgument() from exc + + +discord.commands.options.Option = BridgeOption diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py new file mode 100644 index 0000000..b13b484 --- /dev/null +++ b/discord/ext/commands/__init__.py @@ -0,0 +1,19 @@ +""" +discord.ext.commands +~~~~~~~~~~~~~~~~~~~~~ + +An extension module to facilitate creation of bot commands. + +:copyright: (c) 2015-2021 Rapptz & (c) 2021-present Pycord Development +:license: MIT, see LICENSE for more details. +""" + +from .bot import * +from .cog import * +from .context import * +from .converter import * +from .cooldowns import * +from .core import * +from .errors import * +from .flags import * +from .help import * diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py new file mode 100644 index 0000000..7f86ac6 --- /dev/null +++ b/discord/ext/commands/_types.py @@ -0,0 +1,50 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, Union + +if TYPE_CHECKING: + from .cog import Cog + from .context import Context + from .errors import CommandError + +T = TypeVar("T") + +Coro = Coroutine[Any, Any, T] +MaybeCoro = Union[T, Coro[T]] +CoroFunc = Callable[..., Coro[Any]] + +Check = Union[ + Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], + Callable[["Context[Any]"], MaybeCoro[bool]], +] +Hook = Union[ + Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]] +] +Error = Union[ + Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], + Callable[["Context[Any]", "CommandError"], Coro[Any]], +] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py new file mode 100644 index 0000000..5c3cb24 --- /dev/null +++ b/discord/ext/commands/bot.py @@ -0,0 +1,454 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import collections +import collections.abc +import sys +import traceback +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +import discord + +from . import errors +from .context import Context +from .core import GroupMixin +from .help import DefaultHelpCommand, HelpCommand +from .view import StringView + +if TYPE_CHECKING: + from discord.message import Message + + from ._types import CoroFunc + +__all__ = ( + "when_mentioned", + "when_mentioned_or", + "Bot", + "AutoShardedBot", +) + +MISSING: Any = discord.utils.MISSING + +T = TypeVar("T") +CFT = TypeVar("CFT", bound="CoroFunc") +CXT = TypeVar("CXT", bound="Context") + + +def when_mentioned(bot: Bot | AutoShardedBot, msg: Message) -> list[str]: + """A callable that implements a command prefix equivalent to being mentioned. + + These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. + """ + # bot.user will never be None when this is called + return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore + + +def when_mentioned_or( + *prefixes: str, +) -> Callable[[Bot | AutoShardedBot, Message], list[str]]: + """A callable that implements when mentioned or other prefixes provided. + + These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. + + See Also + -------- + :func:`.when_mentioned` + + Example + ------- + + .. code-block:: python3 + + bot = commands.Bot(command_prefix=commands.when_mentioned_or('!')) + + .. note:: + + This callable returns another callable, so if this is done inside a custom + callable, you must call the returned callable, for example: + + .. code-block:: python3 + + async def get_prefix(bot, message): + extras = await prefixes_for(message.guild) # returns a list + return commands.when_mentioned_or(*extras)(bot, message) + """ + + def inner(bot, msg): + r = list(prefixes) + r = when_mentioned(bot, msg) + r + return r + + return inner + + +def _is_submodule(parent: str, child: str) -> bool: + return parent == child or child.startswith(f"{parent}.") + + +class _DefaultRepr: + def __repr__(self): + return "" + + +_default = _DefaultRepr() + + +class BotBase(GroupMixin, discord.cog.CogMixin): + _supports_prefixed_commands = True + + def __init__(self, command_prefix=when_mentioned, help_command=_default, **options): + super().__init__(**options) + self.command_prefix = command_prefix + self._help_command = None + self.strip_after_prefix = options.get("strip_after_prefix", False) + + if help_command is _default: + self.help_command = DefaultHelpCommand() + else: + self.help_command = help_command + + @discord.utils.copy_doc(discord.Client.close) + async def close(self) -> None: + for extension in tuple(self.__extensions): + try: + self.unload_extension(extension) + except Exception: + pass + + for cog in tuple(self.__cogs): + try: + self.remove_cog(cog) + except Exception: + pass + + await super().close() # type: ignore + + async def on_command_error( + self, context: Context, exception: errors.CommandError + ) -> None: + """|coro| + + The default command error handler provided by the bot. + + By default, this prints to :data:`sys.stderr` however it could be + overridden to have a different implementation. + + This only fires if you do not specify any listeners for command error. + """ + if self.extra_events.get("on_command_error", None): + return + + command = context.command + if command and command.has_error_handler(): + return + + cog = context.cog + if cog and cog.has_error_handler(): + return + + print(f"Ignoring exception in command {context.command}:", file=sys.stderr) + traceback.print_exception( + type(exception), exception, exception.__traceback__, file=sys.stderr + ) + + async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: + data = self._check_once if call_once else self._checks + + if len(data) == 0: + return True + + # type-checker doesn't distinguish between functions and methods + return await discord.utils.async_all(f(ctx) for f in data) # type: ignore + + # help command stuff + + @property + def help_command(self) -> HelpCommand | None: + return self._help_command + + @help_command.setter + def help_command(self, value: HelpCommand | None) -> None: + if value is not None: + if not isinstance(value, HelpCommand): + raise TypeError("help_command must be a subclass of HelpCommand") + if self._help_command is not None: + self._help_command._remove_from_bot(self) + self._help_command = value + value._add_to_bot(self) + elif self._help_command is not None: + self._help_command._remove_from_bot(self) + self._help_command = None + else: + self._help_command = None + + # command processing + + async def get_prefix(self, message: Message) -> list[str] | str: + """|coro| + + Retrieves the prefix the bot is listening to + with the message as a context. + + Parameters + ---------- + message: :class:`discord.Message` + The message context to get the prefix of. + + Returns + ------- + Union[List[:class:`str`], :class:`str`] + A list of prefixes or a single prefix that the bot is + listening for. + """ + prefix = ret = self.command_prefix + if callable(prefix): + ret = await discord.utils.maybe_coroutine(prefix, self, message) + + if not isinstance(ret, str): + try: + ret = list(ret) + except TypeError: + # It's possible that a generator raised this exception. Don't + # replace it with our own error if that's the case. + if isinstance(ret, collections.abc.Iterable): + raise + + raise TypeError( + "command_prefix must be plain string, iterable of strings, or callable " + f"returning either of these, not {ret.__class__.__name__}" + ) + + if not ret: + raise ValueError( + "Iterable command_prefix must contain at least one prefix" + ) + + return ret + + async def get_context(self, message: Message, *, cls: type[CXT] = Context) -> CXT: + r"""|coro| + + Returns the invocation context from the message. + + This is a more low-level counter-part for :meth:`.process_commands` + to allow users more fine-grained control over the processing. + + The returned context is not guaranteed to be a valid invocation + context, :attr:`.Context.valid` must be checked to make sure it is. + If the context is not valid then it is not a valid candidate to be + invoked under :meth:`~.Bot.invoke`. + + Parameters + ----------- + message: :class:`discord.Message` + The message to get the invocation context from. + cls + The factory class that will be used to create the context. + By default, this is :class:`.Context`. Should a custom + class be provided, it must be similar enough to :class:`.Context`\'s + interface. + + Returns + -------- + :class:`.Context` + The invocation context. The type of this can change via the + ``cls`` parameter. + """ + + view = StringView(message.content) + ctx = cls(prefix=None, view=view, bot=self, message=message) + + if message.author.id == self.user.id: # type: ignore + return ctx + + prefix = await self.get_prefix(message) + invoked_prefix = prefix + + if isinstance(prefix, str): + if not view.skip_string(prefix): + return ctx + else: + try: + # if the context class' __init__ consumes something from the view this + # will be wrong. That seems unreasonable though. + if message.content.startswith(tuple(prefix)): + invoked_prefix = discord.utils.find(view.skip_string, prefix) + else: + return ctx + + except TypeError: + if not isinstance(prefix, list): + raise TypeError( + "get_prefix must return either a string or a list of string, " + f"not {prefix.__class__.__name__}" + ) + + # It's possible a bad command_prefix got us here. + for value in prefix: + if not isinstance(value, str): + raise TypeError( + "Iterable command_prefix or list returned from get_prefix must " + f"contain only strings, not {value.__class__.__name__}" + ) + + # Getting here shouldn't happen + raise + + if self.strip_after_prefix: + view.skip_ws() + + invoker = view.get_word() + ctx.invoked_with = invoker + # type-checker fails to narrow invoked_prefix type. + ctx.prefix = invoked_prefix # type: ignore + ctx.command = self.prefixed_commands.get(invoker) + return ctx + + async def invoke(self, ctx: Context) -> None: + """|coro| + + Invokes the command given under the invocation context and + handles all the internal event dispatch mechanisms. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context to invoke. + """ + if ctx.command is not None: + self.dispatch("command", ctx) + try: + if await self.can_run(ctx, call_once=True): + await ctx.command.invoke(ctx) + else: + raise errors.CheckFailure("The global check once functions failed.") + except errors.CommandError as exc: + await ctx.command.dispatch_error(ctx, exc) + else: + self.dispatch("command_completion", ctx) + elif ctx.invoked_with: + exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') + self.dispatch("command_error", ctx, exc) + + async def process_commands(self, message: Message) -> None: + """|coro| + + This function processes the commands that have been registered + to the bot and other groups. Without this coroutine, none of the + commands will be triggered. + + By default, this coroutine is called inside the :func:`.on_message` + event. If you choose to override the :func:`.on_message` event, then + you should invoke this coroutine as well. + + This is built using other low level tools, and is equivalent to a + call to :meth:`~.Bot.get_context` followed by a call to :meth:`~.Bot.invoke`. + + This also checks if the message's author is a bot and doesn't + call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so. + + Parameters + ---------- + message: :class:`discord.Message` + The message to process commands for. + """ + if message.author.bot: + return + + ctx = await self.get_context(message) + await self.invoke(ctx) + + async def on_message(self, message): + await self.process_commands(message) + + +class Bot(BotBase, discord.Bot): + """Represents a discord bot. + + This class is a subclass of :class:`discord.Bot` and as a result + anything that you can do with a :class:`discord.Bot` you can do with + this bot. + + This class also subclasses :class:`.GroupMixin` to provide the functionality + to manage commands. + + .. note:: + + Using prefixed commands requires :attr:`discord.Intents.message_content` to be enabled. + + Attributes + ---------- + command_prefix + The command prefix is what the message content must contain initially + to have a command invoked. This prefix could either be a string to + indicate what the prefix should be, or a callable that takes in the bot + as its first parameter and :class:`discord.Message` as its second + parameter and returns the prefix. This is to facilitate "dynamic" + command prefixes. This callable can be either a regular function or + a coroutine. + + An empty string as the prefix always matches, enabling prefix-less + command invocation. While this may be useful in DMs it should be avoided + in servers, as it's likely to cause performance issues and unintended + command invocations. + + The command prefix could also be an iterable of strings indicating that + multiple checks for the prefix should be used and the first one to + match will be the invocation prefix. You can get this prefix via + :attr:`.Context.prefix`. To avoid confusion empty iterables are not + allowed. + + .. note:: + + When passing multiple prefixes be careful to not pass a prefix + that matches a longer prefix occurring later in the sequence. For + example, if the command prefix is ``('!', '!?')`` the ``'!?'`` + prefix will never be matched to any message as the previous one + matches messages starting with ``!?``. This is especially important + when passing an empty string, it should always be last as no prefix + after it will be matched. + case_insensitive: :class:`bool` + Whether the commands should be case-insensitive. Defaults to ``False``. This + attribute does not carry over to groups. You must set it to every group if + you require group commands to be case-insensitive as well. + help_command: Optional[:class:`.HelpCommand`] + The help command implementation to use. This can be dynamically + set at runtime. To remove the help command pass ``None``. For more + information on implementing a help command, see :ref:`ext_commands_help_command`. + strip_after_prefix: :class:`bool` + Whether to strip whitespace characters after encountering the command + prefix. This allows for ``! hello`` and ``!hello`` to both work if + the ``command_prefix`` is set to ``!``. Defaults to ``False``. + + .. versionadded:: 1.7 + """ + + +class AutoShardedBot(BotBase, discord.AutoShardedBot): + """This is similar to :class:`.Bot` except that it is inherited from + :class:`discord.AutoShardedBot` instead. + """ diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py new file mode 100644 index 0000000..aa61fb2 --- /dev/null +++ b/discord/ext/commands/cog.py @@ -0,0 +1,84 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar + +import discord + +from ...cog import Cog +from ...commands import ApplicationCommand, SlashCommandGroup + +if TYPE_CHECKING: + from .core import Command + +__all__ = ("Cog",) + +CogT = TypeVar("CogT", bound="Cog") +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + +MISSING: Any = discord.utils.MISSING + + +class Cog(Cog): + def __new__(cls: type[CogT], *args: Any, **kwargs: Any) -> CogT: + # For issue 426, we need to store a copy of the command objects + # since we modify them to inject `self` to them. + # To do this, we need to interfere with the Cog creation process. + return super().__new__(cls) + + def walk_commands(self) -> Generator[Command, None, None]: + """An iterator that recursively walks through this cog's commands and subcommands. + + Yields + ------ + Union[:class:`.Command`, :class:`.Group`] + A command or group from the cog. + """ + from .core import GroupMixin + + for command in self.__cog_commands__: + if not isinstance(command, ApplicationCommand): + if command.parent is None: + yield command + if isinstance(command, GroupMixin): + yield from command.walk_commands() + elif isinstance(command, SlashCommandGroup): + yield from command.walk_commands() + else: + yield command + + def get_commands(self) -> list[ApplicationCommand | Command]: + r""" + Returns + -------- + List[Union[:class:`~discord.ApplicationCommand`, :class:`.Command`]] + A :class:`list` of commands that are defined inside this cog. + + .. note:: + + This does not include subcommands. + """ + return [c for c in self.__cog_commands__ if c.parent is None] diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py new file mode 100644 index 0000000..f6a9669 --- /dev/null +++ b/discord/ext/commands/context.py @@ -0,0 +1,404 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import inspect +import re +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union + +import discord.abc +import discord.utils +from discord.message import Message + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord.abc import MessageableChannel + from discord.guild import Guild + from discord.member import Member + from discord.state import ConnectionState + from discord.user import ClientUser, User + from discord.voice_client import VoiceProtocol + + from .bot import AutoShardedBot, Bot + from .cog import Cog + from .core import Command + from .view import StringView + +__all__ = ("Context",) + +MISSING: Any = discord.utils.MISSING + + +T = TypeVar("T") +BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar("CogT", bound="Cog") + +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") + + +class Context(discord.abc.Messageable, Generic[BotT]): + r"""Represents the context in which a command is being invoked under. + + This class contains a lot of metadata to help you understand more about + the invocation context. This class is not created manually and is instead + passed around to commands as the first parameter. + + This class implements the :class:`~discord.abc.Messageable` ABC. + + Attributes + ----------- + message: :class:`.Message` + The message that triggered the command being executed. + bot: :class:`.Bot` + The bot that contains the command being executed. + args: :class:`list` + The list of transformed arguments that were passed into the command. + If this is accessed during the :func:`.on_command_error` event + then this list could be incomplete. + kwargs: :class:`dict` + A dictionary of transformed arguments that were passed into the command. + Similar to :attr:`args`\, if this is accessed in the + :func:`.on_command_error` event then this dict could be incomplete. + current_parameter: Optional[:class:`inspect.Parameter`] + The parameter that is currently being inspected and converted. + This is only of use for within converters. + + .. versionadded:: 2.0 + prefix: Optional[:class:`str`] + The prefix that was used to invoke the command. + command: Optional[:class:`Command`] + The command that is being invoked currently. + invoked_with: Optional[:class:`str`] + The command name that triggered this invocation. Useful for finding out + which alias called the command. + invoked_parents: List[:class:`str`] + The command names of the parents that triggered this invocation. Useful for + finding out which aliases called the command. + + For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``. + + .. versionadded:: 1.7 + + invoked_subcommand: Optional[:class:`Command`] + The subcommand that was invoked. + If no valid subcommand was invoked then this is equal to ``None``. + subcommand_passed: Optional[:class:`str`] + The string that was attempted to call a subcommand. This does not have + to point to a valid registered subcommand and could just point to a + nonsense string. If nothing was passed to attempt a call to a + subcommand then this is set to ``None``. + command_failed: :class:`bool` + A boolean that indicates if the command failed to be parsed, checked, + or invoked. + """ + + def __init__( + self, + *, + message: Message, + bot: BotT, + view: StringView, + args: list[Any] = MISSING, + kwargs: dict[str, Any] = MISSING, + prefix: str | None = None, + command: Command | None = None, + invoked_with: str | None = None, + invoked_parents: list[str] = MISSING, + invoked_subcommand: Command | None = None, + subcommand_passed: str | None = None, + command_failed: bool = False, + current_parameter: inspect.Parameter | None = None, + ): + self.message: Message = message + self.bot: BotT = bot + self.args: list[Any] = args or [] + self.kwargs: dict[str, Any] = kwargs or {} + self.prefix: str | None = prefix + self.command: Command | None = command + self.view: StringView = view + self.invoked_with: str | None = invoked_with + self.invoked_parents: list[str] = invoked_parents or [] + self.invoked_subcommand: Command | None = invoked_subcommand + self.subcommand_passed: str | None = subcommand_passed + self.command_failed: bool = command_failed + self.current_parameter: inspect.Parameter | None = current_parameter + self._state: ConnectionState = self.message._state + + async def invoke( + self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs + ) -> T: + r"""|coro| + + Calls a command with the arguments given. + + This is useful if you want to just call the callback that a + :class:`.Command` holds internally. + + .. note:: + + This does not handle converters, checks, cooldowns, pre-invoke, + or after-invoke hooks in any matter. It calls the internal callback + directly as-if it was a regular function. + + You must take care in passing the proper arguments when + using this function. + + Parameters + ----------- + command: :class:`.Command` + The command that is going to be called. + \*args + The arguments to use. + \*\*kwargs + The keyword arguments to use. + + Raises + ------- + TypeError + The command argument to invoke is missing. + """ + return await command(self, *args, **kwargs) + + async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: + """|coro| + + Calls the command again. + + This is similar to :meth:`~.Context.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + .. note:: + + If you want to bypass :exc:`.UserInputError` derived exceptions, + it is recommended to use the regular :meth:`~.Context.invoke` + as it will work more naturally. After all, this will end up + using the old arguments the user has used and will thus just + fail again. + + Parameters + ---------- + call_hooks: :class:`bool` + Whether to call the before and after invoke hooks. + restart: :class:`bool` + Whether to start the call chain from the very beginning + or where we left off (i.e. the command that caused the error). + The default is to start where we left off. + + Raises + ------ + ValueError + The context to reinvoke is not valid. + """ + cmd = self.command + view = self.view + if cmd is None: + raise ValueError("This context is not valid.") + + # some state to revert to when we're done + index, previous = view.index, view.previous + invoked_with = self.invoked_with + invoked_subcommand = self.invoked_subcommand + invoked_parents = self.invoked_parents + subcommand_passed = self.subcommand_passed + + if restart: + to_call = cmd.root_parent or cmd + view.index = len(self.prefix or "") + view.previous = 0 + self.invoked_parents = [] + self.invoked_with = view.get_word() # advance to get the root command + else: + to_call = cmd + + try: + await to_call.reinvoke(self, call_hooks=call_hooks) + finally: + self.command = cmd + view.index = index + view.previous = previous + self.invoked_with = invoked_with + self.invoked_subcommand = invoked_subcommand + self.invoked_parents = invoked_parents + self.subcommand_passed = subcommand_passed + + @property + def valid(self) -> bool: + """:class:`bool`: Checks if the invocation context is valid to be invoked with.""" + return self.prefix is not None and self.command is not None + + async def _get_channel(self) -> discord.abc.Messageable: + return self.channel + + @property + def clean_prefix(self) -> str: + """:class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``. + + .. versionadded:: 2.0 + """ + if self.prefix is None: + return "" + + user = self.me + # this breaks if the prefix mention is not the bot itself, but I + # consider this to be an *incredibly* strange use case. I'd rather go + # for this common use case rather than waste performance for the + # odd one. + pattern = re.compile(r"<@!?%s>" % user.id) + return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix) + + @property + def cog(self) -> Cog | None: + """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. + None if it does not exist. + """ + + if self.command is None: + return None + return self.command.cog + + @discord.utils.cached_property + def guild(self) -> Guild | None: + """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. + None if not available. + """ + return self.message.guild + + @discord.utils.cached_property + def channel(self) -> MessageableChannel: + """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. + Shorthand for :attr:`.Message.channel`. + """ + return self.message.channel + + @discord.utils.cached_property + def author(self) -> User | Member: + """Union[:class:`~discord.User`, :class:`.Member`]: + Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` + """ + return self.message.author + + @discord.utils.cached_property + def me(self) -> Member | ClientUser: + """Union[:class:`.Member`, :class:`.ClientUser`]: + Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message + message contexts, or when :meth:`Intents.guilds` is absent. + """ + # bot.user will never be None at this point. + return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore + + @property + def voice_client(self) -> VoiceProtocol | None: + r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + g = self.guild + return g.voice_client if g else None + + async def send_help(self, *args: Any) -> Any: + """send_help(entity=) + + |coro| + + Shows the help command for the specified entity if given. + The entity can be a command or a cog. + + If no entity is given, then it'll show help for the + entire bot. + + If the entity is a string, then it looks up whether it's a + :class:`Cog` or a :class:`Command`. + + .. note:: + + Due to the way this function works, instead of returning + something similar to :meth:`~.commands.HelpCommand.command_not_found` + this returns :class:`None` on bad input or no help command. + + Parameters + ---------- + entity: Optional[Union[:class:`Command`, :class:`Cog`, :class:`str`]] + The entity to show help for. + + Returns + ------- + Any + The result of the help command, if any. + """ + from .core import Command, Group, wrap_callback + from .errors import CommandError + + bot = self.bot + cmd = bot.help_command + + if cmd is None: + return None + + cmd = cmd.copy() + cmd.context = self + if len(args) == 0: + await cmd.prepare_help_command(self, None) + mapping = cmd.get_bot_mapping() + injected = wrap_callback(cmd.send_bot_help) + try: + return await injected(mapping) + except CommandError as e: + await cmd.on_help_command_error(self, e) + return None + + entity = args[0] + if isinstance(entity, str): + entity = bot.get_cog(entity) or bot.get_command(entity) + + if entity is None: + return None + + try: + entity.qualified_name + except AttributeError: + # if we're here then it's not a cog, group, or command. + return None + + await cmd.prepare_help_command(self, entity.qualified_name) + + try: + if hasattr(entity, "__cog_commands__"): + injected = wrap_callback(cmd.send_cog_help) + return await injected(entity) + elif isinstance(entity, Group): + injected = wrap_callback(cmd.send_group_help) + return await injected(entity) + elif isinstance(entity, Command): + injected = wrap_callback(cmd.send_command_help) + return await injected(entity) + else: + return None + except CommandError as e: + await cmd.on_help_command_error(self, e) + + @discord.utils.copy_doc(Message.reply) + async def reply(self, content: str | None = None, **kwargs: Any) -> Message: + return await self.message.reply(content, **kwargs) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py new file mode 100644 index 0000000..ee131f0 --- /dev/null +++ b/discord/ext/commands/converter.py @@ -0,0 +1,1255 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import inspect +import re +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + List, + Literal, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + +import discord + +from .errors import * + +if TYPE_CHECKING: + from discord.message import PartialMessageableChannel + + from .context import Context + + +__all__ = ( + "Converter", + "ObjectConverter", + "MemberConverter", + "UserConverter", + "MessageConverter", + "PartialMessageConverter", + "TextChannelConverter", + "ForumChannelConverter", + "InviteConverter", + "GuildConverter", + "RoleConverter", + "GameConverter", + "ColourConverter", + "ColorConverter", + "VoiceChannelConverter", + "StageChannelConverter", + "EmojiConverter", + "PartialEmojiConverter", + "CategoryChannelConverter", + "IDConverter", + "ThreadConverter", + "GuildChannelConverter", + "GuildStickerConverter", + "clean_content", + "Greedy", + "run_converters", +) + + +def _get_from_guilds(bot, getter, argument): + result = None + for guild in bot.guilds: + result = getattr(guild, getter)(argument) + if result: + return result + return result + + +_utils_get = discord.utils.get +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +CT = TypeVar("CT", bound=discord.abc.GuildChannel) +TT = TypeVar("TT", bound=discord.Thread) + + +@runtime_checkable +class Converter(Protocol[T_co]): + """The base class of custom converters that require the :class:`.Context` + to be passed to be useful. + + This allows you to implement converters that function similar to the + special cased ``discord`` classes. + + Classes that derive from this should override the :meth:`~.Converter.convert` + method to do its conversion logic. This method must be a :ref:`coroutine `. + """ + + async def convert(self, ctx: Context, argument: str) -> T_co: + """|coro| + + The method to override to do conversion logic. + + If an error is found while converting, it is recommended to + raise a :exc:`.CommandError` derived exception as it will + properly propagate to the error handlers. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context that the argument is being used in. + argument: :class:`str` + The argument that is being converted. + + Raises + ------ + :exc:`.CommandError` + A generic exception occurred when converting the argument. + :exc:`.BadArgument` + The converter failed to convert the argument. + """ + raise NotImplementedError("Derived classes need to implement this.") + + +_ID_REGEX = re.compile(r"([0-9]{15,20})$") + + +class IDConverter(Converter[T_co]): + @staticmethod + def _get_id_match(argument): + return _ID_REGEX.match(argument) + + +class ObjectConverter(IDConverter[discord.Object]): + """Converts to a :class:`~discord.Object`. + + The argument must follow the valid ID or mention formats (e.g. `<@80088516616269824>`). + + .. versionadded:: 2.0 + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by member, role, or channel mention. + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Object: + match = self._get_id_match(argument) or re.match( + r"<(?:@[!&]?|#)([0-9]{15,20})>$", argument + ) + + if match is None: + raise ObjectNotFound(argument) + + result = int(match.group(1)) + + return discord.Object(id=result) + + +class MemberConverter(IDConverter[discord.Member]): + """Converts to a :class:`~discord.Member`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name#discrim + 4. Lookup by name + 5. Lookup by nickname + + .. versionchanged:: 1.5 + Raise :exc:`.MemberNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.5.1 + This converter now lazily fetches members from the gateway and HTTP APIs, + optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled. + """ + + async def query_member_named(self, guild, argument): + cache = guild._state.member_cache_flags.joined + if len(argument) > 5 and argument[-5] == "#": + username, _, discriminator = argument.rpartition("#") + members = await guild.query_members(username, limit=100, cache=cache) + return discord.utils.get( + members, name=username, discriminator=discriminator + ) + else: + members = await guild.query_members(argument, limit=100, cache=cache) + return discord.utils.find( + lambda m: m.name == argument or m.nick == argument, members + ) + + async def query_member_by_id(self, bot, guild, user_id): + ws = bot._get_websocket(shard_id=guild.shard_id) + cache = guild._state.member_cache_flags.joined + if ws.is_ratelimited(): + # If we're being rate limited on the WS, then fall back to using the HTTP API + # So we don't have to wait ~60 seconds for the query to finish + try: + member = await guild.fetch_member(user_id) + except discord.HTTPException: + return None + + if cache: + guild._add_member(member) + return member + + # If we're not being rate limited then we can use the websocket to actually query + members = await guild.query_members(limit=1, user_ids=[user_id], cache=cache) + if not members: + return None + return members[0] + + async def convert(self, ctx: Context, argument: str) -> discord.Member: + bot = ctx.bot + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) + guild = ctx.guild + result = None + user_id = None + if match is None: + # not a mention... + if guild: + result = guild.get_member_named(argument) + else: + result = _get_from_guilds(bot, "get_member_named", argument) + else: + user_id = int(match.group(1)) + if guild: + result = guild.get_member(user_id) + if ctx.message is not None and result is None: + result = _utils_get(ctx.message.mentions, id=user_id) + else: + result = _get_from_guilds(bot, "get_member", user_id) + + if result is None: + if guild is None: + raise MemberNotFound(argument) + + if user_id is not None: + result = await self.query_member_by_id(bot, guild, user_id) + else: + result = await self.query_member_named(guild, argument) + + if not result: + raise MemberNotFound(argument) + + return result + + +class UserConverter(IDConverter[discord.User]): + """Converts to a :class:`~discord.User`. + + All lookups are via the global user cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name#discrim + 4. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.UserNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.6 + This converter now lazily fetches users from the HTTP APIs if an ID is + passed, and it's not available in cache. + """ + + async def convert(self, ctx: Context, argument: str) -> discord.User: + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) + result = None + state = ctx._state + + if match is not None: + user_id = int(match.group(1)) + result = ctx.bot.get_user(user_id) + if ctx.message is not None and result is None: + result = _utils_get(ctx.message.mentions, id=user_id) + if result is None: + try: + result = await ctx.bot.fetch_user(user_id) + except discord.HTTPException: + raise UserNotFound(argument) from None + + return result + + arg = argument + + # Remove the '@' character if this is the first character from the argument + if arg[0] == "@": + # Remove first character + arg = arg[1:] + + # check for discriminator if it exists, + if len(arg) > 5 and arg[-5] == "#": + discrim = arg[-4:] + name = arg[:-5] + predicate = lambda u: u.name == name and u.discriminator == discrim + result = discord.utils.find(predicate, state._users.values()) + if result is not None: + return result + + predicate = lambda u: u.name == arg + result = discord.utils.find(predicate, state._users.values()) + + if result is None: + raise UserNotFound(argument) + + return result + + +class PartialMessageConverter(Converter[discord.PartialMessage]): + """Converts to a :class:`discord.PartialMessage`. + + .. versionadded:: 1.7 + + The creation strategy is as follows (in order): + + 1. By "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID") + 2. By message ID (The message is assumed to be in the context channel.) + 3. By message URL + """ + + @staticmethod + def _get_id_matches(ctx, argument): + id_regex = re.compile( + r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" + ) + link_regex = re.compile( + r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" + r"(?P[0-9]{15,20}|@me)" + r"/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$" + ) + match = id_regex.match(argument) or link_regex.match(argument) + if not match: + raise MessageNotFound(argument) + data = match.groupdict() + channel_id = data.get("channel_id") + if channel_id is None: + channel_id = ctx.channel and ctx.channel.id + else: + channel_id = int(channel_id) + message_id = int(data["message_id"]) + guild_id = data.get("guild_id") + if guild_id is None: + guild_id = ctx.guild and ctx.guild.id + elif guild_id == "@me": + guild_id = None + else: + guild_id = int(guild_id) + return guild_id, message_id, channel_id + + @staticmethod + def _resolve_channel(ctx, guild_id, channel_id) -> PartialMessageableChannel | None: + if guild_id is not None: + guild = ctx.bot.get_guild(guild_id) + if guild is not None and channel_id is not None: + return guild._resolve_channel(channel_id) # type: ignore + else: + return None + else: + return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel + + async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: + guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) + channel = self._resolve_channel(ctx, guild_id, channel_id) + if not channel: + raise ChannelNotFound(channel_id) + return discord.PartialMessage(channel=channel, id=message_id) + + +class MessageConverter(IDConverter[discord.Message]): + """Converts to a :class:`discord.Message`. + + .. versionadded:: 1.1 + + The lookup strategy is as follows (in order): + + 1. Lookup by "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID") + 2. Lookup by message ID (the message **must** be in the context channel) + 3. Lookup by message URL + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` + instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Message: + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( + ctx, argument + ) + message = ctx.bot._connection._get_message(message_id) + if message: + return message + channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id) + if not channel: + raise ChannelNotFound(channel_id) + try: + return await channel.fetch_message(message_id) + except discord.NotFound: + raise MessageNotFound(argument) + except discord.Forbidden: + raise ChannelNotReadable(channel) + + +class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): + """Converts to a :class:`~discord.abc.GuildChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name. + + .. versionadded:: 2.0 + """ + + async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: + return self._resolve_channel( + ctx, argument, "channels", discord.abc.GuildChannel + ) + + @staticmethod + def _resolve_channel( + ctx: Context, argument: str, attribute: str, type: type[CT] + ) -> CT: + bot = ctx.bot + + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + iterable: Iterable[CT] = getattr(guild, attribute) + result: CT | None = discord.utils.get(iterable, name=argument) + else: + + def check(c): + return isinstance(c, type) and c.name == argument + + result = discord.utils.find(check, bot.get_all_channels()) + else: + channel_id = int(match.group(1)) + if guild: + result = guild.get_channel(channel_id) + else: + result = _get_from_guilds(bot, "get_channel", channel_id) + + if not isinstance(result, type): + raise ChannelNotFound(argument) + + return result + + @staticmethod + def _resolve_thread( + ctx: Context, argument: str, attribute: str, type: type[TT] + ) -> TT: + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + iterable: Iterable[TT] = getattr(guild, attribute) + result: TT | None = discord.utils.get(iterable, name=argument) + else: + thread_id = int(match.group(1)) + if guild: + result = guild.get_thread(thread_id) + + if not result or not isinstance(result, type): + raise ThreadNotFound(argument) + + return result + + +class TextChannelConverter(IDConverter[discord.TextChannel]): + """Converts to a :class:`~discord.TextChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: + return GuildChannelConverter._resolve_channel( + ctx, argument, "text_channels", discord.TextChannel + ) + + +class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): + """Converts to a :class:`~discord.VoiceChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: + return GuildChannelConverter._resolve_channel( + ctx, argument, "voice_channels", discord.VoiceChannel + ) + + +class StageChannelConverter(IDConverter[discord.StageChannel]): + """Converts to a :class:`~discord.StageChannel`. + + .. versionadded:: 1.7 + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + """ + + async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: + return GuildChannelConverter._resolve_channel( + ctx, argument, "stage_channels", discord.StageChannel + ) + + +class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): + """Converts to a :class:`~discord.CategoryChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: + return GuildChannelConverter._resolve_channel( + ctx, argument, "categories", discord.CategoryChannel + ) + + +class ForumChannelConverter(IDConverter[discord.ForumChannel]): + """Converts to a :class:`~discord.ForumChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + + .. versionadded:: 2.0 + """ + + async def convert(self, ctx: Context, argument: str) -> discord.ForumChannel: + return GuildChannelConverter._resolve_channel( + ctx, argument, "forum_channels", discord.ForumChannel + ) + + +class ThreadConverter(IDConverter[discord.Thread]): + """Coverts to a :class:`~discord.Thread`. + + All lookups are via the local guild. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name. + + .. versionadded: 2.0 + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Thread: + return GuildChannelConverter._resolve_thread( + ctx, argument, "threads", discord.Thread + ) + + +class ColourConverter(Converter[discord.Colour]): + """Converts to a :class:`~discord.Colour`. + + .. versionchanged:: 1.5 + Add an alias named ColorConverter + + The following formats are accepted: + + - ``0x`` + - ``#`` + - ``0x#`` + - ``rgb(, , )`` + - Any of the ``classmethod`` in :class:`~discord.Colour` + + - The ``_`` in the name can be optionally replaced with spaces. + + Like CSS, ```` can be either 0-255 or 0-100% and ```` can be + either a 6 digit hex number or a 3 digit hex shortcut (e.g. #fff). + + .. versionchanged:: 1.5 + Raise :exc:`.BadColourArgument` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.7 + Added support for ``rgb`` function and 3-digit hex shortcuts + """ + + RGB_REGEX = re.compile( + r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" + ) + + def parse_hex_number(self, argument): + arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument + try: + value = int(arg, base=16) + if not (0 <= value <= 0xFFFFFF): + raise BadColourArgument(argument) + except ValueError: + raise BadColourArgument(argument) + else: + return discord.Color(value=value) + + def parse_rgb_number(self, argument, number): + if number[-1] == "%": + value = int(number[:-1]) + if not (0 <= value <= 100): + raise BadColourArgument(argument) + return round(255 * (value / 100)) + + value = int(number) + if not (0 <= value <= 255): + raise BadColourArgument(argument) + return value + + def parse_rgb(self, argument, *, regex=RGB_REGEX): + match = regex.match(argument) + if match is None: + raise BadColourArgument(argument) + + red = self.parse_rgb_number(argument, match.group("r")) + green = self.parse_rgb_number(argument, match.group("g")) + blue = self.parse_rgb_number(argument, match.group("b")) + return discord.Color.from_rgb(red, green, blue) + + async def convert(self, ctx: Context, argument: str) -> discord.Colour: + if argument[0] == "#": + return self.parse_hex_number(argument[1:]) + + if argument[0:2] == "0x": + rest = argument[2:] + # Legacy backwards compatible syntax + if rest.startswith("#"): + return self.parse_hex_number(rest[1:]) + return self.parse_hex_number(rest) + + arg = argument.lower() + if arg[0:3] == "rgb": + return self.parse_rgb(arg) + + arg = arg.replace(" ", "_") + method = getattr(discord.Colour, arg, None) + if arg.startswith("from_") or method is None or not inspect.ismethod(method): + raise BadColourArgument(arg) + return method() + + +ColorConverter = ColourConverter + + +class RoleConverter(IDConverter[discord.Role]): + """Converts to a :class:`~discord.Role`. + + All lookups are via the local guild. If in a DM context, the converter raises + :exc:`.NoPrivateMessage` exception. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Role: + guild = ctx.guild + if not guild: + raise NoPrivateMessage() + + match = self._get_id_match(argument) or re.match( + r"<@&([0-9]{15,20})>$", argument + ) + if match: + result = guild.get_role(int(match.group(1))) + else: + result = discord.utils.get(guild._roles.values(), name=argument) + + if result is None: + raise RoleNotFound(argument) + return result + + +class GameConverter(Converter[discord.Game]): + """Converts to :class:`~discord.Game`.""" + + async def convert(self, ctx: Context, argument: str) -> discord.Game: + return discord.Game(name=argument) + + +class InviteConverter(Converter[discord.Invite]): + """Converts to a :class:`~discord.Invite`. + + This is done via an HTTP request using :meth:`.Bot.fetch_invite`. + + .. versionchanged:: 1.5 + Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Invite: + try: + invite = await ctx.bot.fetch_invite(argument) + return invite + except Exception as exc: + raise BadInviteArgument(argument) from exc + + +class GuildConverter(IDConverter[discord.Guild]): + """Converts to a :class:`~discord.Guild`. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by name. (There is no disambiguation for Guilds with multiple matching names). + + .. versionadded:: 1.7 + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Guild: + match = self._get_id_match(argument) + result = None + + if match is not None: + guild_id = int(match.group(1)) + result = ctx.bot.get_guild(guild_id) + + if result is None: + result = discord.utils.get(ctx.bot.guilds, name=argument) + + if result is None: + raise GuildNotFound(argument) + return result + + +class EmojiConverter(IDConverter[discord.Emoji]): + """Converts to a :class:`~discord.Emoji`. + + All lookups are done for the local guild first, if available. If that lookup + fails, then it checks the client's global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by extracting ID from the emoji. + 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.Emoji: + match = self._get_id_match(argument) or re.match( + r"$", argument + ) + result = None + bot = ctx.bot + guild = ctx.guild + + if match is None: + # Try to get the emoji by name. Try local guild first. + if guild: + result = discord.utils.get(guild.emojis, name=argument) + + if result is None: + result = discord.utils.get(bot.emojis, name=argument) + else: + emoji_id = int(match.group(1)) + + # Try to look up emoji by id. + result = bot.get_emoji(emoji_id) + + if result is None: + raise EmojiNotFound(argument) + + return result + + +class PartialEmojiConverter(Converter[discord.PartialEmoji]): + """Converts to a :class:`~discord.PartialEmoji`. + + This is done by extracting the animated flag, name and ID from the emoji. + + .. versionchanged:: 1.5 + Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` + """ + + async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: + match = re.match(r"<(a?):(\w{1,32}):([0-9]{15,20})>$", argument) + + if match: + emoji_animated = bool(match.group(1)) + emoji_name = match.group(2) + emoji_id = int(match.group(3)) + + return discord.PartialEmoji.with_state( + ctx.bot._connection, + animated=emoji_animated, + name=emoji_name, + id=emoji_id, + ) + + raise PartialEmojiConversionFailure(argument) + + +class GuildStickerConverter(IDConverter[discord.GuildSticker]): + """Converts to a :class:`~discord.GuildSticker`. + + All lookups are done for the local guild first, if available. If that lookup + fails, then it checks the client's global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 3. Lookup by name + + .. versionadded:: 2.0 + """ + + async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker: + match = self._get_id_match(argument) + result = None + bot = ctx.bot + guild = ctx.guild + + if match is None: + # Try to get the sticker by name. Try local guild first. + if guild: + result = discord.utils.get(guild.stickers, name=argument) + + if result is None: + result = discord.utils.get(bot.stickers, name=argument) + else: + sticker_id = int(match.group(1)) + + # Try to look up sticker by id. + result = bot.get_sticker(sticker_id) + + if result is None: + raise GuildStickerNotFound(argument) + + return result + + +class clean_content(Converter[str]): + """Converts the argument to mention scrubbed version of + said content. + + This behaves similarly to :attr:`~discord.Message.clean_content`. + + Attributes + ---------- + fix_channel_mentions: :class:`bool` + Whether to clean channel mentions. + use_nicknames: :class:`bool` + Whether to use nicknames when transforming mentions. + escape_markdown: :class:`bool` + Whether to also escape special markdown characters. + remove_markdown: :class:`bool` + Whether to also remove special markdown characters. This option is not supported with ``escape_markdown`` + + .. versionadded:: 1.7 + """ + + def __init__( + self, + *, + fix_channel_mentions: bool = False, + use_nicknames: bool = True, + escape_markdown: bool = False, + remove_markdown: bool = False, + ) -> None: + self.fix_channel_mentions = fix_channel_mentions + self.use_nicknames = use_nicknames + self.escape_markdown = escape_markdown + self.remove_markdown = remove_markdown + + async def convert(self, ctx: Context, argument: str) -> str: + msg = ctx.message + + if ctx.guild: + + def resolve_member(id: int) -> str: + m = ( + None if msg is None else _utils_get(msg.mentions, id=id) + ) or ctx.guild.get_member(id) + return ( + f"@{m.display_name if self.use_nicknames else m.name}" + if m + else "@deleted-user" + ) + + def resolve_role(id: int) -> str: + r = ( + None if msg is None else _utils_get(msg.mentions, id=id) + ) or ctx.guild.get_role(id) + return f"@{r.name}" if r else "@deleted-role" + + else: + + def resolve_member(id: int) -> str: + m = ( + None if msg is None else _utils_get(msg.mentions, id=id) + ) or ctx.bot.get_user(id) + return f"@{m.name}" if m else "@deleted-user" + + def resolve_role(id: int) -> str: + return "@deleted-role" + + if self.fix_channel_mentions and ctx.guild: + + def resolve_channel(id: int) -> str: + c = ctx.guild.get_channel(id) + return f"#{c.name}" if c else "#deleted-channel" + + else: + + def resolve_channel(id: int) -> str: + return f"<#{id}>" + + transforms = { + "@": resolve_member, + "@!": resolve_member, + "#": resolve_channel, + "@&": resolve_role, + } + + def repl(match: re.Match) -> str: + type = match[1] + id = int(match[2]) + transformed = transforms[type](id) + return transformed + + result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument) + if self.escape_markdown: + result = discord.utils.escape_markdown(result) + elif self.remove_markdown: + result = discord.utils.remove_markdown(result) + + # Completely ensure no mentions escape: + return discord.utils.escape_mentions(result) + + +class Greedy(List[T]): + r"""A special converter that greedily consumes arguments until it can't. + As a consequence of this behaviour, most input errors are silently discarded, + since it is used as an indicator of when to stop parsing. + + When a parser error is met the greedy converter stops converting, undoes the + internal string parsing routine, and continues parsing regularly. + + For example, in the following code: + + .. code-block:: python3 + + @commands.command() + async def test(ctx, numbers: Greedy[int], reason: str): + await ctx.send("numbers: {}, reason: {}".format(numbers, reason)) + + An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with + ``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\. + + For more information, check :ref:`ext_commands_special_converters`. + """ + + __slots__ = ("converter",) + + def __init__(self, *, converter: T): + self.converter = converter + + def __repr__(self): + converter = getattr(self.converter, "__name__", repr(self.converter)) + return f"Greedy[{converter}]" + + def __class_getitem__(cls, params: tuple[T] | T) -> Greedy[T]: + if not isinstance(params, tuple): + params = (params,) + if len(params) != 1: + raise TypeError("Greedy[...] only takes a single argument") + converter = params[0] + + origin = getattr(converter, "__origin__", None) + args = getattr(converter, "__args__", ()) + + if not ( + callable(converter) + or isinstance(converter, Converter) + or origin is not None + ): + raise TypeError("Greedy[...] expects a type or a Converter instance.") + + if converter in (str, type(None)) or origin is Greedy: + raise TypeError(f"Greedy[{converter.__name__}] is invalid.") + + if origin is Union and type(None) in args: + raise TypeError(f"Greedy[{converter!r}] is invalid.") + + return cls(converter=converter) + + +def _convert_to_bool(argument: str) -> bool: + lowered = argument.lower() + if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): + return True + elif lowered in ("no", "n", "false", "f", "0", "disable", "off"): + return False + else: + raise BadBoolArgument(lowered) + + +def get_converter(param: inspect.Parameter) -> Any: + converter = param.annotation + if converter is param.empty: + if param.default is not param.empty: + converter = str if param.default is None else type(param.default) + else: + converter = str + return converter + + +_GenericAlias = type(List[T]) + + +def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool: + return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore + + +CONVERTER_MAPPING: dict[type[Any], Any] = { + discord.Object: ObjectConverter, + discord.Member: MemberConverter, + discord.User: UserConverter, + discord.Message: MessageConverter, + discord.PartialMessage: PartialMessageConverter, + discord.TextChannel: TextChannelConverter, + discord.Invite: InviteConverter, + discord.Guild: GuildConverter, + discord.Role: RoleConverter, + discord.Game: GameConverter, + discord.Colour: ColourConverter, + discord.VoiceChannel: VoiceChannelConverter, + discord.StageChannel: StageChannelConverter, + discord.Emoji: EmojiConverter, + discord.PartialEmoji: PartialEmojiConverter, + discord.CategoryChannel: CategoryChannelConverter, + discord.ForumChannel: ForumChannelConverter, + discord.Thread: ThreadConverter, + discord.abc.GuildChannel: GuildChannelConverter, + discord.GuildSticker: GuildStickerConverter, +} + + +async def _actual_conversion( + ctx: Context, converter, argument: str, param: inspect.Parameter +): + if converter is bool: + return _convert_to_bool(argument) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and ( + module.startswith("discord.") and not module.endswith("converter") + ): + converter = CONVERTER_MAPPING.get(converter, converter) + + try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if inspect.ismethod(converter.convert): + return await converter.convert(ctx, argument) + else: + return await converter().convert(ctx, argument) + elif isinstance(converter, Converter): + return await converter.convert(ctx, argument) + except CommandError: + raise + except Exception as exc: + raise ConversionError(converter, exc) from exc + + try: + return converter(argument) + except CommandError: + raise + except Exception as exc: + try: + name = converter.__name__ + except AttributeError: + name = converter.__class__.__name__ + + raise BadArgument( + f'Converting to "{name}" failed for parameter "{param.name}".' + ) from exc + + +async def run_converters( + ctx: Context, converter, argument: str, param: inspect.Parameter +): + """|coro| + + Runs converters for a given converter, argument, and parameter. + + This function does the same work that the library does under the hood. + + .. versionadded:: 2.0 + + Parameters + ---------- + ctx: :class:`Context` + The invocation context to run the converters under. + converter: Any + The converter to run, this corresponds to the annotation in the function. + argument: :class:`str` + The argument to convert to. + param: :class:`inspect.Parameter` + The parameter being converted. This is mainly for error reporting. + + Returns + ------- + Any + The resulting conversion. + + Raises + ------ + CommandError + The converter failed to convert. + """ + origin = getattr(converter, "__origin__", None) + + if origin is Union: + errors = [] + _NoneType = type(None) + union_args = converter.__args__ + for conv in union_args: + # if we got to this part in the code, then the previous conversions have failed, so + # we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.default is param.empty else param.default + + try: + value = await run_converters(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, union_args, errors) + + if origin is Literal: + errors = [] + conversions = {} + literal_args = converter.__args__ + for literal in literal_args: + literal_type = type(literal) + try: + value = conversions[literal_type] + except KeyError: + try: + value = await _actual_conversion(ctx, literal_type, argument, param) + except CommandError as exc: + errors.append(exc) + conversions[literal_type] = object() + continue + else: + conversions[literal_type] = value + + if value == literal: + return value + + # if we're here, then we failed to match all the literals + raise BadLiteralArgument(param, literal_args, errors) + + # This must be the last if-clause in the chain of origin checking + # Nearly every type is a generic type within the typing library + # So care must be taken to make sure a more specialised origin handle + # isn't overwritten by the widest if clause + if origin is not None and is_generic_type(converter): + converter = origin + + return await _actual_conversion(ctx, converter, argument, param) diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py new file mode 100644 index 0000000..90f387e --- /dev/null +++ b/discord/ext/commands/cooldowns.py @@ -0,0 +1,399 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque +from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar + +from discord.enums import Enum + +from ...abc import PrivateChannel +from .errors import MaxConcurrencyReached + +if TYPE_CHECKING: + from ...message import Message + +__all__ = ( + "BucketType", + "Cooldown", + "CooldownMapping", + "DynamicCooldownMapping", + "MaxConcurrency", +) + +C = TypeVar("C", bound="CooldownMapping") +MC = TypeVar("MC", bound="MaxConcurrency") + + +class BucketType(Enum): + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 + category = 5 + role = 6 + + def get_key(self, msg: Message) -> Any: + if self is BucketType.user: + return msg.author.id + elif self is BucketType.guild: + return (msg.guild or msg.author).id + elif self is BucketType.channel: + return msg.channel.id + elif self is BucketType.member: + return (msg.guild and msg.guild.id), msg.author.id + elif self is BucketType.category: + return (msg.channel.category or msg.channel).id # type: ignore + elif self is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are + # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore + + def __call__(self, msg: Message) -> Any: + return self.get_key(msg) + + +class Cooldown: + """Represents a cooldown for a command. + + Attributes + ---------- + rate: :class:`int` + The total number of tokens available per :attr:`per` seconds. + per: :class:`float` + The length of the cooldown period in seconds. + """ + + __slots__ = ("rate", "per", "_window", "_tokens", "_last") + + def __init__(self, rate: float, per: float) -> None: + self.rate: int = int(rate) + self.per: float = float(per) + self._window: float = 0.0 + self._tokens: int = self.rate + self._last: float = 0.0 + + def get_tokens(self, current: float | None = None) -> int: + """Returns the number of available tokens before rate limiting is applied. + + Parameters + ---------- + current: Optional[:class:`float`] + The time in seconds since Unix epoch to calculate tokens at. + If not supplied then :func:`time.time()` is used. + + Returns + ------- + :class:`int` + The number of tokens available before the cooldown is to be applied. + """ + if not current: + current = time.time() + + tokens = self._tokens + + if current > self._window + self.per: + tokens = self.rate + return tokens + + def get_retry_after(self, current: float | None = None) -> float: + """Returns the time in seconds until the cooldown will be reset. + + Parameters + ---------- + current: Optional[:class:`float`] + The current time in seconds since Unix epoch. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + :class:`float` + The number of seconds to wait before this cooldown will be reset. + """ + current = current or time.time() + tokens = self.get_tokens(current) + + if tokens == 0: + return self.per - (current - self._window) + + return 0.0 + + def update_rate_limit(self, current: float | None = None) -> float | None: + """Updates the cooldown rate limit. + + Parameters + ---------- + current: Optional[:class:`float`] + The time in seconds since Unix epoch to update the rate limit at. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + Optional[:class:`float`] + The retry-after time in seconds if rate limited. + """ + current = current or time.time() + self._last = current + + self._tokens = self.get_tokens(current) + + # first token used means that we start a new rate limit window + if self._tokens == self.rate: + self._window = current + + # check if we are rate limited + if self._tokens == 0: + return self.per - (current - self._window) + + # we're not so decrement our tokens + self._tokens -= 1 + + def reset(self) -> None: + """Reset the cooldown to its initial state.""" + self._tokens = self.rate + self._last = 0.0 + + def copy(self) -> Cooldown: + """Creates a copy of this cooldown. + + Returns + ------- + :class:`Cooldown` + A new instance of this cooldown. + """ + return Cooldown(self.rate, self.per) + + def __repr__(self) -> str: + return f"" + + +class CooldownMapping: + def __init__( + self, + original: Cooldown | None, + type: Callable[[Message], Any], + ) -> None: + if not callable(type): + raise TypeError("Cooldown type must be a BucketType or callable") + + self._cache: dict[Any, Cooldown] = {} + self._cooldown: Cooldown | None = original + self._type: Callable[[Message], Any] = type + + def copy(self) -> CooldownMapping: + ret = CooldownMapping(self._cooldown, self._type) + ret._cache = self._cache.copy() + return ret + + @property + def valid(self) -> bool: + return self._cooldown is not None + + @property + def type(self) -> Callable[[Message], Any]: + return self._type + + @classmethod + def from_cooldown(cls: type[C], rate, per, type) -> C: + return cls(Cooldown(rate, per), type) + + def _bucket_key(self, msg: Message) -> Any: + return self._type(msg) + + def _verify_cache_integrity(self, current: float | None = None) -> None: + # we want to delete all cache objects that haven't been used + # in a cooldown window. e.g. if we have a command that has a + # cooldown of 60s, and it has not been used in 60s then that key should be deleted + current = current or time.time() + dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] + for k in dead_keys: + del self._cache[k] + + def create_bucket(self, message: Message) -> Cooldown: + return self._cooldown.copy() # type: ignore + + def get_bucket(self, message: Message, current: float | None = None) -> Cooldown: + if self._type is BucketType.default: + return self._cooldown # type: ignore + + self._verify_cache_integrity(current) + key = self._bucket_key(message) + if key not in self._cache: + bucket = self.create_bucket(message) + if bucket is not None: + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket + + def update_rate_limit( + self, message: Message, current: float | None = None + ) -> float | None: + bucket = self.get_bucket(message, current) + return bucket.update_rate_limit(current) + + +class DynamicCooldownMapping(CooldownMapping): + def __init__( + self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] + ) -> None: + super().__init__(None, type) + self._factory: Callable[[Message], Cooldown] = factory + + def copy(self) -> DynamicCooldownMapping: + ret = DynamicCooldownMapping(self._factory, self._type) + ret._cache = self._cache.copy() + return ret + + @property + def valid(self) -> bool: + return True + + def create_bucket(self, message: Message) -> Cooldown: + return self._factory(message) + + +class _Semaphore: + """This class is a version of a semaphore. + + If you're wondering why asyncio.Semaphore isn't being used, + it's because it doesn't expose the internal value. This internal + value is necessary because I need to support both `wait=True` and + `wait=False`. + + An asyncio.Queue could have been used to do this as well -- but it is + not as inefficient since internally that uses two queues and is a bit + overkill for what is basically a counter. + """ + + __slots__ = ("value", "loop", "_waiters") + + def __init__(self, number: int) -> None: + self.value: int = number + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._waiters: Deque[asyncio.Future] = deque() + + def __repr__(self) -> str: + return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" + + def locked(self) -> bool: + return self.value == 0 + + def is_active(self) -> bool: + return len(self._waiters) > 0 + + def wake_up(self) -> None: + while self._waiters: + future = self._waiters.popleft() + if not future.done(): + future.set_result(None) + return + + async def acquire(self, *, wait: bool = False) -> bool: + if not wait and self.value <= 0: + # signal that we're not acquiring + return False + + while self.value <= 0: + future = self.loop.create_future() + self._waiters.append(future) + try: + await future + except: + future.cancel() + if self.value > 0 and not future.cancelled(): + self.wake_up() + raise + + self.value -= 1 + return True + + def release(self) -> None: + self.value += 1 + self.wake_up() + + +class MaxConcurrency: + __slots__ = ("number", "per", "wait", "_mapping") + + def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: + self._mapping: dict[Any, _Semaphore] = {} + self.per: BucketType = per + self.number: int = number + self.wait: bool = wait + + if number <= 0: + raise ValueError("max_concurrency 'number' cannot be less than 1") + + if not isinstance(per, BucketType): + raise TypeError( + f"max_concurrency 'per' must be of type BucketType not {type(per)!r}" + ) + + def copy(self: MC) -> MC: + return self.__class__(self.number, per=self.per, wait=self.wait) + + def __repr__(self) -> str: + return ( + f"" + ) + + def get_key(self, message: Message) -> Any: + return self.per.get_key(message) + + async def acquire(self, message: Message) -> None: + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + self._mapping[key] = sem = _Semaphore(self.number) + + acquired = await sem.acquire(wait=self.wait) + if not acquired: + raise MaxConcurrencyReached(self.number, self.per) + + async def release(self, message: Message) -> None: + # Technically there's no reason for this function to be async + # But it might be more useful in the future + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + # ...? peculiar + return + else: + sem.release() + + if sem.value >= self.number and not sem.is_active(): + del self._mapping[key] diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py new file mode 100644 index 0000000..eee8d8e --- /dev/null +++ b/discord/ext/commands/core.py @@ -0,0 +1,2497 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import asyncio +import datetime +import functools +import inspect +import types +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Literal, + TypeVar, + Union, + overload, +) + +import discord + +from ...commands import ( + ApplicationCommand, + _BaseCommand, + message_command, + slash_command, + user_command, +) +from ...enums import ChannelType +from ...errors import * +from .cog import Cog +from .context import Context +from .converter import Greedy, get_converter, run_converters +from .cooldowns import ( + BucketType, + Cooldown, + CooldownMapping, + DynamicCooldownMapping, + MaxConcurrency, +) +from .errors import * + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec, TypeGuard + + from discord.message import Message + + from ._types import Check, Coro, CoroFunc, Error, Hook + + +__all__ = ( + "Command", + "Group", + "GroupMixin", + "command", + "group", + "has_role", + "has_permissions", + "has_any_role", + "check", + "check_any", + "before_invoke", + "after_invoke", + "bot_has_role", + "bot_has_permissions", + "bot_has_any_role", + "cooldown", + "dynamic_cooldown", + "max_concurrency", + "dm_only", + "guild_only", + "is_owner", + "is_nsfw", + "has_guild_permissions", + "bot_has_guild_permissions", + "slash_command", + "user_command", + "message_command", +) + +MISSING: Any = discord.utils.MISSING + +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") +CommandT = TypeVar("CommandT", bound="Command") +ContextT = TypeVar("ContextT", bound="Context") +# CHT = TypeVar('CHT', bound='Check') +GroupT = TypeVar("GroupT", bound="Group") +HookT = TypeVar("HookT", bound="Hook") +ErrorT = TypeVar("ErrorT", bound="Error") + +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") + + +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, "__wrapped__"): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def get_signature_parameters( + function: Callable[..., Any], globalns: dict[str, Any] +) -> dict[str, inspect.Parameter]: + signature = inspect.signature(function) + params = {} + cache: dict[str, Any] = {} + eval_annotation = discord.utils.evaluate_annotation + for name, parameter in signature.parameters.items(): + annotation = parameter.annotation + if annotation is parameter.empty: + params[name] = parameter + continue + if annotation is None: + params[name] = parameter.replace(annotation=type(None)) + continue + + annotation = eval_annotation(annotation, globalns, globalns, cache) + if annotation is Greedy: + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") + + params[name] = parameter.replace(annotation=annotation) + + return params + + +def wrap_callback(coro): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except CommandError: + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise CommandInvokeError(exc) from exc + return ret + + return wrapped + + +def hooked_wrapped_callback(command, ctx, coro): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except CommandError: + ctx.command_failed = True + raise + except asyncio.CancelledError: + ctx.command_failed = True + return + except Exception as exc: + ctx.command_failed = True + raise CommandInvokeError(exc) from exc + finally: + if command._max_concurrency is not None: + await command._max_concurrency.release(ctx) + + await command.call_after_hooks(ctx) + return ret + + return wrapped + + +class _CaseInsensitiveDict(dict): + def __contains__(self, k): + return super().__contains__(k.casefold()) + + def __delitem__(self, k): + return super().__delitem__(k.casefold()) + + def __getitem__(self, k): + return super().__getitem__(k.casefold()) + + def get(self, k, default=None): + return super().get(k.casefold(), default) + + def pop(self, k, default=None): + return super().pop(k.casefold(), default) + + def __setitem__(self, k, v): + super().__setitem__(k.casefold(), v) + + +class Command(_BaseCommand, Generic[CogT, P, T]): + r"""A class that implements the protocol for a bot text command. + + These are not created manually, instead they are created via the + decorator or functional interface. + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + help: Optional[:class:`str`] + The long help text for the command. + brief: Optional[:class:`str`] + The short help text for the command. + usage: Optional[:class:`str`] + A replacement for arguments in the default help text. + aliases: Union[List[:class:`str`], Tuple[:class:`str`]] + The list of aliases the command can be invoked under. + enabled: :class:`bool` + A boolean that indicates if the command is currently enabled. + If the command is invoked while it is disabled, then + :exc:`.DisabledCommand` is raised to the :func:`.on_command_error` + event. Defaults to ``True``. + parent: Optional[:class:`Group`] + The parent group that this command belongs to. ``None`` if there + isn't one. + cog: Optional[:class:`Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.Context`], :class:`bool`]] + A list of predicates that verifies if the command could be executed + with the given :class:`.Context` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`.CommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_command_error` + event. + description: :class:`str` + The message prefixed into the default help command. + hidden: :class:`bool` + If ``True``\, the default help command does not show this in the + help output. + rest_is_raw: :class:`bool` + If ``False`` and a keyword-only argument is provided then the keyword + only argument is stripped and handled as if it was a regular argument + that handles :exc:`.MissingRequiredArgument` and default values in a + regular matter rather than passing the rest completely raw. If ``True`` + then the keyword-only argument will pass in the rest of the arguments + in a completely raw matter. Defaults to ``False``. + invoked_subcommand: Optional[:class:`Command`] + The subcommand that was invoked, if any. + require_var_positional: :class:`bool` + If ``True`` and a variadic positional argument is specified, requires + the user to specify at least one argument. Defaults to ``False``. + + .. versionadded:: 1.5 + + ignore_extra: :class:`bool` + If ``True``\, ignores extraneous strings passed to a command if all its + requirements are met (e.g. ``?foo a b c`` when only expecting ``a`` + and ``b``). Otherwise :func:`.on_command_error` and local error handlers + are called with :exc:`.TooManyArguments`. Defaults to ``True``. + cooldown_after_parsing: :class:`bool` + If ``True``\, cooldown processing is done after argument parsing, + which calls converters. If ``False`` then cooldown processing is done + first and then the converters are called second. Defaults to ``False``. + extras: :class:`dict` + A dict of user provided extras to attach to the Command. + + .. note:: + This object may be copied by the library. + + + .. versionadded:: 2.0 + + cooldown: Optional[:class:`Cooldown`] + The cooldown applied when the command is invoked. ``None`` if the command + doesn't have a cooldown. + + .. versionadded:: 2.0 + """ + __original_kwargs__: dict[str, Any] + + def __new__(cls: type[CommandT], *args: Any, **kwargs: Any) -> CommandT: + # if you're wondering why this is done, it's because we need to ensure + # we have a complete original copy of **kwargs even for classes that + # mess with it by popping before delegating to the subclass __init__. + # In order to do this, we need to control the instance creation and + # inject the original kwargs through __new__ rather than doing it + # inside __init__. + self = super().__new__(cls) + + # we do a shallow copy because it's probably the most common use case. + # this could potentially break if someone modifies a list or something + # while it's in movement, but for now this is the cheapest and + # fastest way to do what we want. + self.__original_kwargs__ = kwargs.copy() + return self + + def __init__( + self, + func: ( + Callable[Concatenate[CogT, ContextT, P], Coro[T]] + | Callable[Concatenate[ContextT, P], Coro[T]] + ), + **kwargs: Any, + ): + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + + name = kwargs.get("name") or func.__name__ + if not isinstance(name, str): + raise TypeError("Name of a command must be a string.") + self.name: str = name + + self.callback = func + self.enabled: bool = kwargs.get("enabled", True) + + help_doc = kwargs.get("help") + if help_doc is not None: + help_doc = inspect.cleandoc(help_doc) + else: + help_doc = inspect.getdoc(func) + if isinstance(help_doc, bytes): + help_doc = help_doc.decode("utf-8") + + self.help: str | None = help_doc + + self.brief: str | None = kwargs.get("brief") + self.usage: str | None = kwargs.get("usage") + self.rest_is_raw: bool = kwargs.get("rest_is_raw", False) + self.aliases: list[str] | tuple[str] = kwargs.get("aliases", []) + self.extras: dict[str, Any] = kwargs.get("extras", {}) + + if not isinstance(self.aliases, (list, tuple)): + raise TypeError( + "Aliases of a command must be a list or a tuple of strings." + ) + + self.description: str = inspect.cleandoc(kwargs.get("description", "")) + self.hidden: bool = kwargs.get("hidden", False) + + try: + checks = func.__commands_checks__ + checks.reverse() + except AttributeError: + checks = kwargs.get("checks", []) + + self.checks: list[Check] = checks + + try: + cooldown = func.__commands_cooldown__ + except AttributeError: + cooldown = kwargs.get("cooldown") + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) + self._buckets: CooldownMapping = buckets + + try: + max_concurrency = func.__commands_max_concurrency__ + except AttributeError: + max_concurrency = kwargs.get("max_concurrency") + + self._max_concurrency: MaxConcurrency | None = max_concurrency + + self.require_var_positional: bool = kwargs.get("require_var_positional", False) + self.ignore_extra: bool = kwargs.get("ignore_extra", True) + self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) + self.cog: CogT | None = None + + # bandaid for the fact that sometimes parent can be the bot instance + parent = kwargs.get("parent") + self.parent: GroupMixin | None = parent if isinstance(parent, _BaseCommand) else None # type: ignore + + self._before_invoke: Hook | None = None + try: + before_invoke = func.__before_invoke__ + except AttributeError: + pass + else: + self.before_invoke(before_invoke) + + self._after_invoke: Hook | None = None + try: + after_invoke = func.__after_invoke__ + except AttributeError: + pass + else: + self.after_invoke(after_invoke) + + @property + def callback( + self, + ) -> ( + Callable[Concatenate[CogT, Context, P], Coro[T]] + | Callable[Concatenate[Context, P], Coro[T]] + ): + return self._callback + + @callback.setter + def callback( + self, + function: ( + Callable[Concatenate[CogT, Context, P], Coro[T]] + | Callable[Concatenate[Context, P], Coro[T]] + ), + ) -> None: + self._callback = function + unwrap = unwrap_function(function) + self.module = unwrap.__module__ + + try: + globalns = unwrap.__globals__ + except AttributeError: + globalns = {} + + self.params = get_signature_parameters(function, globalns) + + def add_check(self, func: Check) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`.check`. + + .. versionadded:: 1.3 + + Parameters + ---------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + .. versionadded:: 1.3 + + Parameters + ---------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + + def update(self, **kwargs: Any) -> None: + """Updates :class:`Command` instance with updated attribute. + + This works similarly to the :func:`.command` decorator in terms + of parameters in that they are passed to the :class:`Command` or + subclass constructors, sans the name and callback. + """ + self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) + + async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: + """|coro| + + Calls the internal callback that the command holds. + + .. note:: + + This bypasses all mechanisms -- including checks, converters, + invoke hooks, cooldowns, etc. You must take care to pass + the proper arguments and types to this function. + + .. versionadded:: 1.3 + """ + if self.cog is not None: + return await self.callback(self.cog, context, *args, **kwargs) # type: ignore + else: + return await self.callback(context, *args, **kwargs) # type: ignore + + def _ensure_assignment_on_copy(self, other: CommandT) -> CommandT: + other._before_invoke = self._before_invoke + other._after_invoke = self._after_invoke + if self.checks != other.checks: + other.checks = self.checks.copy() + if self._buckets.valid and not other._buckets.valid: + other._buckets = self._buckets.copy() + if self._max_concurrency != other._max_concurrency: + # _max_concurrency won't be None at this point + other._max_concurrency = self._max_concurrency.copy() # type: ignore + + try: + other.on_error = self.on_error + except AttributeError: + pass + return other + + def copy(self: CommandT) -> CommandT: + """Creates a copy of this command. + + Returns + ------- + :class:`Command` + A new instance of this command. + """ + ret = self.__class__(self.callback, **self.__original_kwargs__) + return self._ensure_assignment_on_copy(ret) + + def _update_copy(self: CommandT, kwargs: dict[str, Any]) -> CommandT: + if kwargs: + kw = kwargs.copy() + kw.update(self.__original_kwargs__) + copy = self.__class__(self.callback, **kw) + return self._ensure_assignment_on_copy(copy) + else: + return self.copy() + + async def dispatch_error(self, ctx: Context, error: Exception) -> None: + ctx.command_failed = True + cog = self.cog + try: + coro = self.on_error + except AttributeError: + pass + else: + injected = wrap_callback(coro) + if cog is not None: + await injected(cog, ctx, error) + else: + await injected(ctx, error) + + try: + if cog is not None: + local = Cog._get_overridden_method(cog.cog_command_error) + if local is not None: + wrapped = wrap_callback(local) + await wrapped(ctx, error) + finally: + ctx.bot.dispatch("command_error", ctx, error) + + async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: + required = param.default is param.empty + converter = get_converter(param) + consume_rest_is_special = ( + param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + ) + view = ctx.view + view.skip_ws() + + # The greedy converter is simple -- it keeps going until it fails in which case, + # it undoes the view ready for the next parameter to use instead + if isinstance(converter, Greedy): + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): + return await self._transform_greedy_pos( + ctx, param, required, converter.converter + ) + elif param.kind == param.VAR_POSITIONAL: + return await self._transform_greedy_var_pos( + ctx, param, converter.converter + ) + else: + # if we're here, then it's a KEYWORD_ONLY param type + # since this is mostly useless, we'll helpfully transform Greedy[X] + # into just X and do the parsing that way. + converter = converter.converter + + if view.eof: + if param.kind == param.VAR_POSITIONAL: + raise RuntimeError() # break the loop + if required: + if self._is_typing_optional(param.annotation): + return None + if ( + hasattr(converter, "__commands_is_flag__") + and converter._can_be_constructible() + ): + return await converter._construct_default(ctx) + raise MissingRequiredArgument(param) + return param.default + + previous = view.index + if consume_rest_is_special: + argument = view.read_rest().strip() + else: + try: + argument = view.get_quoted_word() + except ArgumentParsingError as exc: + if not self._is_typing_optional(param.annotation): + raise exc + view.index = previous + return None + view.previous = previous + + # type-checker fails to narrow argument + return await run_converters(ctx, converter, argument, param) # type: ignore + + async def _transform_greedy_pos( + self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any + ) -> Any: + view = ctx.view + result = [] + while not view.eof: + # for use with a manual undo + previous = view.index + + view.skip_ws() + try: + argument = view.get_quoted_word() + value = await run_converters(ctx, converter, argument, param) # type: ignore + except (CommandError, ArgumentParsingError): + view.index = previous + break + else: + result.append(value) + + if not result and not required: + return param.default + return result + + async def _transform_greedy_var_pos( + self, ctx: Context, param: inspect.Parameter, converter: Any + ) -> Any: + view = ctx.view + previous = view.index + try: + argument = view.get_quoted_word() + value = await run_converters(ctx, converter, argument, param) # type: ignore + except (CommandError, ArgumentParsingError): + view.index = previous + raise RuntimeError() from None # break loop + else: + return value + + @property + def clean_params(self) -> dict[str, inspect.Parameter]: + """Dict[:class:`str`, :class:`inspect.Parameter`]: + Retrieves the parameter dictionary without the context or self parameters. + + Useful for inspecting signature. + """ + result = self.params.copy() + if self.cog is not None: + # first parameter is self + try: + del result[next(iter(result))] + except StopIteration: + raise ValueError("missing 'self' parameter") from None + + try: + # first/second parameter is context + del result[next(iter(result))] + except StopIteration: + raise ValueError("missing 'context' parameter") from None + + return result + + @property + def full_parent_name(self) -> str: + """:class:`str`: Retrieves the fully qualified parent command name. + + This the base command name required to execute it. For example, + in ``?one two three`` the parent name would be ``one two``. + """ + entries = [] + command = self + # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore + + return " ".join(reversed(entries)) + + @property + def parents(self) -> list[Group]: + """List[:class:`Group`]: Retrieves the parents of this command. + + If the command has no parents then it returns an empty :class:`list`. + + For example in commands ``?a b c test``, the parents are ``[c, b, a]``. + + .. versionadded:: 1.1 + """ + entries = [] + command = self + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command) + + return entries + + @property + def root_parent(self) -> Group | None: + """Optional[:class:`Group`]: Retrieves the root parent of this command. + + If the command has no parents then it returns ``None``. + + For example in commands ``?a b c test``, the root parent is ``a``. + """ + if not self.parent: + return None + return self.parents[-1] + + @property + def qualified_name(self) -> str: + """:class:`str`: Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + + parent = self.full_parent_name + if parent: + return f"{parent} {self.name}" + else: + return self.name + + def __str__(self) -> str: + return self.qualified_name + + async def _parse_arguments(self, ctx: Context) -> None: + ctx.args = [ctx] if self.cog is None else [self.cog, ctx] + ctx.kwargs = {} + args = ctx.args + kwargs = ctx.kwargs + + view = ctx.view + iterator = iter(self.params.items()) + + if self.cog is not None: + # we have 'self' as the first parameter so just advance + # the iterator and resume parsing + try: + next(iterator) + except StopIteration: + raise discord.ClientException( + f'Callback for {self.name} command is missing "self" parameter.' + ) + + # next we have the 'ctx' as the next parameter + try: + next(iterator) + except StopIteration: + raise discord.ClientException( + f'Callback for {self.name} command is missing "ctx" parameter.' + ) + + for name, param in iterator: + ctx.current_parameter = param + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): + transformed = await self.transform(ctx, param) + args.append(transformed) + elif param.kind == param.KEYWORD_ONLY: + # kwarg only param denotes "consume rest" semantics + if self.rest_is_raw: + converter = get_converter(param) + argument = view.read_rest() + kwargs[name] = await run_converters(ctx, converter, argument, param) + else: + kwargs[name] = await self.transform(ctx, param) + break + elif param.kind == param.VAR_POSITIONAL: + if view.eof and self.require_var_positional: + raise MissingRequiredArgument(param) + while not view.eof: + try: + transformed = await self.transform(ctx, param) + args.append(transformed) + except RuntimeError: + break + + if not self.ignore_extra and not view.eof: + raise TooManyArguments( + f"Too many arguments passed to {self.qualified_name}" + ) + + async def call_before_hooks(self, ctx: Context) -> None: + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.cog + if self._before_invoke is not None: + # should be cog if @commands.before_invoke is used + instance = getattr(self._before_invoke, "__self__", cog) + # __self__ only exists for methods, not functions + # however, if @command.before_invoke is used, it will be a function + if instance: + await self._before_invoke(instance, ctx) # type: ignore + else: + await self._before_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = Cog._get_overridden_method(cog.cog_before_invoke) + if hook is not None: + await hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + await hook(ctx) + + async def call_after_hooks(self, ctx: Context) -> None: + cog = self.cog + if self._after_invoke is not None: + instance = getattr(self._after_invoke, "__self__", cog) + if instance: + await self._after_invoke(instance, ctx) # type: ignore + else: + await self._after_invoke(ctx) # type: ignore + + # call the cog local hook if applicable: + if cog is not None: + hook = Cog._get_overridden_method(cog.cog_after_invoke) + if hook is not None: + await hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + await hook(ctx) + + def _prepare_cooldowns(self, ctx: Context) -> None: + if self._buckets.valid: + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + bucket = self._buckets.get_bucket(ctx.message, current) + if bucket is not None: + retry_after = bucket.update_rate_limit(current) + if retry_after: + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + + async def prepare(self, ctx: Context) -> None: + ctx.command = self + + if not await self.can_run(ctx): + raise CheckFailure( + f"The check functions for command {self.qualified_name} failed." + ) + + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore + + try: + if self.cooldown_after_parsing: + await self._parse_arguments(ctx) + self._prepare_cooldowns(ctx) + else: + self._prepare_cooldowns(ctx) + await self._parse_arguments(ctx) + + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) # type: ignore + raise + + @property + def cooldown(self) -> Cooldown | None: + return self._buckets._cooldown + + def is_on_cooldown(self, ctx: Context) -> bool: + """Checks whether the command is currently on cooldown. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context to use when checking the command's cooldown status. + + Returns + ------- + :class:`bool` + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False + + bucket = self._buckets.get_bucket(ctx.message) + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_tokens(current) == 0 + + def reset_cooldown(self, ctx: Context) -> None: + """Resets the cooldown on this command. + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + bucket.reset() + + def get_cooldown_retry_after(self, ctx: Context) -> float: + """Retrieves the amount of seconds before this command can be tried again. + + .. versionadded:: 1.4 + + Parameters + ---------- + ctx: :class:`.Context` + The invocation context to retrieve the cooldown from. + + Returns + ------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_retry_after(current) + + return 0.0 + + async def invoke(self, ctx: Context) -> None: + await self.prepare(ctx) + + # terminate the invoked_subcommand chain. + # since we're in a regular command (and not a group) then + # the invoked subcommand is None. + ctx.invoked_subcommand = None + ctx.subcommand_passed = None + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) + + async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: + ctx.command = self + await self._parse_arguments(ctx) + + if call_hooks: + await self.call_before_hooks(ctx) + + ctx.invoked_subcommand = None + try: + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + await self.call_after_hooks(ctx) + + def error(self, coro: ErrorT) -> ErrorT: + """A decorator that registers a coroutine as a local error handler. + + A local error handler is an :func:`.on_command_error` event limited to + a single command. However, the :func:`.on_command_error` is still + invoked afterwards as the catch-all. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The error handler must be a coroutine.") + + self.on_error: Error = coro + return coro + + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the command has an error handler registered. + + .. versionadded:: 1.7 + """ + return hasattr(self, "on_error") + + def before_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.before_invoke` for more info. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the pre-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro: HookT) -> HookT: + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.after_invoke` for more info. + + Parameters + ---------- + coro: :ref:`coroutine ` + The coroutine to register as the post-invoke hook. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise TypeError("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + @property + def cog_name(self) -> str | None: + """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" + return type(self.cog).__cog_name__ if self.cog is not None else None + + @property + def short_doc(self) -> str: + """:class:`str`: Gets the "short" documentation of a command. + + By default, this is the :attr:`.brief` attribute. + If that lookup leads to an empty string then the first line of the + :attr:`.help` attribute is used instead. + """ + if self.brief is not None: + return self.brief + if self.help is not None: + return self.help.split("\n", 1)[0] + return "" + + def _is_typing_optional(self, annotation: T | T | None) -> TypeGuard[T | None]: + return ( + getattr(annotation, "__origin__", None) is Union + or type(annotation) is getattr(types, "UnionType", Union) + ) and type( + None + ) in annotation.__args__ # type: ignore + + @property + def signature(self) -> str: + """:class:`str`: Returns a POSIX-like signature useful for help command output.""" + if self.usage is not None: + return self.usage + + params = self.clean_params + if not params: + return "" + + result = [] + for name, param in params.items(): + greedy = isinstance(param.annotation, Greedy) + optional = False # postpone evaluation of if it's an optional argument + + # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the + # parameter signature is a literal list of it's values + annotation = param.annotation.converter if greedy else param.annotation + origin = getattr(annotation, "__origin__", None) + if not greedy and origin is Union: + none_cls = type(None) + union_args = annotation.__args__ + optional = union_args[-1] is none_cls + if len(union_args) == 2 and optional: + annotation = union_args[0] + origin = getattr(annotation, "__origin__", None) + + if origin is Literal: + name = "|".join( + f'"{v}"' if isinstance(v, str) else str(v) + for v in annotation.__args__ + ) + if param.default is not param.empty: + # We don't want None or '' to trigger the [name=value] case, and instead it should + # do [name] since [name=None] or [name=] are not exactly useful for the user. + should_print = ( + param.default + if isinstance(param.default, str) + else param.default is not None + ) + if should_print: + result.append( + f"[{name}={param.default}]" + if not greedy + else f"[{name}={param.default}]..." + ) + continue + else: + result.append(f"[{name}]") + + elif param.kind == param.VAR_POSITIONAL: + if self.require_var_positional: + result.append(f"<{name}...>") + else: + result.append(f"[{name}...]") + elif greedy: + result.append(f"[{name}]...") + elif optional: + result.append(f"[{name}]") + else: + result.append(f"<{name}>") + + return " ".join(result) + + async def can_run(self, ctx: Context) -> bool: + """|coro| + + Checks if the command can be executed by checking all the predicates + inside the :attr:`~Command.checks` attribute. This also checks whether the + command is disabled. + + .. versionchanged:: 1.3 + Checks whether the command is disabled or not + + Parameters + ---------- + ctx: :class:`.Context` + The ctx of the command currently being invoked. + + Returns + ------- + :class:`bool` + A boolean indicating if the command can be invoked. + + Raises + ------ + :class:`CommandError` + Any command error that was raised during a check call will be propagated + by this function. + """ + + if not self.enabled: + raise DisabledCommand(f"{self.name} command is disabled") + + original = ctx.command + ctx.command = self + + try: + if not await ctx.bot.can_run(ctx): + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) + + cog = self.cog + if cog is not None: + local_check = Cog._get_overridden_method(cog.cog_check) + if local_check is not None: + ret = await discord.utils.maybe_coroutine(local_check, ctx) + if not ret: + return False + + predicates = self.checks + if not predicates: + # since we have no checks, then we just return True. + return True + + return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore + finally: + ctx.command = original + + def _set_cog(self, cog): + self.cog = cog + + +class GroupMixin(Generic[CogT]): + """A mixin that implements common functionality for classes that behave + similar to :class:`.Group` and are allowed to register commands. + + Attributes + ---------- + all_commands: :class:`dict` + A mapping of command name to :class:`.Command` + objects. + case_insensitive: :class:`bool` + Whether the commands should be case-insensitive. Defaults to ``False``. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + case_insensitive = kwargs.get("case_insensitive", False) + self.prefixed_commands: dict[str, Command[CogT, Any, Any]] = ( + _CaseInsensitiveDict() if case_insensitive else {} + ) + self.case_insensitive: bool = case_insensitive + super().__init__(*args, **kwargs) + + @property + def all_commands(self): + # merge app and prefixed commands + if hasattr(self, "_application_commands"): + return {**self._application_commands, **self.prefixed_commands} + return self.prefixed_commands + + @property + def commands(self) -> set[Command[CogT, Any, Any]]: + """Set[:class:`.Command`]: A unique set of commands without aliases that are registered.""" + return set(self.prefixed_commands.values()) + + def recursively_remove_all_commands(self) -> None: + for command in self.prefixed_commands.copy().values(): + if isinstance(command, GroupMixin): + command.recursively_remove_all_commands() + self.remove_command(command.name) + + def add_command(self, command: Command[CogT, Any, Any]) -> None: + """Adds a :class:`.Command` into the internal list of commands. + + This is usually not called, instead the :meth:`~.GroupMixin.command` or + :meth:`~.GroupMixin.group` shortcut decorators are used instead. + + .. versionchanged:: 1.4 + Raise :exc:`.CommandRegistrationError` instead of generic :exc:`.ClientException` + + Parameters + ---------- + command: :class:`Command` + The command to add. + + Raises + ------ + :exc:`.CommandRegistrationError` + If the command or its alias is already registered by different command. + TypeError + If the command passed is not a subclass of :class:`.Command`. + """ + + if not isinstance(command, Command): + raise TypeError("The command passed must be a subclass of Command") + + if isinstance(self, Command): + command.parent = self + + if command.name in self.prefixed_commands: + raise CommandRegistrationError(command.name) + + self.prefixed_commands[command.name] = command + for alias in command.aliases: + if alias in self.prefixed_commands: + self.remove_command(command.name) + raise CommandRegistrationError(alias, alias_conflict=True) + self.prefixed_commands[alias] = command + + def remove_command(self, name: str) -> Command[CogT, Any, Any] | None: + """Remove a :class:`.Command` from the internal list + of commands. + + This could also be used as a way to remove aliases. + + Parameters + ---------- + name: :class:`str` + The name of the command to remove. + + Returns + ------- + Optional[:class:`.Command`] + The command that was removed. If the name is not valid then + ``None`` is returned instead. + """ + command = self.prefixed_commands.pop(name, None) + + # does not exist + if command is None: + return None + + if name in command.aliases: + # we're removing an alias, so we don't want to remove the rest + return command + + # we're not removing the alias so let's delete the rest of them. + for alias in command.aliases: + cmd = self.prefixed_commands.pop(alias, None) + # in the case of a CommandRegistrationError, an alias might conflict + # with an already existing command. If this is the case, we want to + # make sure the pre-existing command is not removed. + if cmd is not None and cmd != command: + self.prefixed_commands[alias] = cmd + return command + + def walk_commands(self) -> Generator[Command[CogT, Any, Any], None, None]: + """An iterator that recursively walks through all commands and subcommands. + + .. versionchanged:: 1.4 + Duplicates due to aliases are no longer returned + + Yields + ------ + Union[:class:`.Command`, :class:`.Group`] + A command or group from the internal list of commands. + """ + for command in self.commands: + yield command + if isinstance(command, GroupMixin): + yield from command.walk_commands() + + def get_command(self, name: str) -> Command[CogT, Any, Any] | None: + """Get a :class:`.Command` from the internal list + of commands. + + This could also be used as a way to get aliases. + + The name could be fully qualified (e.g. ``'foo bar'``) will get + the subcommand ``bar`` of the group command ``foo``. If a + subcommand is not found then ``None`` is returned just as usual. + + Parameters + ---------- + name: :class:`str` + The name of the command to get. + + Returns + ------- + Optional[:class:`Command`] + The command that was requested. If not found, returns ``None``. + """ + + # fast path, no space in name. + if " " not in name: + return self.prefixed_commands.get(name) + + names = name.split() + if not names: + return None + obj = self.prefixed_commands.get(names[0]) + if not isinstance(obj, GroupMixin): + return obj + + for name in names[1:]: + try: + obj = obj.prefixed_commands[name] # type: ignore + except (AttributeError, KeyError): + return None + + return obj + + @overload + def command( + self, + name: str = ..., + cls: type[Command[CogT, P, T]] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[T]] + | Callable[Concatenate[ContextT, P], Coro[T]] + ) + ], + Command[CogT, P, T], + ]: + ... + + @overload + def command( + self, + name: str = ..., + cls: type[CommandT] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: + ... + + def command( + self, + name: str = MISSING, + cls: type[CommandT] = MISSING, + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]: + """A shortcut decorator that invokes :func:`.command` and adds it to + the internal command list via :meth:`~.GroupMixin.add_command`. + + Returns + ------- + Callable[..., :class:`Command`] + A decorator that converts the provided method into a Command, adds it to the bot, then returns it. + """ + + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: + kwargs.setdefault("parent", self) + result = command(name=name, cls=cls, *args, **kwargs)(func) + self.add_command(result) + return result + + return decorator + + @overload + def group( + self, + name: str = ..., + cls: type[Group[CogT, P, T]] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[T]] + | Callable[Concatenate[ContextT, P], Coro[T]] + ) + ], + Group[CogT, P, T], + ]: + ... + + @overload + def group( + self, + name: str = ..., + cls: type[GroupT] = ..., + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: + ... + + def group( + self, + name: str = MISSING, + cls: type[GroupT] = MISSING, + *args: Any, + **kwargs: Any, + ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]: + """A shortcut decorator that invokes :func:`.group` and adds it to + the internal command list via :meth:`~.GroupMixin.add_command`. + + Returns + ------- + Callable[..., :class:`Group`] + A decorator that converts the provided method into a Group, adds it to the bot, then returns it. + """ + + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: + kwargs.setdefault("parent", self) + result = group(name=name, cls=cls, *args, **kwargs)(func) + self.add_command(result) + return result + + return decorator + + +class Group(GroupMixin[CogT], Command[CogT, P, T]): + """A class that implements a grouping protocol for commands to be + executed as subcommands. + + This class is a subclass of :class:`.Command` and thus all options + valid in :class:`.Command` are valid in here as well. + + Attributes + ---------- + invoke_without_command: :class:`bool` + Indicates if the group callback should begin parsing and + invocation only if no subcommand was found. Useful for + making it an error handling function to tell the user that + no subcommand was found or to have different functionality + in case no subcommand was found. If this is ``False``, then + the group callback will always be invoked first. This means + that the checks and the parsing dictated by its parameters + will be executed. Defaults to ``False``. + case_insensitive: :class:`bool` + Indicates if the group's commands should be case-insensitive. + Defaults to ``False``. + """ + + def __init__(self, *args: Any, **attrs: Any) -> None: + self.invoke_without_command: bool = attrs.pop("invoke_without_command", False) + super().__init__(*args, **attrs) + + def copy(self: GroupT) -> GroupT: + """Creates a copy of this :class:`Group`. + + Returns + ------- + :class:`Group` + A new instance of this group. + """ + ret = super().copy() + for cmd in self.commands: + ret.add_command(cmd.copy()) + return ret # type: ignore + + async def invoke(self, ctx: Context) -> None: + ctx.invoked_subcommand = None + ctx.subcommand_passed = None + early_invoke = not self.invoke_without_command + if early_invoke: + await self.prepare(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None) + + if early_invoke: + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) + + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + await ctx.invoked_subcommand.invoke(ctx) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + await super().invoke(ctx) + + async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: + ctx.invoked_subcommand = None + early_invoke = not self.invoke_without_command + if early_invoke: + ctx.command = self + await self._parse_arguments(ctx) + + if call_hooks: + await self.call_before_hooks(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None) + + if early_invoke: + try: + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + await self.call_after_hooks(ctx) + + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + await ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + await super().reinvoke(ctx, call_hooks=call_hooks) + + +# Decorators + + +@overload # for py 3.10 +def command( + name: str = ..., + cls: type[Command[CogT, P, T]] = ..., + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P]] + | Coro[T] + | Callable[Concatenate[ContextT, P]] + | Coro[T] + ) + ], + Command[CogT, P, T], +]: + ... + + +@overload +def command( + name: str = ..., + cls: type[Command[CogT, P, T]] = ..., + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[T]] + | Callable[Concatenate[ContextT, P], Coro[T]] + ) + ], + Command[CogT, P, T], +]: + ... + + +@overload +def command( + name: str = ..., + cls: type[CommandT] = ..., + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[Any]] + | Callable[Concatenate[ContextT, P], Coro[Any]] + ) + ], + CommandT, +]: + ... + + +def command( + name: str = MISSING, cls: type[CommandT] = MISSING, **attrs: Any +) -> Callable[ + [ + ( + Callable[Concatenate[ContextT, P], Coro[Any]] + | Callable[Concatenate[CogT, ContextT, P], Coro[T]] + ) + ], + Command[CogT, P, T] | CommandT, +]: + """A decorator that transforms a function into a :class:`.Command` + or if called with :func:`.group`, :class:`.Group`. + + By default the ``help`` attribute is received automatically from the + docstring of the function and is cleaned up with the use of + ``inspect.cleandoc``. If the docstring is ``bytes``, then it is decoded + into :class:`str` using utf-8 encoding. + + All checks added using the :func:`.check` & co. decorators are added into + the function. There is no way to supply your own checks through this + decorator. + + Parameters + ---------- + name: :class:`str` + The name to create the command with. By default, this uses the + function name unchanged. + cls + The class to construct with. By default, this is :class:`.Command`. + You usually do not change this. + attrs + Keyword arguments to pass into the construction of the class denoted + by ``cls``. + + Raises + ------ + TypeError + If the function is not a coroutine or is already a command. + """ + if cls is MISSING: + cls = Command # type: ignore + + def decorator( + func: ( + Callable[Concatenate[ContextT, P], Coro[Any]] + | Callable[Concatenate[CogT, ContextT, P], Coro[Any]] + ) + ) -> CommandT: + if isinstance(func, Command): + raise TypeError("Callback is already a command.") + return cls(func, name=name, **attrs) + + return decorator + + +@overload +def group( + name: str = ..., + cls: type[Group[CogT, P, T]] = ..., + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[T]] + | Callable[Concatenate[ContextT, P], Coro[T]] + ) + ], + Group[CogT, P, T], +]: + ... + + +@overload +def group( + name: str = ..., + cls: type[GroupT] = ..., + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[CogT, ContextT, P], Coro[Any]] + | Callable[Concatenate[ContextT, P], Coro[Any]] + ) + ], + GroupT, +]: + ... + + +def group( + name: str = MISSING, + cls: type[GroupT] = MISSING, + **attrs: Any, +) -> Callable[ + [ + ( + Callable[Concatenate[ContextT, P], Coro[Any]] + | Callable[Concatenate[CogT, ContextT, P], Coro[T]] + ) + ], + Group[CogT, P, T] | GroupT, +]: + """A decorator that transforms a function into a :class:`.Group`. + + This is similar to the :func:`.command` decorator but the ``cls`` + parameter is set to :class:`Group` by default. + + .. versionchanged:: 1.1 + The ``cls`` parameter can now be passed. + """ + if cls is MISSING: + cls = Group # type: ignore + return command(name=name, cls=cls, **attrs) # type: ignore + + +def check(predicate: Check) -> Callable[[T], T]: + r"""A decorator that adds a check to the :class:`.Command` or its + subclasses. These checks could be accessed via :attr:`.Command.checks`. + + These checks should be predicates that take in a single parameter taking + a :class:`.Context`. If the check returns a ``False``\-like value then + during invocation a :exc:`.CheckFailure` exception is raised and sent to + the :func:`.on_command_error` event. + + If an exception should be thrown in the predicate then it should be a + subclass of :exc:`.CommandError`. Any exception not subclassed from it + will be propagated while those subclassed will be sent to + :func:`.on_command_error`. + + A special attribute named ``predicate`` is bound to the value + returned by this decorator to retrieve the predicate passed to the + decorator. This allows the following introspection and chaining to be done: + + .. code-block:: python3 + + def owner_or_permissions(**perms): + original = commands.has_permissions(**perms).predicate + async def extended_check(ctx): + if ctx.guild is None: + return False + return ctx.guild.owner_id == ctx.author.id or await original(ctx) + return commands.check(extended_check) + + .. note:: + + The function returned by ``predicate`` is **always** a coroutine, + even if the original function was not a coroutine. + + .. versionchanged:: 1.3 + The ``predicate`` attribute was added. + + Examples + --------- + + Creating a basic check to see if the command invoker is you. + + .. code-block:: python3 + + def check_if_it_is_me(ctx): + return ctx.message.author.id == 85309593344815104 + + @bot.command() + @commands.check(check_if_it_is_me) + async def only_for_me(ctx): + await ctx.send('I know you!') + + Transforming common checks into its own decorator: + + .. code-block:: python3 + + def is_me(): + def predicate(ctx): + return ctx.message.author.id == 85309593344815104 + return commands.check(predicate) + + @bot.command() + @is_me() + async def only_me(ctx): + await ctx.send('Only you!') + + Parameters + ----------- + predicate: Callable[[:class:`Context`], :class:`bool`] + The predicate to check if the command should be invoked. + """ + + def decorator(func: Command | CoroFunc) -> Command | CoroFunc: + if isinstance(func, _BaseCommand): + func.checks.append(predicate) + else: + if not hasattr(func, "__commands_checks__"): + func.__commands_checks__ = [] + + func.__commands_checks__.append(predicate) + + return func + + if inspect.iscoroutinefunction(predicate): + decorator.predicate = predicate + else: + + @functools.wraps(predicate) + async def wrapper(ctx): + return predicate(ctx) # type: ignore + + decorator.predicate = wrapper + + return decorator # type: ignore + + +def check_any(*checks: Check) -> Callable[[T], T]: + r"""A :func:`check` that is added that checks if any of the checks passed + will pass, i.e. using logical OR. + + If all checks fail then :exc:`.CheckAnyFailure` is raised to signal the failure. + It inherits from :exc:`.CheckFailure`. + + .. note:: + + The ``predicate`` attribute for this function **is** a coroutine. + + .. versionadded:: 1.3 + + Parameters + ------------ + \*checks: Callable[[:class:`Context`], :class:`bool`] + An argument list of checks that have been decorated with + the :func:`check` decorator. + + Raises + ------- + TypeError + A check passed has not been decorated with the :func:`check` + decorator. + + Examples + --------- + + Creating a basic check to see if it's the bot owner or + the server owner: + + .. code-block:: python3 + + def is_guild_owner(): + def predicate(ctx): + return ctx.guild is not None and ctx.guild.owner_id == ctx.author.id + return commands.check(predicate) + + @bot.command() + @commands.check_any(commands.is_owner(), is_guild_owner()) + async def only_for_owners(ctx): + await ctx.send('Hello mister owner!') + """ + + unwrapped = [] + for wrapped in checks: + try: + pred = wrapped.predicate + except AttributeError: + raise TypeError( + f"{wrapped!r} must be wrapped by commands.check decorator" + ) from None + else: + unwrapped.append(pred) + + async def predicate(ctx: Context) -> bool: + errors = [] + for func in unwrapped: + try: + value = await func(ctx) + except CheckFailure as e: + errors.append(e) + else: + if value: + return True + # if we're here, all checks failed + raise CheckAnyFailure(unwrapped, errors) + + return check(predicate) + + +def has_role(item: int | str) -> Callable[[T], T]: + """A :func:`.check` that is added that checks if the member invoking the + command has the role specified via the name or ID specified. + + If a string is specified, you must give the exact name of the role, including + caps and spelling. + + If an integer is specified, you must give the exact snowflake ID of the role. + + If the message is invoked in a private message context then the check will + return ``False``. + + This check raises one of two special exceptions, :exc:`.MissingRole` if the user + is missing a role, or :exc:`.NoPrivateMessage` if it is used in a private message. + Both inherit from :exc:`.CheckFailure`. + + .. versionchanged:: 1.1 + + Raise :exc:`.MissingRole` or :exc:`.NoPrivateMessage` + instead of generic :exc:`.CheckFailure` + + Parameters + ---------- + item: Union[:class:`int`, :class:`str`] + The name or ID of the role to check. + """ + + def predicate(ctx: Context) -> bool: + if ctx.guild is None: + raise NoPrivateMessage() + + # ctx.guild is None doesn't narrow ctx.author to Member + if isinstance(item, int): + role = discord.utils.get(ctx.author.roles, id=item) # type: ignore + else: + role = discord.utils.get(ctx.author.roles, name=item) # type: ignore + if role is None: + raise MissingRole(item) + return True + + return check(predicate) + + +def has_any_role(*items: int | str) -> Callable[[T], T]: + r"""A :func:`.check` that is added that checks if the member invoking the + command has **any** of the roles specified. This means that if they have + one out of the three roles specified, then this check will return `True`. + + Similar to :func:`.has_role`\, the names or IDs passed in must be exact. + + This check raises one of two special exceptions, :exc:`.MissingAnyRole` if the user + is missing all roles, or :exc:`.NoPrivateMessage` if it is used in a private message. + Both inherit from :exc:`.CheckFailure`. + + .. versionchanged:: 1.1 + + Raise :exc:`.MissingAnyRole` or :exc:`.NoPrivateMessage` + instead of generic :exc:`.CheckFailure` + + Parameters + ----------- + items: List[Union[:class:`str`, :class:`int`]] + An argument list of names or IDs to check that the member has roles wise. + + Example + -------- + + .. code-block:: python3 + + @bot.command() + @commands.has_any_role('Library Devs', 'Moderators', 492212595072434186) + async def cool(ctx): + await ctx.send('You are cool indeed') + """ + + def predicate(ctx): + if ctx.guild is None: + raise NoPrivateMessage() + + # ctx.guild is None doesn't narrow ctx.author to Member + getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore + if any( + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items + ): + return True + raise MissingAnyRole(list(items)) + + return check(predicate) + + +def bot_has_role(item: int) -> Callable[[T], T]: + """Similar to :func:`.has_role` except checks if the bot itself has the + role. + + This check raises one of two special exceptions, :exc:`.BotMissingRole` if the bot + is missing the role, or :exc:`.NoPrivateMessage` if it is used in a private message. + Both inherit from :exc:`.CheckFailure`. + + .. versionchanged:: 1.1 + + Raise :exc:`.BotMissingRole` or :exc:`.NoPrivateMessage` + instead of generic :exc:`.CheckFailure` + """ + + def predicate(ctx): + if ctx.guild is None: + raise NoPrivateMessage() + + me = ctx.me + if isinstance(item, int): + role = discord.utils.get(me.roles, id=item) + else: + role = discord.utils.get(me.roles, name=item) + if role is None: + raise BotMissingRole(item) + return True + + return check(predicate) + + +def bot_has_any_role(*items: int) -> Callable[[T], T]: + """Similar to :func:`.has_any_role` except checks if the bot itself has + any of the roles listed. + + This check raises one of two special exceptions, :exc:`.BotMissingAnyRole` if the bot + is missing all roles, or :exc:`.NoPrivateMessage` if it is used in a private message. + Both inherit from :exc:`.CheckFailure`. + + .. versionchanged:: 1.1 + + Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` + instead of generic :exc:`.CheckFailure`. + """ + + def predicate(ctx): + if ctx.guild is None: + raise NoPrivateMessage() + + me = ctx.me + getter = functools.partial(discord.utils.get, me.roles) + if any( + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items + ): + return True + raise BotMissingAnyRole(list(items)) + + return check(predicate) + + +def has_permissions(**perms: bool) -> Callable[[T], T]: + r"""A :func:`.check` that is added that checks if the member has all of + the permissions necessary. + + Note that this check operates on the current channel permissions, not the + guild wide permissions. + + The permissions passed in must be exactly like the properties shown under + :class:`.discord.Permissions`. + + This check raises a special exception, :exc:`.MissingPermissions` + that is inherited from :exc:`.CheckFailure`. + + If the command is executed within a DM, it returns ``True``. + + Parameters + ------------ + \*\*perms: Dict[:class:`str`, :class:`bool`] + An argument list of permissions to check for. + + Example + --------- + + .. code-block:: python3 + + @bot.command() + @commands.has_permissions(manage_messages=True) + async def test(ctx): + await ctx.send('You can manage messages.') + + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") + + def predicate(ctx: Context) -> bool: + if ctx.channel.type == ChannelType.private: + return True + permissions = ctx.channel.permissions_for(ctx.author) # type: ignore + + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] + + if not missing: + return True + + raise MissingPermissions(missing) + + return check(predicate) + + +def bot_has_permissions(**perms: bool) -> Callable[[T], T]: + """Similar to :func:`.has_permissions` except checks if the bot itself has + the permissions listed. + + This check raises a special exception, :exc:`.BotMissingPermissions` + that is inherited from :exc:`.CheckFailure`. + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") + + def predicate(ctx: Context) -> bool: + guild = ctx.guild + me = guild.me if guild is not None else ctx.bot.user + if ctx.channel.type == ChannelType.private: + return True + + if hasattr(ctx, "app_permissions"): + permissions = ctx.app_permissions + else: + permissions = ctx.channel.permissions_for(me) # type: ignore + + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return check(predicate) + + +def has_guild_permissions(**perms: bool) -> Callable[[T], T]: + """Similar to :func:`.has_permissions`, but operates on guild wide + permissions instead of the current channel permissions. + + If this check is called in a DM context, it will raise an + exception, :exc:`.NoPrivateMessage`. + + .. versionadded:: 1.3 + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") + + def predicate(ctx: Context) -> bool: + if not ctx.guild: + raise NoPrivateMessage + + permissions = ctx.author.guild_permissions # type: ignore + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] + + if not missing: + return True + + raise MissingPermissions(missing) + + return check(predicate) + + +def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: + """Similar to :func:`.has_guild_permissions`, but checks the bot + members guild permissions. + + .. versionadded:: 1.3 + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") + + def predicate(ctx: Context) -> bool: + if not ctx.guild: + raise NoPrivateMessage + + permissions = ctx.me.guild_permissions # type: ignore + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return check(predicate) + + +def dm_only() -> Callable[[T], T]: + """A :func:`.check` that indicates this command must only be used in a + DM context. Only private messages are allowed when + using the command. + + This check raises a special exception, :exc:`.PrivateMessageOnly` + that is inherited from :exc:`.CheckFailure`. + + .. versionadded:: 1.1 + """ + + def predicate(ctx: Context) -> bool: + if ctx.guild is not None: + raise PrivateMessageOnly() + return True + + return check(predicate) + + +def guild_only() -> Callable[[T], T]: + """A :func:`.check` that indicates this command must only be used in a + guild context only. Basically, no private messages are allowed when + using the command. + + This check raises a special exception, :exc:`.NoPrivateMessage` + that is inherited from :exc:`.CheckFailure`. + """ + + def predicate(ctx: Context) -> bool: + if ctx.guild is None: + raise NoPrivateMessage() + return True + + return check(predicate) + + +def is_owner() -> Callable[[T], T]: + """A :func:`.check` that checks if the person invoking this command is the + owner of the bot. + + This is powered by :meth:`.Bot.is_owner`. + + This check raises a special exception, :exc:`.NotOwner` that is derived + from :exc:`.CheckFailure`. + """ + + async def predicate(ctx: Context) -> bool: + if not await ctx.bot.is_owner(ctx.author): + raise NotOwner("You do not own this bot.") + return True + + return check(predicate) + + +def is_nsfw() -> Callable[[T], T]: + """A :func:`.check` that checks if the channel is a NSFW channel. + + This check raises a special exception, :exc:`.NSFWChannelRequired` + that is derived from :exc:`.CheckFailure`. + + .. versionchanged:: 1.1 + + Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. + DM channels will also now pass this check. + """ + + def pred(ctx: Context) -> bool: + ch = ctx.channel + if ctx.guild is None or ( + isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() + ): + return True + raise NSFWChannelRequired(ch) # type: ignore + + return check(pred) + + +def cooldown( + rate: int, + per: float, + type: BucketType | Callable[[Message], Any] = BucketType.default, +) -> Callable[[T], T]: + """A decorator that adds a cooldown to a command + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-guild, per-channel, per-user, per-role or global basis. + Denoted by the third argument of ``type`` which must be of enum + type :class:`.BucketType`. + + If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in + :func:`.on_command_error` and the local error handler. + + A command can only have a single cooldown. + + Parameters + ---------- + rate: :class:`int` + The number of times a command can be used before triggering a cooldown. + per: :class:`float` + The amount of seconds to wait for a cooldown when it's been triggered. + type: Union[:class:`.BucketType`, Callable[[:class:`.Message`], Any]] + The type of cooldown to have. If callable, should return a key for the mapping. + + .. versionchanged:: 1.7 + Callables are now supported for custom bucket types. + """ + + def decorator(func: Command | CoroFunc) -> Command | CoroFunc: + if isinstance(func, (Command, ApplicationCommand)): + func._buckets = CooldownMapping(Cooldown(rate, per), type) + else: + func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) + return func + + return decorator # type: ignore + + +def dynamic_cooldown( + cooldown: BucketType | Callable[[Message], Any], + type: BucketType = BucketType.default, +) -> Callable[[T], T]: + """A decorator that adds a dynamic cooldown to a command + + This differs from :func:`.cooldown` in that it takes a function that + accepts a single parameter of type :class:`.discord.Message` and must + return a :class:`.Cooldown` or ``None``. If ``None`` is returned then + that cooldown is effectively bypassed. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-guild, per-channel, per-user, per-role or global basis. + Denoted by the third argument of ``type`` which must be of enum + type :class:`.BucketType`. + + If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in + :func:`.on_command_error` and the local error handler. + + A command can only have a single cooldown. + + .. versionadded:: 2.0 + + Parameters + ---------- + cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`.Cooldown`]] + A function that takes a message and returns a cooldown that will + apply to this invocation or ``None`` if the cooldown should be bypassed. + type: :class:`.BucketType` + The type of cooldown to have. + """ + if not callable(cooldown): + raise TypeError("A callable must be provided") + + def decorator(func: Command | CoroFunc) -> Command | CoroFunc: + if isinstance(func, Command): + func._buckets = DynamicCooldownMapping(cooldown, type) + else: + func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) + return func + + return decorator # type: ignore + + +def max_concurrency( + number: int, per: BucketType = BucketType.default, *, wait: bool = False +) -> Callable[[T], T]: + """A decorator that adds a maximum concurrency to a command + + This enables you to only allow a certain number of command invocations at the same time, + for example if a command takes too long or if only one user can use it at a time. This + differs from a cooldown in that there is no set waiting period or token bucket -- only + a set number of people can run the command. + + .. versionadded:: 1.3 + + Parameters + ---------- + number: :class:`int` + The maximum number of invocations of this command that can be running at the same time. + per: :class:`.BucketType` + The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow + it to be used up to ``number`` times per guild. + wait: :class:`bool` + Whether the command should wait for the queue to be over. If this is set to ``False`` + then instead of waiting until the command can run again, the command raises + :exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True`` + then the command waits until it can be executed. + """ + + def decorator(func: Command | CoroFunc) -> Command | CoroFunc: + value = MaxConcurrency(number, per=per, wait=wait) + if isinstance(func, (Command, ApplicationCommand)): + func._max_concurrency = value + else: + func.__commands_max_concurrency__ = value + return func + + return decorator # type: ignore + + +def before_invoke(coro) -> Callable[[T], T]: + """A decorator that registers a coroutine as a pre-invoke hook. + + This allows you to refer to one before invoke hook for several commands that + do not have to be within the same cog. + + .. versionadded:: 1.4 + + Example + ------- + + .. code-block:: python3 + + async def record_usage(ctx): + print(ctx.author, 'used', ctx.command, 'at', ctx.message.created_at) + + @bot.command() + @commands.before_invoke(record_usage) + async def who(ctx): # Output: used who at