Skip to content

Commit

Permalink
socket tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tudddorrr committed Dec 9, 2024
1 parent e40d417 commit 90cb4b9
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 47 deletions.
20 changes: 20 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"hygen": "^6.2.11",
"lint-staged": ">=10",
"supertest": "^7.0.0",
"superwstest": "^2.0.4",
"ts-node": "^10.7.0",
"tsx": "^4.11.0",
"typescript": "^5.4.5",
Expand Down
4 changes: 2 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ export default async function init(): Promise<Koa> {

app.use(cleanupMiddleware)

const server = createServer(app.callback())
app.context.wss = new Socket(server, app.context.em)
if (!isTest) {
const server = createServer(app.callback())
app.context.wss = new Socket(server, app.context.em)
server.listen(80, () => console.info('Listening on port 80'))
}

Expand Down
8 changes: 6 additions & 2 deletions src/middlewares/player-auth-middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ export async function validateAuthSessionToken(ctx: Context, alias: PlayerAlias)
}

try {
const payload = await promisify(jwt.verify)(sessionToken, alias.player.auth.sessionKey)
if (payload.playerId !== ctx.state.currentPlayerId || payload.aliasId !== ctx.state.currentAliasId) {
if (!await validateSessionTokenJWT(sessionToken as string, alias)) {
throw new Error()
}
} catch (err) {
Expand All @@ -54,3 +53,8 @@ export async function validateAuthSessionToken(ctx: Context, alias: PlayerAlias)
})
}
}

export async function validateSessionTokenJWT(sessionToken: string, alias: PlayerAlias): Promise<boolean> {
const payload = await promisify(jwt.verify)(sessionToken, alias.player.auth.sessionKey)
return payload.playerId === alias.player.id && payload.aliasId === alias.id
}
14 changes: 10 additions & 4 deletions src/socket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export default class Socket {

ws.on('message', (data) => this.handleMessage(ws, data))
ws.on('pong', () => this.handlePong(ws))
ws.on('close', () => this.handleClose(ws))
ws.on('close', () => this.handleCloseConnection(ws))
ws.on('error', captureException)
})

Expand All @@ -28,6 +28,10 @@ export default class Socket {
this.heartbeat()
}

getServer(): WebSocketServer {
return this.wss
}

