Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
DoroWolf committed Dec 8, 2024
2 parents cc84bd1 + b2e7c98 commit c41d2a3
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 237 deletions.
22 changes: 12 additions & 10 deletions core/builtins/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
from datetime import datetime, UTC as datetimeUTC
from typing import Any, Coroutine, Dict, List, Optional, Union
Expand All @@ -22,18 +24,18 @@ class ExecutionLockList:
_list = set()

@staticmethod
def add(msg: 'MessageSession'):
def add(msg: MessageSession):
target_id = msg.target.sender_id
ExecutionLockList._list.add(target_id)

@staticmethod
def remove(msg: 'MessageSession'):
def remove(msg: MessageSession):
target_id = msg.target.sender_id
if target_id in ExecutionLockList._list:
ExecutionLockList._list.remove(target_id)

@staticmethod
def check(msg: 'MessageSession'):
def check(msg: MessageSession):
target_id = msg.target.sender_id
return target_id in ExecutionLockList._list

Expand All @@ -49,7 +51,7 @@ class MessageTaskManager:
@classmethod
def add_task(
cls,
session: 'MessageSession',
session: MessageSession,
flag: asyncio.Event,
all_: bool = False,
reply: Optional[Union[List[int], List[str], int, str]] = None,
Expand All @@ -73,7 +75,7 @@ def add_callback(cls, message_id: Union[List[int], List[str], int, str], callbac
cls._callback_list[message_id] = {'callback': callback, 'ts': datetime.now().timestamp()}

@classmethod
def get_result(cls, session: 'MessageSession'):
def get_result(cls, session: MessageSession):
if 'result' in cls._task_list[session.target.target_id][session.target.sender_id][session]:
return cls._task_list[session.target.target_id][session.target.sender_id][session]['result']
else:
Expand All @@ -98,7 +100,7 @@ async def bg_check(cls):
del cls._callback_list[message_id]

@classmethod
async def check(cls, session: 'MessageSession'):
async def check(cls, session: MessageSession):
if session.target.target_id in cls._task_list:
senders = []
if session.target.sender_id in cls._task_list[session.target.target_id]:
Expand Down Expand Up @@ -284,7 +286,7 @@ async def get_text_channel_list(self) -> List[str]:
raise NotImplementedError

class Typing:
def __init__(self, msg: 'MessageSession'):
def __init__(self, msg: MessageSession):
"""
:param msg: 本条消息,由于此class需要被一同传入下游方便调用,所以作为子class存在,将来可能会有其他的解决办法。
"""
Expand Down Expand Up @@ -346,7 +348,7 @@ async def wait_next_message(self,
quote: bool = True,
delete: bool = False,
timeout: Optional[float] = 120,
append_instruction: bool = True) -> 'MessageSession':
append_instruction: bool = True) -> MessageSession:
"""
一次性模板,用于等待对象的下一条消息。
Expand Down Expand Up @@ -386,7 +388,7 @@ async def wait_reply(self,
delete: bool = False,
timeout: Optional[float] = 120,
all_: bool = False,
append_instruction: bool = True) -> 'MessageSession':
append_instruction: bool = True) -> MessageSession:
"""
一次性模板,用于等待触发对象回复消息。
Expand Down Expand Up @@ -425,7 +427,7 @@ async def wait_anyone(self,
message_chain: Optional[Union[MessageChain, str, list, MessageElement]] = None,
quote: bool = False,
delete: bool = False,
timeout: Optional[float] = 120) -> 'MessageSession':
timeout: Optional[float] = 120) -> MessageSession:
"""
一次性模板,用于等待触发对象所属对话内任意成员确认。
Expand Down
6 changes: 4 additions & 2 deletions core/builtins/message/chain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import re
from typing import List, Optional, Tuple, Union, Any
Expand Down Expand Up @@ -32,7 +34,7 @@ def __init__(
List[MessageElement],
Tuple[MessageElement],
MessageElement,
'MessageChain'
MessageChain
]] = None,
):
"""
Expand Down Expand Up @@ -146,7 +148,7 @@ def unsafeprompt(name, secret, text):
return False
return True

def as_sendable(self, msg: 'MessageSession' = None, embed: bool = True) -> list:
def as_sendable(self, msg: MessageSession = None, embed: bool = True) -> list:
"""
将消息链转换为可发送的格式。
"""
Expand Down
30 changes: 17 additions & 13 deletions core/builtins/message/elements.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import base64
import os
import random
import re
from datetime import datetime, timezone
from typing import Tuple, Optional, TYPE_CHECKING, Dict, Any, Union, List
from typing import Optional, TYPE_CHECKING, Dict, Any, Union, List
from urllib import parse

