diff --git a/src/docs/game-channel-api.docs.ts b/src/docs/game-channel-api.docs.ts index 07ff89e4..f5c6a80d 100644 --- a/src/docs/game-channel-api.docs.ts +++ b/src/docs/game-channel-api.docs.ts @@ -45,6 +45,8 @@ const GameChannelAPIDocs: APIDocs = { createdAt: '2024-10-25T18:18:28.000Z', updatedAt: '2024-12-04T07:15:13.000Z' }, + totalMessages: 36, + memberCount: 8, props: [ { key: 'channelType', value: 'guild' }, { key: 'guildId', value: '5912' } diff --git a/src/services/api-key.service.ts b/src/services/api-key.service.ts index 7aa6ef48..68cdc98e 100644 --- a/src/services/api-key.service.ts +++ b/src/services/api-key.service.ts @@ -113,7 +113,7 @@ export default class APIKeyService extends Service { const socket: Socket = req.ctx.wss const conns = socket.findConnections((conn) => conn.getAPIKeyId() === apiKey.id) for (const conn of conns) { - conn.ws.close(3000) + socket.closeConnection(conn.ws) } await em.flush() diff --git a/src/socket/authenticateSocket.ts b/src/socket/authenticateSocket.ts index 67b3523e..62a82c97 100644 --- a/src/socket/authenticateSocket.ts +++ b/src/socket/authenticateSocket.ts @@ -1,14 +1,12 @@ -import { WebSocket } from 'ws' import getAPIKeyFromToken from '../lib/auth/getAPIKeyFromToken' import { promisify } from 'util' import jwt from 'jsonwebtoken' import { RequestContext } from '@mikro-orm/core' import APIKey from '../entities/api-key' -export default async function authenticateSocket(authHeader: string, ws: WebSocket): Promise { +export default async function authenticateSocket(authHeader: string): Promise { const apiKey = await getAPIKeyFromToken(authHeader) if (!apiKey || apiKey.revokedAt) { - ws.close(3000) return } @@ -20,7 +18,6 @@ export default async function authenticateSocket(authHeader: string, ws: WebSock const secret = apiKey.game.apiSecret.getPlainSecret() await promisify(jwt.verify)(token, secret) } catch (err) { - ws.close(3000) return } diff --git a/src/socket/index.ts b/src/socket/index.ts index 9bf36c92..ff431285 100644 --- a/src/socket/index.ts +++ b/src/socket/index.ts @@ -7,6 +7,13 @@ import SocketConnection from './socketConnection' import SocketRouter from './router/socketRouter' import { sendMessage } from './messages/socketMessage' +type CloseConnectionOptions = { + code?: number + reason?: string + terminate?: boolean + preclosed?: boolean +} + export default class Socket { private readonly wss: WebSocketServer private connections: SocketConnection[] = [] @@ -19,7 +26,7 @@ export default class Socket { ws.on('message', (data) => this.handleMessage(ws, data)) ws.on('pong', () => this.handlePong(ws)) - ws.on('close', () => this.handleCloseConnection(ws)) + ws.on('close', () => this.closeConnection(ws, { preclosed: true })) ws.on('error', captureException) }) @@ -37,7 +44,7 @@ export default class Socket { this.connections.forEach((conn) => { /* v8 ignore start */ if (!conn.alive) { - conn.ws.terminate() + this.closeConnection(conn.ws, { terminate: true }) return } @@ -54,10 +61,12 @@ export default class Socket { async handleConnection(ws: WebSocket, req: IncomingMessage): Promise { await RequestContext.create(this.em, async () => { - const key = await authenticateSocket(req.headers?.authorization ?? '', ws) + const key = await authenticateSocket(req.headers?.authorization ?? '') if (key) { this.connections.push(new SocketConnection(ws, key, req)) sendMessage(this.connections.at(-1), 'v1.connected', {}) + } else { + this.closeConnection(ws) } }) } @@ -80,7 +89,16 @@ export default class Socket { } /* v8 ignore end */ - handleCloseConnection(ws: WebSocket): void { + closeConnection(ws: WebSocket, options: CloseConnectionOptions = {}): void { + const terminate = options.terminate ?? false + const preclosed = options.preclosed ?? false + + if (terminate) { + ws.terminate() + } else if (!preclosed) { + ws.close(options.code ?? 3000, options.reason) + } + this.connections = this.connections.filter((conn) => conn.ws !== ws) } @@ -88,7 +106,7 @@ export default class Socket { const connection = this.connections.find((conn) => conn.ws === ws) /* v8 ignore start */ if (!connection) { - ws.close(3000) + this.closeConnection(ws) return } /* v8 ignore end */ diff --git a/src/socket/router/socketRouter.ts b/src/socket/router/socketRouter.ts index 825e5d3b..3e45d34f 100644 --- a/src/socket/router/socketRouter.ts +++ b/src/socket/router/socketRouter.ts @@ -34,6 +34,11 @@ export default class SocketRouter { const rateLimitExceeded = await conn.checkRateLimitExceeded() if (rateLimitExceeded) { + if (conn.rateLimitWarnings > 3) { + this.socket.closeConnection(conn.ws, { code: 1008, reason: 'RATE_LIMIT_EXCEEDED' }) + } else { + sendError(conn, 'unknown', new SocketError('RATE_LIMIT_EXCEEDED', 'Rate limit exceeded')) + } return } diff --git a/src/socket/socketConnection.ts b/src/socket/socketConnection.ts index 9aaf158f..e32dfe7b 100644 --- a/src/socket/socketConnection.ts +++ b/src/socket/socketConnection.ts @@ -8,7 +8,6 @@ import jwt from 'jsonwebtoken' import { v4 } from 'uuid' import Redis from 'ioredis' import redisConfig from '../config/redis.config' -import SocketError, { sendError } from './messages/socketError' import checkRateLimitExceeded from '../lib/errors/checkRateLimitExceeded' export default class SocketConnection { @@ -64,12 +63,8 @@ export default class SocketConnection { if (rateLimitExceeded) { this.rateLimitWarnings++ - if (this.rateLimitWarnings > 3) { - this.ws.close(1008, 'RATE_LIMIT_EXCEEDED') - } else { - sendError(this, 'unknown', new SocketError('RATE_LIMIT_EXCEEDED', 'Rate limit exceeded')) - } - return } + + return rateLimitExceeded } }