diff --git a/novus/api/gateway/gateway.py b/novus/api/gateway/gateway.py index 24599956..fc8d48d5 100644 --- a/novus/api/gateway/gateway.py +++ b/novus/api/gateway/gateway.py @@ -250,6 +250,7 @@ def __init__( self.sequence: int | None = None self.resume_url = self.ws_url self.session_id = None + self.heartbeat_interval = 60.0 # Temporarily cached data for combining requests self.chunk_counter: dict[str, int] = {} # nonce: req_counter @@ -593,9 +594,10 @@ async def _connect( log.debug("[%s] Connected to gateway - %s", self.shard_id, dump(data)) # Start heartbeat - heartbeat_interval = data["heartbeat_interval"] + self.heartbeat_interval = data["heartbeat_interval"] self.heartbeat_task = asyncio.create_task( - self.heartbeat(heartbeat_interval, jitter=not resume) + self.heartbeat(self.heartbeat_interval, jitter=not resume), + name="Heartbeat for shard %s" % self.shard_id, ) # Send identify or resume @@ -673,7 +675,8 @@ async def heartbeat( self, heartbeat_interval: int | float, *, - jitter: bool = False) -> None: + jitter: bool = False, + skip_first_wait: bool = False) -> None: """ Send heartbeats to Discord. This implements a forever loop, so the method should be created as a task. @@ -699,7 +702,9 @@ async def heartbeat( ) while True: try: - await asyncio.sleep(wait / 1_000) + if not skip_first_wait: + await asyncio.sleep(wait / 1_000) + skip_first_wait = False except asyncio.CancelledError: log.debug("[%s] Heartbeat has been cancelled", self.shard_id) return @@ -860,10 +865,19 @@ async def message_handler(self) -> None: async for opcode, event_name, sequence, message in self.messages(): match opcode: - # Ignore heartbeats + # Ignore heartbeat acks case GatewayOpcode.heartbeat_ack: self.heartbeat_received.set() + # Sometimes Discord may ask for heartbeats explicitly + case GatewayOpcode.heartbeat: + if self.heartbeat_task: + self.heartbeat_task.cancel() + self.heartbeat_task = asyncio.create_task( + self.heartbeat(self.heartbeat_interval, skip_first_wait=True), + name="Heartbeat for shard %s" % self.shard_id, + ) + # Deal with dispatch case GatewayOpcode.dispatch: event_name = cast(str, event_name) @@ -902,4 +916,4 @@ async def message_handler(self) -> None: # Everything else case _: - print("Failed to deal with gateway message %s" % dump(message)) + log.warning("Failed to deal with gateway message %s" % dump(message))