heartbeat(): void {
const interval = setInterval(() => {
this.connections.forEach((conn) => {
Expand All @@ -49,8 +53,10 @@ export default class Socket {
async handleConnection(ws: WebSocket, req: IncomingMessage): Promise<void> {
await RequestContext.create(this.em, async () => {
const key = await authenticateSocket(req.headers?.authorization ?? '', ws)
this.connections.push(new SocketConnection(ws, key))
sendMessage(this.connections.at(-1), 'v1.connected', {})
if (key) {
this.connections.push(new SocketConnection(ws, key))
sendMessage(this.connections.at(-1), 'v1.connected', {})
}
})
}

Expand All @@ -67,7 +73,7 @@ export default class Socket {
connection.alive = true
}

handleClose(ws: WebSocket): void {
handleCloseConnection(ws: WebSocket): void {
this.connections = this.connections.filter((conn) => conn.ws !== ws)
}

Expand Down
10 changes: 8 additions & 2 deletions src/socket/listeners/gameChannelListeners.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ const gameChannelListeners: SocketMessageListener<ZodType>[] = [
populate: ['members']
}))

if (!channel) return
if (!channel) {
throw new Error('Channel not found')
}
if (!channel.members.getIdentifiers().includes(conn.playerAlias.id)) {
throw new Error('Player not in channel')
}

const conns = socket.findConnections((conn) => channel.members.getIdentifiers().includes(conn.playerAlias.id))
sendMessages(conns, 'v1.channels.message', {
channelName: channel.name,
message: data.message
message: data.message,
fromPlayerAlias: conn.playerAlias
})
},
{
Expand Down
27 changes: 22 additions & 5 deletions src/socket/listeners/playerListeners.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,57 @@ import { sendMessage } from '../messages/socketMessage'
import Redis from 'ioredis'
import redisConfig from '../../config/redis.config'
import { RequestContext } from '@mikro-orm/core'
import PlayerAlias from '../../entities/player-alias'
import PlayerAlias, { PlayerAliasService } from '../../entities/player-alias'
import { SocketMessageListener } from '../router/createListener'
import SocketError, { sendError } from '../messages/socketError'
import { APIKeyScope } from '../../entities/api-key'
import { validateSessionTokenJWT } from '../../middlewares/player-auth-middleware'

const playerListeners: SocketMessageListener<ZodType>[] = [
createListener(
'v1.players.identify',
z.object({
playerAliasId: z.number(),
token: z.string()
socketToken: z.string(),
sessionToken: z.string().optional()
}),
async ({ conn, req, data }) => {
const redis = new Redis(redisConfig)
const token = await redis.get(`socketTokens.${data.playerAliasId}`)

if (token === data.token) {
if (token === data.socketToken) {
conn.playerAlias = await (RequestContext.getEntityManager()
.getRepository(PlayerAlias)
.findOne({
id: data.playerAliasId,
player: {
game: conn.game
}
}, {
populate: ['player.auth']
}))

sendMessage(conn, 'v1.players.identify.success', conn.playerAlias)
if (conn.playerAlias.service === PlayerAliasService.TALO) {
try {
if (!await validateSessionTokenJWT(data.sessionToken, conn.playerAlias)) {
throw new Error()
}
sendMessage(conn, 'v1.players.identify.success', conn.playerAlias)
} catch (err) {
sendError(conn, req, new SocketError('INVALID_SESSION', 'Session token is invalid'))
}
} else {
sendMessage(conn, 'v1.players.identify.success', conn.playerAlias)
}
} else {
sendError(conn, req, new SocketError('INVALID_SOCKET_TOKEN', 'Invalid socket token'))
}

await redis.quit()
},
{
requirePlayer: false
requirePlayer: false,
apiKeyScopes: [APIKeyScope.READ_PLAYERS]
}
)
]
Expand Down
9 changes: 6 additions & 3 deletions src/socket/messages/socketError.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ const codes = [
'ROUTING_ERROR',
'LISTENER_ERROR',
'INVALID_SOCKET_TOKEN',
'MISSING_ACCESS_KEY_SCOPE'
'INVALID_SESSION',
'MISSING_ACCESS_KEY_SCOPES'
] as const

export type SocketErrorCode = typeof codes[number]

export default class SocketError {
constructor(public code: SocketErrorCode, public message: string) {}
constructor(public code: SocketErrorCode, public message: string, public cause?: string) {}
}

type SocketErrorReq = SocketMessageRequest | 'unknown'
Expand All @@ -30,9 +31,11 @@ export function sendError(conn: SocketConnection, req: SocketErrorReq, error: So
req: SocketErrorReq
message: string
errorCode: SocketErrorCode
cause?: string
}>(conn, 'v1.error', {
req,
message: error.message,
errorCode: error.code
errorCode: error.code,
cause: error.cause
})
}
16 changes: 11 additions & 5 deletions src/socket/messages/socketMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ export const responses = [
export type SocketMessageResponse = typeof responses[number]

export function sendMessage<T>(conn: SocketConnection, res: SocketMessageResponse, data: T) {
conn.ws.send(JSON.stringify({
res,
data
}))
if (conn.ws.readyState === conn.ws.OPEN) {
conn.ws.send(JSON.stringify({
res,
data
}))
}
}

export function sendMessages<T>(conns: SocketConnection[], type: SocketMessageResponse, data: T) {
conns.forEach((ws) => sendMessage<T>(ws, type, data))
conns.forEach((ws) => {
if (ws.ws.readyState === ws.ws.OPEN) {
sendMessage<T>(ws, type, data)
}
})
}
3 changes: 2 additions & 1 deletion src/socket/router/createListener.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { z, ZodType } from 'zod'
import { SocketMessageRequest } from '../messages/socketMessage'
import SocketConnection from '../socketConnection'
import Socket from '..'
import { APIKeyScope } from '../../entities/api-key'

type SocketMessageListenerHandlerParams<T> = {
conn: SocketConnection
Expand All @@ -13,7 +14,7 @@ type SocketMessageListenerHandlerParams<T> = {
type SocketMessageListenerHandler<T> = (params: SocketMessageListenerHandlerParams<T>) => void | Promise<void>
type SocketMessageListenerOptions = {
requirePlayer?: boolean
apiKeyScopes?: string[]
apiKeyScopes?: APIKeyScope[]
}

export type SocketMessageListener<T extends ZodType> = {
Expand Down
34 changes: 22 additions & 12 deletions src/socket/router/socketRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export default class SocketRouter {
let message: SocketMessage = null

try {
message = await this.getParsedMessage(rawData)
message = await socketMessageValidator.parseAsync(JSON.parse(rawData.toString()))

const handled = await this.routeMessage(conn, message)
if (!handled) {
Expand All @@ -50,10 +50,6 @@ export default class SocketRouter {
}
}

async getParsedMessage(rawData: RawData): Promise<SocketMessage> {
return await socketMessageValidator.parseAsync(JSON.parse(rawData.toString()))
}

async routeMessage(conn: SocketConnection, message: SocketMessage): Promise<boolean> {
let handled = false

Expand All @@ -63,22 +59,22 @@ export default class SocketRouter {
try {
handled = true

if ((listener.options.requirePlayer ?? true) && !conn.playerAlias) {
sendError(conn, message.req, new SocketError('NO_PLAYER_FOUND', 'No player found'))
} else if ((listener.options.apiKeyScopes ?? []).some((scope) => !conn.scopes.includes(scope as APIKeyScope))) {
const missing = listener.options.apiKeyScopes.filter((scope) => !conn.scopes.includes(scope as APIKeyScope))
sendError(conn, message.req, new SocketError('MISSING_ACCESS_KEY_SCOPE', `Missing access key scope(s): ${missing.join(', ')}`))
if (!this.meetsPlayerRequirement(conn, listener)) {
sendError(conn, message.req, new SocketError('NO_PLAYER_FOUND', 'You must identify a player before sending this request'))
} else if (!this.meetsScopeRequirements(conn, listener)) {
const missing = this.getMissingScopes(conn, listener)
sendError(conn, message.req, new SocketError('MISSING_ACCESS_KEY_SCOPES', `Missing access key scope(s): ${missing.join(', ')}`))
} else {
const data = await listener.validator.parseAsync(message.data)
listener.handler({ conn, req: listener.req, data, socket: this.socket })
await listener.handler({ conn, req: listener.req, data, socket: this.socket })
}

break
} catch (err) {
if (err instanceof ZodError) {
sendError(conn, message.req, new SocketError('INVALID_MESSAGE', 'Invalid message data for request'))
} else {
sendError(conn, message.req, new SocketError('LISTENER_ERROR', 'An error occurred while processing the message'))
sendError(conn, message.req, new SocketError('LISTENER_ERROR', 'An error occurred while processing the message', err.message))
}
}
}
Expand All @@ -87,4 +83,18 @@ export default class SocketRouter {

return handled
}

meetsPlayerRequirement(conn: SocketConnection, listener: SocketMessageListener<ZodType>): boolean {
const requirePlayer = listener.options.requirePlayer ?? true
return Boolean(conn.playerAlias) || !requirePlayer
}

meetsScopeRequirements(conn: SocketConnection, listener: SocketMessageListener<ZodType>): boolean {
const requiredScopes = listener.options.apiKeyScopes ?? []
return requiredScopes.every((scope) => conn.scopes.includes(scope as APIKeyScope))
}

getMissingScopes(conn: SocketConnection, listener: SocketMessageListener<ZodType>): APIKeyScope[] {
return (listener.options.apiKeyScopes ?? []).filter((scope) => !conn.scopes.includes(scope))
}
}
7 changes: 7 additions & 0 deletions tests/setupTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import init from '../src'
import ormConfig from '../src/config/mikro-orm.config'
import createClickhouseClient from '../src/lib/clickhouse/createClient'
import { NodeClickHouseClient } from '@clickhouse/client/dist/client'
import { createServer } from 'http'

beforeAll(async () => {
vi.mock('@sendgrid/mail')
Expand All @@ -16,6 +17,9 @@ beforeAll(async () => {
global.app = app.callback()
global.em = app.context.em

global.server = createServer()
global.server.listen(0)

global.clickhouse = createClickhouseClient()
await (global.clickhouse as NodeClickHouseClient).command({
query: `TRUNCATE ALL TABLES from ${process.env.CLICKHOUSE_DB}`
Expand All @@ -25,10 +29,13 @@ beforeAll(async () => {
afterAll(async () => {
await (global.em as EntityManager).getConnection().close(true)

global.server.close()

const clickhouse = global.clickhouse as NodeClickHouseClient
clickhouse.close()

delete global.em
delete global.app
delete global.server
delete global.clickhouse
})
Loading

0 comments on commit 90cb4b9

Please sign in to comment.