-
Notifications
You must be signed in to change notification settings - Fork 0
/
ratelimiter.py
142 lines (120 loc) · 5.36 KB
/
ratelimiter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import asyncio
import contextlib
from typing import Callable, Coroutine, Any, Union, List, Dict, Optional, TypeVar, Type, TYPE_CHECKING, Awaitable
from telegram.error import RetryAfter
from telegram.ext import AIORateLimiter
from telegram.ext._aioratelimiter import null_context
from utils.log import logger
if TYPE_CHECKING:
from gram_core.application import Application
JSONDict: Type[dict[str, Any]] = Dict[str, Any]
RL_ARGS = TypeVar("RL_ARGS")
T_CalledAPIFunc = Callable[[str, Dict[str, Any], Union[bool, JSONDict, List[JSONDict]]], Awaitable[Any]]
class RateLimiter(AIORateLimiter):
_lock = asyncio.Lock()
__slots__ = (
"_retry_after_event_map",
"_application",
)
def __init__(
self,
max_retries: int = 5,
) -> None:
super().__init__(
max_retries=max_retries,
)
self._application: Optional["Application"] = None
self._retry_after_event_map: Dict[int, asyncio.Event] = {0: asyncio.Event()}
self._retry_after_event_map[0].set()
def clear_group_retry_after_event(self, group: Union[str, int]) -> None:
for key, retry_after_event in self._retry_after_event_map.copy().items():
if key == group:
continue
if retry_after_event.is_set():
del self._retry_after_event_map[key]
async def _get_group_retry_after_event(self, group: Union[str, int]) -> asyncio.Event:
async with self._lock:
event = self._retry_after_event_map.get(group)
if event:
return event
if isinstance(group, (str, int)):
if len(self._retry_after_event_map) > 512:
self.clear_group_retry_after_event(group)
if group not in self._retry_after_event_map:
event = asyncio.Event()
event.set()
self._retry_after_event_map[group] = event
event = self._retry_after_event_map[group]
if not event:
event = self._retry_after_event_map[0]
return event
async def _run_request(
self,
chat: bool,
group: Union[str, int, bool],
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
args: Any,
kwargs: Dict[str, Any],
) -> Union[bool, JSONDict, List[JSONDict]]:
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
group_context = self._get_group_limiter(group) if group and self._group_max_rate else null_context()
async with group_context, base_context:
return await callback(*args, **kwargs)
async def process_request(
self,
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, List[JSONDict]]]],
args: Any,
kwargs: Dict[str, Any],
endpoint: str,
data: Dict[str, Any],
rate_limit_args: Optional[RL_ARGS],
) -> Union[bool, JSONDict, List[JSONDict]]:
max_retries = rate_limit_args or self._max_retries
group: Union[int, str, bool] = False
chat: bool = False
chat_id = data.get("chat_id")
if chat_id is not None:
chat = True
# In case user passes integer chat id as string
with contextlib.suppress(ValueError, TypeError):
chat_id = int(chat_id)
if (isinstance(chat_id, int) and chat_id < 0) or isinstance(chat_id, str):
# string chat_id only works for channels and supergroups
# We can't really tell channels from groups though ...
group = chat_id
_retry_after_event = await self._get_group_retry_after_event(group)
await _retry_after_event.wait()
for i in range(max_retries + 1):
try:
result = await self._run_request(chat=chat, group=group, callback=callback, args=args, kwargs=kwargs)
await self._on_called_api(endpoint, data, result)
return result
except RetryAfter as exc:
if endpoint == "setWebhook" and exc.retry_after == 1:
# webhook 已被正确设置
return True
if i == max_retries:
logger.warning("chat_id[%s] 达到最大重试限制 max_retries[%s]", chat_id, exc)
raise exc
sleep = exc.retry_after + 0.1
logger.warning("chat_id[%s] 触发洪水限制 当前被服务器限制 retry_after[%s]秒", chat_id, exc.retry_after)
# Make sure we don't allow other requests to be processed
_retry_after_event.clear()
await asyncio.sleep(sleep)
finally:
# Allow other requests to be processed
_retry_after_event.set()
return None # type: ignore[return-value]
def set_application(self, application: "Application") -> None:
self._application = application
async def _on_called_api(
self,
endpoint: str,
data: Dict[str, Any],
result: Union[bool, JSONDict, List[JSONDict]],
) -> None:
if funcs := [hook(endpoint, data, result) for hook in self._application.get_called_api_funcs()]:
try:
await asyncio.gather(*funcs)
except Exception as e:
logger.error("Error while running CalledAPI hooks: %s", e)