Skip to content

Commit

Permalink
centralise closing sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
tudddorrr committed Dec 14, 2024
1 parent 41cfdef commit 64d3737
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/docs/game-channel-api.docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ const GameChannelAPIDocs: APIDocs<GameChannelAPIService> = {
createdAt: '2024-10-25T18:18:28.000Z',
updatedAt: '2024-12-04T07:15:13.000Z'
},
totalMessages: 36,
memberCount: 8,
props: [
{ key: 'channelType', value: 'guild' },
{ key: 'guildId', value: '5912' }
Expand Down
2 changes: 1 addition & 1 deletion src/services/api-key.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ export default class APIKeyService extends Service {
const socket: Socket = req.ctx.wss
const conns = socket.findConnections((conn) => conn.getAPIKeyId() === apiKey.id)
for (const conn of conns) {
conn.ws.close(3000)
socket.closeConnection(conn.ws)
}

await em.flush()
Expand Down
5 changes: 1 addition & 4 deletions src/socket/authenticateSocket.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import { WebSocket } from 'ws'
import getAPIKeyFromToken from '../lib/auth/getAPIKeyFromToken'
import { promisify } from 'util'
import jwt from 'jsonwebtoken'
import { RequestContext } from '@mikro-orm/core'
import APIKey from '../entities/api-key'

export default async function authenticateSocket(authHeader: string, ws: WebSocket): Promise<APIKey> {
export default async function authenticateSocket(authHeader: string): Promise<APIKey> {
const apiKey = await getAPIKeyFromToken(authHeader)
if (!apiKey || apiKey.revokedAt) {
ws.close(3000)
return
}

Expand All @@ -20,7 +18,6 @@ export default async function authenticateSocket(authHeader: string, ws: WebSock
const secret = apiKey.game.apiSecret.getPlainSecret()
await promisify(jwt.verify)(token, secret)
} catch (err) {
ws.close(3000)
return
}

Expand Down
28 changes: 23 additions & 5 deletions src/socket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import SocketConnection from './socketConnection'
import SocketRouter from './router/socketRouter'
import { sendMessage } from './messages/socketMessage'

type CloseConnectionOptions = {
code?: number
reason?: string
terminate?: boolean
preclosed?: boolean
}

export default class Socket {
private readonly wss: WebSocketServer
private connections: SocketConnection[] = []
Expand All @@ -19,7 +26,7 @@ export default class Socket {

ws.on('message', (data) => this.handleMessage(ws, data))
ws.on('pong', () => this.handlePong(ws))
ws.on('close', () => this.handleCloseConnection(ws))
ws.on('close', () => this.closeConnection(ws, { preclosed: true }))
ws.on('error', captureException)
})

Expand All @@ -37,7 +44,7 @@ export default class Socket {
this.connections.forEach((conn) => {
/* v8 ignore start */
if (!conn.alive) {
conn.ws.terminate()
this.closeConnection(conn.ws, { terminate: true })
return
}

Expand All @@ -54,10 +61,12 @@ export default class Socket {

async handleConnection(ws: WebSocket, req: IncomingMessage): Promise<void> {
await RequestContext.create(this.em, async () => {
const key = await authenticateSocket(req.headers?.authorization ?? '', ws)
const key = await authenticateSocket(req.headers?.authorization ?? '')
if (key) {
this.connections.push(new SocketConnection(ws, key, req))
sendMessage(this.connections.at(-1), 'v1.connected', {})
} else {
this.closeConnection(ws)
}
})
}
Expand All @@ -80,15 +89,24 @@ export default class Socket {
}
/* v8 ignore end */

handleCloseConnection(ws: WebSocket): void {
closeConnection(ws: WebSocket, options: CloseConnectionOptions = {}): void {
const terminate = options.terminate ?? false
const preclosed = options.preclosed ?? false

if (terminate) {
ws.terminate()
} else if (!preclosed) {
ws.close(options.code ?? 3000, options.reason)
}

this.connections = this.connections.filter((conn) => conn.ws !== ws)
}

findConnectionBySocket(ws: WebSocket): SocketConnection | undefined {
const connection = this.connections.find((conn) => conn.ws === ws)
/* v8 ignore start */
if (!connection) {
ws.close(3000)
this.closeConnection(ws)
return
}
/* v8 ignore end */
Expand Down
5 changes: 5 additions & 0 deletions src/socket/router/socketRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ export default class SocketRouter {

const rateLimitExceeded = await conn.checkRateLimitExceeded()
if (rateLimitExceeded) {
if (conn.rateLimitWarnings > 3) {
this.socket.closeConnection(conn.ws, { code: 1008, reason: 'RATE_LIMIT_EXCEEDED' })
} else {
sendError(conn, 'unknown', new SocketError('RATE_LIMIT_EXCEEDED', 'Rate limit exceeded'))
}
return
}

Expand Down
9 changes: 2 additions & 7 deletions src/socket/socketConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import jwt from 'jsonwebtoken'
import { v4 } from 'uuid'
import Redis from 'ioredis'
import redisConfig from '../config/redis.config'
import SocketError, { sendError } from './messages/socketError'
import checkRateLimitExceeded from '../lib/errors/checkRateLimitExceeded'

export default class SocketConnection {
Expand Down Expand Up @@ -64,12 +63,8 @@ export default class SocketConnection {

if (rateLimitExceeded) {
this.rateLimitWarnings++
if (this.rateLimitWarnings > 3) {
this.ws.close(1008, 'RATE_LIMIT_EXCEEDED')
} else {
sendError(this, 'unknown', new SocketError('RATE_LIMIT_EXCEEDED', 'Rate limit exceeded'))
}
return
}

return rateLimitExceeded
}
}

0 comments on commit 64d3737

Please sign in to comment.