-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcore.py
234 lines (202 loc) · 8.93 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import pkgutil
import traceback
from contextvars import ContextVar
from pathlib import Path
from typing import Literal, Union
from arknights_toolkit.update.main import fetch
from avilla.core import Context
from creart import it
from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.entities.signatures import Force
from graia.broadcast.exceptions import PropagationCancelled, RequirementCrashed
from graia.broadcast.interfaces.dispatcher import DispatcherInterface
from graia.saya import Saya
from launart import Launart, Service
from loguru import logger
from .config import BasePluginConfig, BotConfig, RaianConfig, extract_plugin_config
from .cos import CosConfig, put_object
from .database import DatabaseService
from .statistic import Statistic, commit
BotServiceCtx: ContextVar["RaianBotService"] = ContextVar("bot_service")
class RaianBotService(Service):
id = "raian.core.service"
config: RaianConfig
db: DatabaseService
def __init__(self, config: RaianConfig):
super().__init__()
self.config = config
(Path.cwd() / self.config.data_dir).mkdir(parents=True, exist_ok=True)
self.cache = {}
def ensure_manager(self, manager: Launart):
super().ensure_manager(manager)
if self.config.database.type == "sqlite":
self.config.database.name = f"{self.config.data_dir}/{self.config.database.name}"
if not self.config.database.name.endswith(".db"):
self.config.database.name = f"{self.config.database.name}.db"
manager.add_component(
db := DatabaseService(
self.config.database.url,
{"echo": None, "pool_pre_ping": True},
)
)
self.db = db
@property
def required(self) -> set[str]:
return set()
@property
def stages(self) -> set[Literal["preparing", "blocking", "cleanup"]]:
return {"preparing", "cleanup"}
@classmethod
def current(cls):
"""获取当前上下文的 Bot"""
return BotServiceCtx.get()
async def launch(self, manager: Launart):
token = BotServiceCtx.set(self)
async with self.stage("preparing"):
logger.success("机器人数据加载完毕")
if not await fetch(proxy=self.config.proxy):
logger.error("方舟数据获取失败")
manager.status.exiting = True
return
saya = it(Saya)
with saya.module_context():
for module_info in pkgutil.iter_modules(self.config.plugin.paths):
path = Path(module_info.module_finder.path).stem # noqa # type: ignore
name = module_info.name
if name == "config" or name.startswith("_") or f"{path}.{name}" in self.config.plugin.disabled:
continue
try:
if model := extract_plugin_config(self.config, path, name):
self.config.plugin.configs[type(model)] = model
saya.require(f"{path}.{name}.main")
except BaseException as e:
logger.warning(
f"fail to load {path}.{name}, caused by "
f"{traceback.format_exception(BaseException, e, e.__traceback__, 1)[-1]}"
)
traceback.print_exc()
continue
async with self.stage("cleanup"):
self.cache.clear()
logger.success("机器人数据保存完毕")
BotServiceCtx.reset(token)
def record(self, name: str, disable: bool = False):
def __wrapper__(func):
record = self.cache.setdefault("function::record", {})
disables = self.cache.setdefault("function::disables", set())
record.setdefault(name, func)
func.__record__ = name
if disable:
disables.add(name)
return func
return __wrapper__
@property
def functions(self):
return self.cache.get("function::record", {})
@property
def disabled(self):
return self.cache.get("function::disables", set())
def func_description(self, name: str):
return func.__doc__ if (func := self.cache.get("function::record", {}).get(name)) else "Unknown"
async def upload_to_cos(self, content: Union[bytes, str], name: str, custom_domain: bool = False):
if not self.config.platform.tencentcloud:
return name
config = CosConfig(
secret_id=self.config.platform.tencentcloud.secret_id,
secret_key=self.config.platform.tencentcloud.secret_key,
region=self.config.platform.tencentcloud.region,
scheme="https",
)
await put_object(
config, self.config.platform.tencentcloud.bucket, content, name, headers={"StorageClass": "STANDARD"}
)
return config.uri(
self.config.platform.tencentcloud.bucket,
name,
domain=self.config.platform.tencentcloud.custom_domain if custom_domain else None,
)
class RaianBotDispatcher(BaseDispatcher):
def __init__(self, service: RaianBotService):
self.service = service
async def beforeExecution(self, interface: DispatcherInterface):
try:
context = await interface.lookup_param("context", Context, Force(None))
except RequirementCrashed:
context = None
if context:
interface.local_storage["bot_config"] = next(
(bot for bot in self.service.config.bots if bot.ensure(context.account)), None # type: ignore
)
async def catch(self, interface: DispatcherInterface):
if interface.annotation is RaianBotService:
return self.service
if interface.annotation is RaianConfig:
return self.service.config
if isinstance(interface.annotation, type):
if issubclass(interface.annotation, Service):
manager = Launart.current()
return manager.get_component(interface.annotation)
if issubclass(interface.annotation, BasePluginConfig):
return self.service.config.plugin.get(interface.annotation)
if issubclass(interface.annotation, BotConfig):
return interface.local_storage.get("bot_config")
async def afterExecution(
self,
interface: DispatcherInterface,
exception: Union[Exception, None],
tb: ...,
):
if interface.depth > 0 or exception:
return
await interface.exec_result
result = interface.exec_result.result()
if isinstance(result, Statistic):
await commit(self.service.db, result)
raise PropagationCancelled
# patch
from graiax.playwright import PlaywrightService
from graiax.playwright.i18n import N_
from graiax.playwright.installer import install_playwright
from graiax.playwright.utils import log
from playwright.async_api import Error as PWError
from playwright.async_api import async_playwright
async def launch(self: PlaywrightService, _):
if self.auto_download_browser:
await install_playwright(
self.playwright_download_host,
self.browser_type,
self.install_with_deps,
)
self.playwright_mgr = playwright_mgr = async_playwright() # type: ignore
async with self.stage("preparing"):
self.playwright = await playwright_mgr.__aenter__()
browser_type = {
"chromium": self.playwright.chromium,
"firefox": self.playwright.firefox,
"webkit": self.playwright.webkit,
}[self.browser_type]
try:
if self.use_persistent_context:
log("info", N_("Playwright is currently starting in persistent context mode."))
self._context = await browser_type.launch_persistent_context(**self.launch_config)
else:
self._browser = await browser_type.launch(**self.launch_config)
self._context = await self._browser.new_context(**self.global_context_config)
except PWError:
log(
"error",
N_(
"Unable to launch Playwright for {browser_type}, "
"please check the log output for the reason of failure. "
"It is possible that some system dependencies are missing. "
"You can set [magenta]`install_with_deps`[/] to [magenta]`True`[/] "
"to install dependencies when download browser."
).format(browser_type=self.browser_type),
)
raise
else:
log("success", N_("Playwright for {browser_type} is started.").format(browser_type=self.browser_type))
async with self.stage("cleanup"):
# await self.context.close() # 这里会卡住
await playwright_mgr.__aexit__()
PlaywrightService.launch = launch