From a281421a933e6eaaff9ed0987ac1b7968992313b Mon Sep 17 00:00:00 2001 From: Vivien <34387011+vivien8261@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:26:31 +0800 Subject: [PATCH] =?UTF-8?q?update:=20=E9=87=8D=E5=86=99=E4=BA=86=E5=AE=98?= =?UTF-8?q?=E6=96=B9qq=E6=8E=A5=E5=8F=A3=E7=9A=84=E5=88=86=E7=89=87?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- amiyabot/__init__.py | 5 +- amiyabot/adapters/__init__.py | 9 +- amiyabot/adapters/comwechat/__init__.py | 2 +- amiyabot/adapters/kook/__init__.py | 2 +- amiyabot/adapters/mirai/__init__.py | 2 +- amiyabot/adapters/onebot/v11/__init__.py | 2 +- amiyabot/adapters/onebot/v12/__init__.py | 2 +- .../adapters/tencent/qqGlobal/__init__.py | 16 +++- amiyabot/adapters/tencent/qqGroup/__init__.py | 20 +++- amiyabot/adapters/tencent/qqGuild/__init__.py | 95 +++++++++++-------- amiyabot/adapters/tencent/qqGuild/model.py | 3 +- amiyabot/adapters/test/__init__.py | 2 +- amiyabot/factory/__init__.py | 2 + 13 files changed, 97 insertions(+), 65 deletions(-) diff --git a/amiyabot/__init__.py b/amiyabot/__init__.py index c62dca7..24a6dcd 100644 --- a/amiyabot/__init__.py +++ b/amiyabot/__init__.py @@ -55,9 +55,8 @@ def __init__( if not appid: appid = random_code(10) - super().__init__(appid, token, adapter) + super().__init__(appid, token, adapter, private) - self.private = private self.send_message = self.instance.send_message self.__closed = False @@ -71,7 +70,7 @@ async def start(self, launch_browser: typing.Union[bool, BrowserLaunchConfig] = await basic_browser_service.launch(BrowserLaunchConfig() if launch_browser is True else launch_browser) self.run_timed_tasks() - await self.instance.start(self.private, self.__message_handler) + await self.instance.start(self.__message_handler) async def close(self): if not self.__closed: diff --git a/amiyabot/adapters/__init__.py b/amiyabot/adapters/__init__.py index 042b879..c6c0dc0 100644 --- a/amiyabot/adapters/__init__.py +++ b/amiyabot/adapters/__init__.py @@ -17,12 +17,14 @@ class BotAdapterProtocol: - def __init__(self, appid: str, token: str): + def __init__(self, appid: str, token: str, private: bool = False): self.appid = appid self.token = token self.alive = False self.keep_run = True + self.private = private + # 适配器实例连接信息 self.host: Optional[str] = None self.ws_port: Optional[int] = None @@ -30,6 +32,8 @@ def __init__(self, appid: str, token: str): self.session: Optional[str] = None self.headers: Optional[dict] = None + self.bot_name = '' + self.log = LoggerManager(self.__str__()) self.bot: Optional[T_BotHandlerFactory] = None @@ -73,11 +77,10 @@ async def close(self): raise NotImplementedError @abc.abstractmethod - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): """ 启动实例,执行 handler 方法处理消息 - :param private: 是否私域机器人 :param handler: 消息处理方法 """ raise NotImplementedError diff --git a/amiyabot/adapters/comwechat/__init__.py b/amiyabot/adapters/comwechat/__init__.py index a66e152..e60340d 100644 --- a/amiyabot/adapters/comwechat/__init__.py +++ b/amiyabot/adapters/comwechat/__init__.py @@ -19,7 +19,7 @@ class ComWeChatBotInstance(OneBot12Instance): def __str__(self): return 'ComWeChat' - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): while self.keep_run: await self.keep_connect(handler, package_method=package_com_wechat_message) await asyncio.sleep(10) diff --git a/amiyabot/adapters/kook/__init__.py b/amiyabot/adapters/kook/__init__.py index 901b8ce..9448d80 100644 --- a/amiyabot/adapters/kook/__init__.py +++ b/amiyabot/adapters/kook/__init__.py @@ -37,7 +37,7 @@ def api(self): def __still_alive(self): return self.keep_run and self.connection - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): me_req = await self.api.get_me() if me_req: self.appid = me_req.json['data']['id'] diff --git a/amiyabot/adapters/mirai/__init__.py b/amiyabot/adapters/mirai/__init__.py index ed55ca0..a558b29 100644 --- a/amiyabot/adapters/mirai/__init__.py +++ b/amiyabot/adapters/mirai/__init__.py @@ -51,7 +51,7 @@ async def close(self): if self.connection: await self.connection.close() - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): while self.keep_run: await self.keep_connect(handler) await asyncio.sleep(10) diff --git a/amiyabot/adapters/onebot/v11/__init__.py b/amiyabot/adapters/onebot/v11/__init__.py index 1418c4e..bef29a6 100644 --- a/amiyabot/adapters/onebot/v11/__init__.py +++ b/amiyabot/adapters/onebot/v11/__init__.py @@ -56,7 +56,7 @@ async def close(self): if self.connection: await self.connection.close() - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): while self.keep_run: await self.keep_connect(handler) await asyncio.sleep(10) diff --git a/amiyabot/adapters/onebot/v12/__init__.py b/amiyabot/adapters/onebot/v12/__init__.py index 33559d8..1542960 100644 --- a/amiyabot/adapters/onebot/v12/__init__.py +++ b/amiyabot/adapters/onebot/v12/__init__.py @@ -56,7 +56,7 @@ async def close(self): if self.connection: await self.connection.close() - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): while self.keep_run: await self.keep_connect(handler) await asyncio.sleep(10) diff --git a/amiyabot/adapters/tencent/qqGlobal/__init__.py b/amiyabot/adapters/tencent/qqGlobal/__init__.py index bcbea5a..1ec50db 100644 --- a/amiyabot/adapters/tencent/qqGlobal/__init__.py +++ b/amiyabot/adapters/tencent/qqGlobal/__init__.py @@ -6,10 +6,18 @@ class QQGlobalBotInstance(QQGroupBotInstance): - def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder): - super().__init__(appid, token, client_secret, default_chain_builder) - - self.guild = QQGuildBotInstance(appid, token) + def __init__( + self, + appid: str, + token: str, + client_secret: str, + default_chain_builder: ChainBuilder, + shard_index: int, + shards: int, + ): + super().__init__(appid, token, client_secret, default_chain_builder, shard_index, shards) + + self.guild = QQGuildBotInstance(appid, token, shard_index, shards) def __str__(self): return 'QQGlobal' diff --git a/amiyabot/adapters/tencent/qqGroup/__init__.py b/amiyabot/adapters/tencent/qqGroup/__init__.py index 94aa8a1..7185513 100644 --- a/amiyabot/adapters/tencent/qqGroup/__init__.py +++ b/amiyabot/adapters/tencent/qqGroup/__init__.py @@ -17,8 +17,16 @@ class QQGroupBotInstance(QQGuildBotInstance): - def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder): - super().__init__(appid, token) + def __init__( + self, + appid: str, + token: str, + client_secret: str, + default_chain_builder: ChainBuilder, + shard_index: int, + shards: int, + ): + super().__init__(appid, token, shard_index, shards) self.__access_token_api = QQGroupAPI(self.appid, self.token, client_secret) self.__default_chain_builder = default_chain_builder @@ -33,6 +41,8 @@ def build_adapter( client_secret: str, default_chain_builder: Optional[ChainBuilder] = None, default_chain_builder_options: QQGroupChainBuilderOptions = QQGroupChainBuilderOptions(), + shard_index: int = 0, + shards: int = 1, ): def adapter(appid: str, token: str): if default_chain_builder: @@ -40,7 +50,7 @@ def adapter(appid: str, token: str): else: cb = QQGroupChainBuilder(default_chain_builder_options) - return cls(appid, token, client_secret, cb) + return cls(appid, token, client_secret, cb, shard_index, shards) return adapter @@ -52,14 +62,14 @@ def api(self): def package_method(self): return package_qq_group_message - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): if hasattr(self.__default_chain_builder, 'start'): self.__default_chain_builder.start() if not self.__seq_service.alive: asyncio.create_task(self.__seq_service.run()) - await super().start(private, handler) + await super().start(handler) async def close(self): await self.__seq_service.stop() diff --git a/amiyabot/adapters/tencent/qqGuild/__init__.py b/amiyabot/adapters/tencent/qqGuild/__init__.py index 7b67f17..93a5423 100644 --- a/amiyabot/adapters/tencent/qqGuild/__init__.py +++ b/amiyabot/adapters/tencent/qqGuild/__init__.py @@ -11,20 +11,28 @@ from amiyabot.adapters.tencent.intents import get_intents from .api import QQGuildAPI, log -from .model import GateWay, Payload, ShardsRecord, ConnectionHandler +from .model import GateWay, Payload, ConnectionModel, ConnectionHandler from .package import package_qq_guild_message from .builder import build_message_send, QQGuildMessageCallback +def qq_guild_shards(shard_index: int, shards: int): + def adapter(appid: str, token: str): + return QQGuildBotInstance(appid, token, shard_index, shards) + + return adapter + + class QQGuildBotInstance(BotAdapterProtocol): - def __init__(self, appid: str, token: str): + def __init__(self, appid: str, token: str, shard_index: int = 0, shards: int = 1): super().__init__(appid, token) self.appid = appid self.token = token - self.bot_name = '' + self.shard_index = shard_index + self.shards = shards - self.shards_record: Dict[int, ShardsRecord] = {} + self.model: Optional[ConnectionModel] = None def __str__(self): return 'QQGuild' @@ -37,20 +45,25 @@ def api(self): def package_method(self): return package_qq_guild_message - def __create_heartbeat(self, websocket, interval: int, record: ShardsRecord): + def __create_heartbeat(self, websocket, interval: int): heartbeat_key = random_code(10) - record.heartbeat_key = heartbeat_key - asyncio.create_task(self.heartbeat_interval(websocket, interval, record.shards_index, heartbeat_key)) + self.model.heartbeat_key = heartbeat_key + asyncio.create_task( + self.heartbeat_interval( + websocket, + interval, + heartbeat_key, + ) + ) async def close(self): log.info(f'closing {self}(appid {self.appid})...') self.keep_run = False - for _, item in self.shards_record.items(): - if item.connection: - await item.connection.close() + if self.model: + await self.model.connection.close() - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): log.info(f'requesting appid {self.appid} gateway') resp = await self.api.gateway_bot() @@ -58,28 +71,35 @@ async def start(self, private: bool, handler: HANDLER_TYPE): if not resp or 'url' not in resp.json: if self.keep_run: await asyncio.sleep(10) - asyncio.create_task(self.start(private, handler)) + asyncio.create_task(self.start(handler)) return False gateway = GateWay(**resp.json) log.info( - f'appid {self.appid} gateway resp: shards {gateway.shards}, remaining %d/%d' + f'appid {self.appid} gateway resp: shards {gateway.shards}, max_concurrency %d, remaining %d/%d' % ( + gateway.session_start_limit['max_concurrency'], gateway.session_start_limit['remaining'], gateway.session_start_limit['total'], ) ) - await self.create_connection(ConnectionHandler(private=private, gateway=gateway, message_handler=handler)) + await self.create_connection( + ConnectionHandler( + private=self.private, + gateway=gateway, + message_handler=handler, + ) + ) - async def create_connection(self, handler: ConnectionHandler, shards_index: int = 0): + async def create_connection(self, handler: ConnectionHandler): gateway = handler.gateway - sign = f'{self.appid} {shards_index + 1}/{gateway.shards}' + sign = f'{self.appid} {self.shard_index + 1}/{self.shards}' async with self.get_websocket_connection(sign, gateway.url) as websocket: if websocket: - self.shards_record[shards_index] = ShardsRecord(shards_index, connection=websocket) + self.model = ConnectionModel(connection=websocket) while self.keep_run: await asyncio.sleep(0) @@ -94,11 +114,7 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int f'connected({sign}): {self.bot_name}({self}-%s)' % ('private' if handler.private else 'public') ) - self.shards_record[shards_index].session_id = payload.d['session_id'] - - if shards_index == 0 and gateway.shards > 1: - for n in range(gateway.shards - 1): - asyncio.create_task(self.create_connection(handler, n + 1)) + self.model.session_id = payload.d['session_id'] else: await self.create_package_task(handler, payload) @@ -106,7 +122,7 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int create_token = { 'token': f'Bot {self.appid}.{self.token}', 'intents': get_intents(handler.private, self.__str__()), - 'shard': [shards_index, gateway.shards], + 'shard': [self.shard_index, self.shards], 'properties': { '$os': sys.platform, '$browser': '', @@ -115,25 +131,21 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int } await websocket.send(Payload(op=2, d=create_token).to_json()) - self.__create_heartbeat( - websocket, - payload.d['heartbeat_interval'], - self.shards_record[shards_index], - ) + self.__create_heartbeat(websocket, payload.d['heartbeat_interval']) if payload.s: - self.shards_record[shards_index].last_s = payload.s + self.model.last_s = payload.s - while self.keep_run and self.shards_record[shards_index].reconnect_limit > 0: - await self.reconnect(handler, self.shards_record[shards_index], sign) + while self.keep_run and self.model.reconnect_limit > 0: + await self.reconnect(handler, sign) await asyncio.sleep(1) - async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign: str): + async def reconnect(self, handler: ConnectionHandler, sign: str): log.info(f'reconnecting({sign})...') async with self.get_websocket_connection(sign, handler.gateway.url) as websocket: if websocket: - record.connection = websocket + self.model.connection = websocket while self.keep_run: await asyncio.sleep(0) @@ -150,34 +162,33 @@ async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign if payload.op == 10: reconnect_token = { 'token': f'Bot {self.appid}.{self.token}', - 'session_id': record.session_id, - 'seq': record.last_s, + 'session_id': self.model.session_id, + 'seq': self.model.last_s, } await websocket.send(Payload(op=6, d=reconnect_token).to_json()) - self.__create_heartbeat(websocket, payload.d['heartbeat_interval'], record) + self.__create_heartbeat(websocket, payload.d['heartbeat_interval']) - record.reconnect_limit = 3 + self.model.reconnect_limit = 3 if payload.s: - record.last_s = payload.s + self.model.last_s = payload.s - record.reconnect_limit -= 1 + self.model.reconnect_limit -= 1 async def heartbeat_interval( self, websocket: WebSocketClientProtocol, interval: int, - shards_index: int, heartbeat_key: str, ): sec = 0 - while self.keep_run and self.shards_record[shards_index].heartbeat_key == heartbeat_key: + while self.keep_run and self.model.heartbeat_key == heartbeat_key: await asyncio.sleep(1) sec += 1 if sec >= interval / 1000: sec = 0 - await websocket.send(Payload(op=1, d=self.shards_record[shards_index].last_s).to_json()) + await websocket.send(Payload(op=1, d=self.model.last_s).to_json()) async def create_package_task(self, handler: ConnectionHandler, payload: Payload): asyncio.create_task( diff --git a/amiyabot/adapters/tencent/qqGuild/model.py b/amiyabot/adapters/tencent/qqGuild/model.py index 1ce5391..98a876c 100644 --- a/amiyabot/adapters/tencent/qqGuild/model.py +++ b/amiyabot/adapters/tencent/qqGuild/model.py @@ -22,8 +22,7 @@ class ConnectionHandler: @dataclass -class ShardsRecord: - shards_index: int +class ConnectionModel: session_id: Optional[str] = None last_s: Optional[int] = None reconnect_limit: int = 3 diff --git a/amiyabot/adapters/test/__init__.py b/amiyabot/adapters/test/__init__.py index 2444b60..3267bb1 100644 --- a/amiyabot/adapters/test/__init__.py +++ b/amiyabot/adapters/test/__init__.py @@ -48,7 +48,7 @@ def __str__(self): async def close(self): ... - async def start(self, private: bool, handler: HANDLER_TYPE): + async def start(self, handler: HANDLER_TYPE): await self.server.run(handler) async def build_active_message_chain(self, chain: Chain, user_id: str, channel_id: str, direct_src_guild_id: str): diff --git a/amiyabot/factory/__init__.py b/amiyabot/factory/__init__.py index 32600d4..4982919 100644 --- a/amiyabot/factory/__init__.py +++ b/amiyabot/factory/__init__.py @@ -20,6 +20,7 @@ def __init__( appid: Optional[str] = None, token: Optional[str] = None, adapter: Optional[Type[BotAdapterProtocol]] = None, + private: bool = False, ): super().__init__() @@ -31,6 +32,7 @@ def __init__( if adapter: self.instance = adapter(appid, token) self.instance.bot = self + self.instance.private = private @property def prefix_keywords(self) -> PrefixKeywords: