Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #106

Merged
merged 2 commits into from
Jul 31, 2024
Merged

Dev #106

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions amiyabot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def __init__(
if not appid:
appid = random_code(10)

super().__init__(appid, token, adapter)
super().__init__(appid, token, adapter, private)

self.private = private
self.send_message = self.instance.send_message

self.__closed = False
Expand All @@ -71,7 +70,7 @@ async def start(self, launch_browser: typing.Union[bool, BrowserLaunchConfig] =
await basic_browser_service.launch(BrowserLaunchConfig() if launch_browser is True else launch_browser)

self.run_timed_tasks()
await self.instance.start(self.private, self.__message_handler)
await self.instance.start(self.__message_handler)

async def close(self):
if not self.__closed:
Expand Down
9 changes: 6 additions & 3 deletions amiyabot/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@


class BotAdapterProtocol:
def __init__(self, appid: str, token: str):
def __init__(self, appid: str, token: str, private: bool = False):
self.appid = appid
self.token = token
self.alive = False
self.keep_run = True

self.private = private

# 适配器实例连接信息
self.host: Optional[str] = None
self.ws_port: Optional[int] = None
self.http_port: Optional[int] = None
self.session: Optional[str] = None
self.headers: Optional[dict] = None

self.bot_name = ''

self.log = LoggerManager(self.__str__())
self.bot: Optional[T_BotHandlerFactory] = None

Expand Down Expand Up @@ -73,11 +77,10 @@ async def close(self):
raise NotImplementedError

@abc.abstractmethod
async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
"""
启动实例,执行 handler 方法处理消息

:param private: 是否私域机器人
:param handler: 消息处理方法
"""
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/comwechat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ComWeChatBotInstance(OneBot12Instance):
def __str__(self):
return 'ComWeChat'

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler, package_method=package_com_wechat_message)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/kook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def api(self):
def __still_alive(self):
return self.keep_run and self.connection

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
me_req = await self.api.get_me()
if me_req:
self.appid = me_req.json['data']['id']
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/mirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/onebot/v11/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/onebot/v12/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
16 changes: 12 additions & 4 deletions amiyabot/adapters/tencent/qqGlobal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@


class QQGlobalBotInstance(QQGroupBotInstance):
def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder):
super().__init__(appid, token, client_secret, default_chain_builder)

self.guild = QQGuildBotInstance(appid, token)
def __init__(
self,
appid: str,
token: str,
client_secret: str,
default_chain_builder: ChainBuilder,
shard_index: int,
shards: int,
):
super().__init__(appid, token, client_secret, default_chain_builder, shard_index, shards)

self.guild = QQGuildBotInstance(appid, token, shard_index, shards)

def __str__(self):
return 'QQGlobal'
Expand Down
20 changes: 15 additions & 5 deletions amiyabot/adapters/tencent/qqGroup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@


class QQGroupBotInstance(QQGuildBotInstance):
def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder):
super().__init__(appid, token)
def __init__(
self,
appid: str,
token: str,
client_secret: str,
default_chain_builder: ChainBuilder,
shard_index: int,
shards: int,
):
super().__init__(appid, token, shard_index, shards)

self.__access_token_api = QQGroupAPI(self.appid, self.token, client_secret)
self.__default_chain_builder = default_chain_builder
Expand All @@ -33,14 +41,16 @@ def build_adapter(
client_secret: str,
default_chain_builder: Optional[ChainBuilder] = None,
default_chain_builder_options: QQGroupChainBuilderOptions = QQGroupChainBuilderOptions(),
shard_index: int = 0,
shards: int = 1,
):
def adapter(appid: str, token: str):
if default_chain_builder:
cb = default_chain_builder
else:
cb = QQGroupChainBuilder(default_chain_builder_options)

return cls(appid, token, client_secret, cb)
return cls(appid, token, client_secret, cb, shard_index, shards)

return adapter

Expand All @@ -52,14 +62,14 @@ def api(self):
def package_method(self):
return package_qq_group_message

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
if hasattr(self.__default_chain_builder, 'start'):
self.__default_chain_builder.start()

if not self.__seq_service.alive:
asyncio.create_task(self.__seq_service.run())

await super().start(private, handler)
await super().start(handler)

async def close(self):
await self.__seq_service.stop()
Expand Down
95 changes: 53 additions & 42 deletions amiyabot/adapters/tencent/qqGuild/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,28 @@
from amiyabot.adapters.tencent.intents import get_intents

from .api import QQGuildAPI, log
from .model import GateWay, Payload, ShardsRecord, ConnectionHandler
from .model import GateWay, Payload, ConnectionModel, ConnectionHandler
from .package import package_qq_guild_message
from .builder import build_message_send, QQGuildMessageCallback


def qq_guild_shards(shard_index: int, shards: int):
def adapter(appid: str, token: str):
return QQGuildBotInstance(appid, token, shard_index, shards)

return adapter


class QQGuildBotInstance(BotAdapterProtocol):
def __init__(self, appid: str, token: str):
def __init__(self, appid: str, token: str, shard_index: int = 0, shards: int = 1):
super().__init__(appid, token)

self.appid = appid
self.token = token
self.bot_name = ''
self.shard_index = shard_index
self.shards = shards

self.shards_record: Dict[int, ShardsRecord] = {}
self.model: Optional[ConnectionModel] = None

def __str__(self):
return 'QQGuild'
Expand All @@ -37,49 +45,61 @@ def api(self):
def package_method(self):
return package_qq_guild_message