import aiohttp
Expand All @@ -19,6 +21,8 @@
from core.utils.cache import random_cache_path
from core.utils.i18n import Locale

from copy import deepcopy

if TYPE_CHECKING:
from core.builtins import MessageSession

Expand Down Expand Up @@ -50,7 +54,7 @@ def assign(cls,
text = ''.join([str(x) for x in texts])
if not disable_joke:
text = joke(text)
return cls(text=text)
return deepcopy(cls(text=text))


@define
Expand All @@ -76,7 +80,7 @@ def assign(cls, url: str, use_mm: bool = False):
"nopqrstuvwxyzabcdefghijklmNOPQRSTUVWXYZABCDEFGHIJKLM")
url = mm_url % parse.quote(parse.unquote(url).translate(rot13))

return cls(url=url)
return deepcopy(cls(url=url))

def __str__(self):
if self.md_format:
Expand All @@ -98,7 +102,7 @@ class FormattedTimeElement(MessageElement):
seconds: bool = True
timezone: bool = True

def to_str(self, msg: Optional['MessageSession'] = None):
def to_str(self, msg: Optional[MessageSession] = None):
ftime_template = []
if msg:
if self.date:
Expand Down Expand Up @@ -145,7 +149,7 @@ def assign(cls, timestamp: float,
:param seconds: 是否显示秒。(默认为True)
:param timezone: 是否显示时区。(默认为True)
"""
return cls(timestamp=timestamp, date=date, iso=iso, time=time, seconds=seconds, timezone=timezone)
return deepcopy(cls(timestamp=timestamp, date=date, iso=iso, time=time, seconds=seconds, timezone=timezone))


@define
Expand All @@ -164,7 +168,7 @@ def assign(cls,
:param key: 多语言的键名。
:param kwargs: 多语言中的变量。
"""
return cls(key=key, kwargs=kwargs)
return deepcopy(cls(key=key, kwargs=kwargs))


@define
Expand Down Expand Up @@ -203,7 +207,7 @@ def assign(cls,
error_message += '\n' + \
locale.t('error.prompt.address', url=str(report_url))

return cls(error_message)
return deepcopy(cls(error_message))

def __str__(self):
return self.error_message
Expand Down Expand Up @@ -234,7 +238,7 @@ def assign(cls, path: Union[str, PILImage.Image],
path = save
elif re.match('^https?://.*', path):
need_get = True
return cls(path, need_get, headers)
return deepcopy(cls(path, need_get, headers))

async def get(self):
"""
Expand Down Expand Up @@ -294,7 +298,7 @@ def assign(cls, path: str):
"""
:param path: 语音路径。
"""
return cls(path)
return deepcopy(cls(path))


@define
Expand All @@ -317,7 +321,7 @@ def assign(cls, name: str, value: str, inline: bool = False):
:param value: 字段值。
:param inline: 是否内联。(默认为False)
"""
return cls(name=name, value=value, inline=inline)
return deepcopy(cls(name=name, value=value, inline=inline))


@define
Expand Down Expand Up @@ -355,7 +359,7 @@ def assign(cls, title: Optional[str] = None,
author: Optional[str] = None,
footer: Optional[str] = None,
fields: Optional[List[EmbedFieldElement]] = None):
return cls(
return deepcopy(cls(
title=title,
description=description,
url=url,
Expand All @@ -365,9 +369,9 @@ def assign(cls, title: Optional[str] = None,
thumbnail=thumbnail,
author=author,
footer=footer,
fields=fields)
fields=fields))

def to_message_chain(self, msg: Optional['MessageSession'] = None):
def to_message_chain(self, msg: Optional[MessageSession] = None):
"""
将Embed转换为消息链。
"""
Expand Down
32 changes: 16 additions & 16 deletions core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,22 +184,22 @@ def module(
:param exclude_from: 此命令排除的平台列表。
:param support_languages: 此命令支持的语言列表。
"""
module = Module(alias=alias,
bind_prefix=bind_prefix,
desc=desc,
recommend_modules=recommend_modules,
developers=developers,
base=base,
doc=doc,
hidden=hidden,
load=load,
rss=rss,
required_admin=required_admin,
required_superuser=required_superuser,
required_base_superuser=required_base_superuser,
available_for=available_for,
exclude_from=exclude_from,
support_languages=support_languages)
module = Module.assign(alias=alias,
bind_prefix=bind_prefix,
desc=desc,
recommend_modules=recommend_modules,
developers=developers,
base=base,
doc=doc,
hidden=hidden,
load=load,
rss=rss,
required_admin=required_admin,
required_superuser=required_superuser,
required_base_superuser=required_base_superuser,
available_for=available_for,
exclude_from=exclude_from,
support_languages=support_languages)
frame = inspect.currentframe()
ModulesManager.add_module(module, frame.f_back.f_globals["__name__"])
return Bind.Module(bind_prefix)
8 changes: 4 additions & 4 deletions core/parser/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ async def execute_submodule(msg: Bot.MessageSession, command_first_word, command

if Config('bug_report_url', bug_report_url_default, cfg_type=str):
errmsg += '\n' + msg.locale.t('error.prompt.address',
url=str(Url(Config('bug_report_url',
bug_report_url_default,
cfg_type=str),
use_mm=False)))
url=Url(Config('bug_report_url',
bug_report_url_default,
cfg_type=str),
use_mm=False))
await msg.send_message(errmsg)

if not timeout and report_targets:
Expand Down
4 changes: 2 additions & 2 deletions core/types/message/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

@define
class MsgInfo:
target_id: Union[int, str]
sender_id: Union[int, str]
target_id: str
sender_id: str
sender_prefix: str
target_from: str
sender_from: str
Expand Down
84 changes: 37 additions & 47 deletions core/types/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,47 @@

from .component_matches import *

from .utils import convert2lst

def convert2lst(elements: Union[str, list, tuple]) -> list:
if isinstance(elements, str):
return [elements]
elif isinstance(elements, tuple):
return list(elements)
return elements
from attrs import define, field, Converter

from copy import deepcopy


def alias_converter(value, _self) -> dict:
if isinstance(value, str):
return {value: _self.bind_prefix}
elif isinstance(value, (tuple, list)):
return {x: _self.bind_prefix for x in value}
return value


@define
class Module:
def __init__(self,
bind_prefix: str,
alias: Union[str, list, tuple, dict, None] = None,
desc: str = None,
recommend_modules: Union[str, list, tuple, None] = None,
developers: Union[str, list, tuple, None] = None,
required_admin: bool = False,
base: bool = False,
doc: bool = False,
hidden: bool = False,
load: bool = True,
rss: bool = False,
required_superuser: bool = False,
required_base_superuser: bool = False,
available_for: Union[str, list, tuple, None] = '*',
exclude_from: Union[str, list, tuple, None] = '',
support_languages: Union[str, list, tuple, None] = None):
self.bind_prefix: str = bind_prefix
if isinstance(alias, str):
alias = {alias: bind_prefix}
elif isinstance(alias, (tuple, list)):
alias = {x: bind_prefix for x in alias}
self.alias: Dict[str, str] = alias
self.desc: str = desc
self.recommend_modules: List[str] = convert2lst(recommend_modules)
self.developers: List[str] = convert2lst(developers)
self.required_admin: bool = required_admin
self.base: bool = base
self.doc: bool = doc
self.hidden: bool = hidden
self.load: bool = load
self.rss: bool = rss
self.required_superuser: bool = required_superuser
self.required_base_superuser: bool = required_base_superuser
self.available_for: List[str] = convert2lst(available_for)
self.exclude_from: List[str] = convert2lst(exclude_from)
self.support_languages: List[str] = convert2lst(support_languages)
self.command_list = CommandMatches()
self.regex_list = RegexMatches()
self.schedule_list = ScheduleMatches()
self.hooks_list = HookMatches()
bind_prefix: str
alias: dict = field(converter=Converter(alias_converter, takes_self=True))
recommend_modules: list = field(converter=convert2lst)
developers: list = field(converter=convert2lst)
available_for: list = field(default=['*'], converter=convert2lst)
exclude_from: list = field(default=[], converter=convert2lst)
support_languages: list = field(default=None, converter=convert2lst)
desc: Union[str] = ''
required_admin: bool = False
base: bool = False
doc: bool = False
hidden: bool = False
load: bool = True
rss: bool = False
required_superuser: bool = False
required_base_superuser: bool = False
command_list: CommandMatches = CommandMatches.init()
regex_list: RegexMatches = RegexMatches.init()
schedule_list: ScheduleMatches = ScheduleMatches.init()
hooks_list: HookMatches = HookMatches.init()

@classmethod
def assign(cls, **kwargs):
return deepcopy(cls(**kwargs))


__all__ = ["Module", "AndTrigger", "OrTrigger", "DateTrigger",
Expand Down
Loading

0 comments on commit c41d2a3

Please sign in to comment.