Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jquense committed Oct 5, 2021
1 parent 6ef63e8 commit eb49b91
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 50 deletions.
6 changes: 3 additions & 3 deletions src/SocketIOSubscriptionServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ export default class SocketIOSubscriptionServer<
const request = Object.create((express as any).request);
Object.assign(request, socket.request);

this.log('debug', 'SubscriptionServer: new socket connection', {
this.log('debug', 'new socket connection', {
clientId,
numClients: this.io.engine?.clientsCount ?? 0,
});

this.opened(
this.initConnection(
{
id: clientId,
protocol: 'socket-io',
protocol: '4c-subscription-server',
on: socket.on.bind(socket),
emit(event: string, data: any) {
socket.emit(event, data);
Expand Down
5 changes: 1 addition & 4 deletions src/SubscriptionServer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Request } from 'express';
import type { GraphQLSchema } from 'graphql';
import type { Server, Socket } from 'socket.io';

import AuthorizedSocketConnection from './AuthorizedSocketConnection';
import type { CreateValidationRules } from './AuthorizedSocketConnection';
Expand Down Expand Up @@ -39,9 +38,7 @@ export default abstract class SubscriptionServer<TContext, TCredentials> {

public abstract attach(httpServer: any): void;

protected opened(socket: WebSocket, request: Request) {
this.log('debug', 'new socket connection');

protected initConnection(socket: WebSocket, request: Request) {
const { createContext } = this.config;

// eslint-disable-next-line no-new
Expand Down
106 changes: 70 additions & 36 deletions src/WebSocketSubscriptionServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,37 @@ import ws from 'ws';
import SubscriptionServer, {
SubscriptionServerConfig,
} from './SubscriptionServer';
import { MessageType } from './types';
import { MessageType, SupportedProtocols } from './types';

export type DisconnectReason =
| 'server disconnect'
| 'client disconnect'
| 'ping timeout';
interface Message {
type: MessageType;
payload: any;
ackId?: number;
}

class GraphQLSocket extends EventEmitter {
protocol: 'graphql-transport-ws' | 'socket-io';
protocol: SupportedProtocols;

private pingHandle: NodeJS.Timeout | null;
isAlive = true;

private pongWait: NodeJS.Timeout | null;

constructor(private socket: ws, { keepAlive = 12 * 1000 } = {}) {
constructor(private socket: ws) {
super();

this.socket = socket;
this.isAlive = true;

this.protocol =
socket.protocol === 'graphql-transport-ws'
? socket.protocol
: 'socket-io';
: '4c-subscription-server';

socket.on('pong', () => {
this.isAlive = true;
});

socket.on('message', (data) => {
let msg: Message | null = null;
Expand All @@ -43,35 +51,16 @@ class GraphQLSocket extends EventEmitter {
});

socket.on('close', (code: number, reason: string) => {
clearTimeout(this.pongWait!);
clearInterval(this.pingHandle!);

this.isAlive = false;
super.emit('disconnect', 'client disconnect');
super.emit('close', code, reason);
});
}

// keep alive through ping-pong messages
this.pongWait = null;

this.pingHandle =
keepAlive > 0 && Number.isFinite(keepAlive)
? setInterval(() => {
// ping pong on open sockets only
if (this.socket.readyState === this.socket.OPEN) {
// terminate the connection after pong wait has passed because the client is idle
this.pongWait = setTimeout(() => {
this.socket.terminate();
}, keepAlive);

// listen for client's pong and stop socket termination
this.socket.once('pong', () => {
clearTimeout(this.pongWait!);
this.pongWait = null;
});

this.socket.ping();
}
}, keepAlive)
: null;
disconnect(reason?: DisconnectReason) {
this.emit('disconnect', reason);
super.emit('disconnect', reason);
this.socket.terminate();
}

private ack(msg: { ackId?: number } | null) {
Expand All @@ -97,18 +86,39 @@ class GraphQLSocket extends EventEmitter {
close(code: number, reason: string) {
this.socket.close(code, reason);
}

ping() {
if (this.socket.readyState === this.socket.OPEN) {
this.isAlive = false;
this.socket.ping();
}
}
}

export interface WebSocketSubscriptionServerConfig<TContext, TCredentials>
extends SubscriptionServerConfig<TContext, TCredentials> {
keepAlive?: number;
}
export default class WebSocketSubscriptionServer<
TContext,
TCredentials,
> extends SubscriptionServer<TContext, TCredentials> {
private ws: ws.Server;

constructor(config: SubscriptionServerConfig<TContext, TCredentials>) {
private gqlClients = new WeakMap<ws, GraphQLSocket>();

readonly keepAlive: number;

private pingHandle: NodeJS.Timeout | null = null;

constructor({
keepAlive = 15_000,
...config
}: WebSocketSubscriptionServerConfig<TContext, TCredentials>) {
super(config);

this.ws = new ws.Server({ noServer: true });
this.keepAlive = keepAlive;

this.ws.on('error', () => {
// catch the first thrown error and re-throw it once all clients have been notified
Expand All @@ -126,14 +136,16 @@ export default class WebSocketSubscriptionServer<
if (firstErr) throw firstErr;
});

this.scheduleLivelinessCheck();
this.ws.on('connection', (socket, request) => {
const gqlSocket = new GraphQLSocket(socket);
this.gqlClients.set(socket, gqlSocket);

this.opened(gqlSocket, request as any);
this.initConnection(gqlSocket, request as any);

// socket io clients do this behind the scenes
// so we keep it out of the server logic
if (gqlSocket.protocol === 'socket-io') {
if (gqlSocket.protocol === '4c-subscription-server') {
// inform the client they are good to go
gqlSocket.emit('connect');
}
Expand All @@ -158,6 +170,8 @@ export default class WebSocketSubscriptionServer<
}

async close() {
clearTimeout(this.pingHandle!);

for (const client of this.ws.clients) {
client.close(1001, 'Going away');
}
Expand All @@ -167,4 +181,24 @@ export default class WebSocketSubscriptionServer<
this.ws.close((err) => (err ? reject(err) : resolve()));
});
}

private scheduleLivelinessCheck() {
clearTimeout(this.pingHandle!);
this.pingHandle = setTimeout(() => {
for (const socket of this.ws.clients) {
const gql = this.gqlClients.get(socket);
if (!gql) {
continue;
}
if (!gql.isAlive) {
gql.disconnect('ping timeout');
return;
}

gql.ping();
}

this.scheduleLivelinessCheck();
}, this.keepAlive);
}
}
6 changes: 5 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
export type SupportedProtocols =
| 'graphql-transport-ws'
| '4c-subscription-server';

export interface WebSocket {
protocol: string;
protocol: SupportedProtocols;
id?: string;

close(code: number, reason: string): Promise<void> | void;
Expand Down
5 changes: 1 addition & 4 deletions test/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// import { RedisClient } from 'redis';
// import type { Socket } from 'socket.io-client';
// import socketio from 'socket.io-client';
import { EventEmitter } from 'events';
import http from 'http';

Expand Down Expand Up @@ -70,8 +67,8 @@ export async function startServer(
httpServer,
subscriber,
async close() {
httpServer.close();
await server.close();
httpServer.close();
},
};
}
Expand Down
1 change: 1 addition & 0 deletions test/socket-io.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* eslint-disable no-underscore-dangle */

import socketio from 'socket.io-client';

import { CreateLogger } from '../src';
Expand Down
82 changes: 80 additions & 2 deletions test/websocket.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
/* eslint-disable no-underscore-dangle */

import WebSocketSubscriptionServer from '../src/WebSocketSubscriptionServer';
import schema from './data/schema';
import {
TestClient,
TestCredentialsManager,
delay,
graphql,
startServer,
} from './helpers';

function createServer(subscriber) {
function createServer(subscriber, options = {}) {
return new WebSocketSubscriptionServer({
...options,
path: '/graphql',
schema,
subscriber,
Expand All @@ -22,7 +26,7 @@ function createServer(subscriber) {

type PromiseType<P> = P extends Promise<infer R> ? R : never;

describe('socket-io client', () => {
describe('websocket server', () => {
let server: PromiseType<ReturnType<typeof startServer>>;
let client: TestClient | null = null;

Expand Down Expand Up @@ -121,4 +125,78 @@ describe('socket-io client', () => {

await socket.unsubscribe();
});

it('should not race unsubscribe call', async () => {
const socket = await createClient(
graphql`
subscription TestTodoUpdatedSubscription(
$input: TodoUpdatedSubscriptionInput!
) {
todoUpdated(input: $input) {
todo {
text
}
}
}
`,
{
input: {
id: '1',
},
},
);

await socket.authenticate();

const range = Array.from({ length: 2 }, (_, i) => i);
const promises = [] as any[];
for (const id of range) {
promises.push(socket.subscribe(`s-${id}`));
promises.push(socket.unsubscribe(`s-${id}`));

await delay(0);
}

await Promise.all(promises);

expect(server.subscriber._queues.size).toEqual(0);
expect(server.subscriber._channels.size).toEqual(0);
});

it('should clean up on client close', async () => {
const socket = await createClient(
graphql`
subscription TestTodoUpdatedSubscription(
$input: TodoUpdatedSubscriptionInput!
) {
todoUpdated(input: $input) {
todo {
text
}
}
}
`,
{
input: {
id: '1',
},
},
);

expect(server.subscriber._queues.size).toEqual(0);
expect(server.subscriber._channels.size).toEqual(0);

await socket.authenticate();
await socket.subscribe();

expect(server.subscriber._queues.size).toEqual(1);
expect(server.subscriber._channels.size).toEqual(1);

socket.close();

await delay(50);

expect(server.subscriber._queues.size).toEqual(0);
expect(server.subscriber._channels.size).toEqual(0);
});
});

0 comments on commit eb49b91

Please sign in to comment.