def __create_heartbeat(self, websocket, interval: int, record: ShardsRecord):
def __create_heartbeat(self, websocket, interval: int):
heartbeat_key = random_code(10)
record.heartbeat_key = heartbeat_key
asyncio.create_task(self.heartbeat_interval(websocket, interval, record.shards_index, heartbeat_key))
self.model.heartbeat_key = heartbeat_key
asyncio.create_task(
self.heartbeat_interval(
websocket,
interval,
heartbeat_key,
)
)

async def close(self):
log.info(f'closing {self}(appid {self.appid})...')
self.keep_run = False

for _, item in self.shards_record.items():
if item.connection:
await item.connection.close()
if self.model:
await self.model.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
log.info(f'requesting appid {self.appid} gateway')

resp = await self.api.gateway_bot()

if not resp or 'url' not in resp.json:
if self.keep_run:
await asyncio.sleep(10)
asyncio.create_task(self.start(private, handler))
asyncio.create_task(self.start(handler))
return False

gateway = GateWay(**resp.json)

log.info(
f'appid {self.appid} gateway resp: shards {gateway.shards}, remaining %d/%d'
f'appid {self.appid} gateway resp: shards {gateway.shards}, max_concurrency %d, remaining %d/%d'
% (
gateway.session_start_limit['max_concurrency'],
gateway.session_start_limit['remaining'],
gateway.session_start_limit['total'],
)
)

await self.create_connection(ConnectionHandler(private=private, gateway=gateway, message_handler=handler))
await self.create_connection(
ConnectionHandler(
private=self.private,
gateway=gateway,
message_handler=handler,
)
)

async def create_connection(self, handler: ConnectionHandler, shards_index: int = 0):
async def create_connection(self, handler: ConnectionHandler):
gateway = handler.gateway
sign = f'{self.appid} {shards_index + 1}/{gateway.shards}'
sign = f'{self.appid} {self.shard_index + 1}/{self.shards}'

async with self.get_websocket_connection(sign, gateway.url) as websocket:
if websocket:
self.shards_record[shards_index] = ShardsRecord(shards_index, connection=websocket)
self.model = ConnectionModel(connection=websocket)

while self.keep_run:
await asyncio.sleep(0)
Expand All @@ -94,19 +114,15 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int
f'connected({sign}): {self.bot_name}({self}-%s)'
% ('private' if handler.private else 'public')
)
self.shards_record[shards_index].session_id = payload.d['session_id']

if shards_index == 0 and gateway.shards > 1:
for n in range(gateway.shards - 1):
asyncio.create_task(self.create_connection(handler, n + 1))
self.model.session_id = payload.d['session_id']
else:
await self.create_package_task(handler, payload)

if payload.op == 10:
create_token = {
'token': f'Bot {self.appid}.{self.token}',
'intents': get_intents(handler.private, self.__str__()),
'shard': [shards_index, gateway.shards],
'shard': [self.shard_index, self.shards],
'properties': {
'$os': sys.platform,
'$browser': '',
Expand All @@ -115,25 +131,21 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int
}
await websocket.send(Payload(op=2, d=create_token).to_json())

self.__create_heartbeat(
websocket,
payload.d['heartbeat_interval'],
self.shards_record[shards_index],
)
self.__create_heartbeat(websocket, payload.d['heartbeat_interval'])

if payload.s:
self.shards_record[shards_index].last_s = payload.s
self.model.last_s = payload.s

while self.keep_run and self.shards_record[shards_index].reconnect_limit > 0:
await self.reconnect(handler, self.shards_record[shards_index], sign)
while self.keep_run and self.model.reconnect_limit > 0:
await self.reconnect(handler, sign)
await asyncio.sleep(1)

async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign: str):
async def reconnect(self, handler: ConnectionHandler, sign: str):
log.info(f'reconnecting({sign})...')

async with self.get_websocket_connection(sign, handler.gateway.url) as websocket:
if websocket:
record.connection = websocket
self.model.connection = websocket

while self.keep_run:
await asyncio.sleep(0)
Expand All @@ -150,34 +162,33 @@ async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign
if payload.op == 10:
reconnect_token = {
'token': f'Bot {self.appid}.{self.token}',
'session_id': record.session_id,
'seq': record.last_s,
'session_id': self.model.session_id,
'seq': self.model.last_s,
}
await websocket.send(Payload(op=6, d=reconnect_token).to_json())

self.__create_heartbeat(websocket, payload.d['heartbeat_interval'], record)
self.__create_heartbeat(websocket, payload.d['heartbeat_interval'])

record.reconnect_limit = 3
self.model.reconnect_limit = 3

if payload.s:
record.last_s = payload.s
self.model.last_s = payload.s

record.reconnect_limit -= 1
self.model.reconnect_limit -= 1

async def heartbeat_interval(
self,
websocket: WebSocketClientProtocol,
interval: int,
shards_index: int,
heartbeat_key: str,
):
sec = 0
while self.keep_run and self.shards_record[shards_index].heartbeat_key == heartbeat_key:
while self.keep_run and self.model.heartbeat_key == heartbeat_key:
await asyncio.sleep(1)
sec += 1
if sec >= interval / 1000:
sec = 0
await websocket.send(Payload(op=1, d=self.shards_record[shards_index].last_s).to_json())
await websocket.send(Payload(op=1, d=self.model.last_s).to_json())

async def create_package_task(self, handler: ConnectionHandler, payload: Payload):
asyncio.create_task(
Expand Down
3 changes: 1 addition & 2 deletions amiyabot/adapters/tencent/qqGuild/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class ConnectionHandler:


@dataclass
class ShardsRecord:
shards_index: int
class ConnectionModel:
session_id: Optional[str] = None
last_s: Optional[int] = None
reconnect_limit: int = 3
Expand Down
Loading
Loading