diff --git a/src/pubsub-async-iterator.ts b/src/pubsub-async-iterator.ts index 78b9cba..7eddfc0 100644 --- a/src/pubsub-async-iterator.ts +++ b/src/pubsub-async-iterator.ts @@ -1,4 +1,5 @@ import { PubSubEngine } from 'graphql-subscriptions'; +import { FilterFn } from './with-filter'; /** * A class for digesting PubSubEngine events via the new AsyncIterator interface. @@ -31,9 +32,10 @@ import { PubSubEngine } from 'graphql-subscriptions'; */ export class PubSubAsyncIterator implements AsyncIterableIterator { - constructor(pubsub: PubSubEngine, eventNames: string | string[], options?: unknown) { + constructor(pubsub: PubSubEngine, eventNames: string | string[], options?: unknown, filterFn?: FilterFn) { this.pubsub = pubsub; this.options = options; + this.filterFn = filterFn; this.pullQueue = []; this.pushQueue = []; this.listening = true; @@ -66,9 +68,14 @@ export class PubSubAsyncIterator implements AsyncIterableIterator { private listening: boolean; private pubsub: PubSubEngine; private options: unknown; + private filterFn: FilterFn | undefined; private async pushValue(event) { await this.subscribeAll(); + if (this.filterFn) { + const filterResult = await this.filterFn(event, 0); + if (!filterResult) return; + } if (this.pullQueue.length !== 0) { this.pullQueue.shift()({ value: event, done: false }); } else { diff --git a/src/redis-pubsub.ts b/src/redis-pubsub.ts index de375b4..b6ebfad 100644 --- a/src/redis-pubsub.ts +++ b/src/redis-pubsub.ts @@ -1,6 +1,7 @@ import {Cluster, Redis, RedisOptions} from 'ioredis'; import {PubSubEngine} from 'graphql-subscriptions'; import {PubSubAsyncIterator} from './pubsub-async-iterator'; +import { FilterFn } from './with-filter'; type RedisClient = Redis | Cluster; type OnMessage = (message: T) => void; @@ -139,8 +140,8 @@ export class RedisPubSub implements PubSubEngine { delete this.subscriptionMap[subId]; } - public asyncIterator(triggers: string | string[], options?: unknown): AsyncIterator { - return new PubSubAsyncIterator(this, triggers, options); + public asyncIterator(triggers: string | string[], options?: unknown, filterFn?: FilterFn): AsyncIterator { + return new PubSubAsyncIterator(this, triggers, options, filterFn); } public getSubscriber(): RedisClient { diff --git a/src/test/tests.ts b/src/test/tests.ts index c27fc5c..82e1ba8 100644 --- a/src/test/tests.ts +++ b/src/test/tests.ts @@ -451,6 +451,26 @@ describe('PubSubAsyncIterator', () => { pubSub.publish(eventName, { test: true }); }); + it('should only publish filtered events', done => { + const pubSub = new RedisPubSub(mockOptions); + const eventName = 'test'; + const iterator = pubSub.asyncIterator(eventName, undefined, (eventData) => eventData.filtered === false); + + iterator.next().then(result => { + // tslint:disable-next-line:no-unused-expression + expect(result).to.exist; + // tslint:disable-next-line:no-unused-expression + expect(result.value).to.exist; + // tslint:disable-next-line:no-unused-expression + expect(result.done).to.exist; + expect(result.value.filtered).to.equal(false); + done(); + }); + + pubSub.publish(eventName, { filtered: true }); + pubSub.publish(eventName, { filtered: false }); + }); + it('should not trigger event on asyncIterator when publishing other event', async () => { const pubSub = new RedisPubSub(mockOptions); const eventName = 'test2'; diff --git a/src/with-filter.ts b/src/with-filter.ts index 0d87622..661481b 100644 --- a/src/with-filter.ts +++ b/src/with-filter.ts @@ -1,5 +1,10 @@ export type FilterFn = (rootValue?: any, args?: any, context?: any, info?: any) => boolean; +/** + * Wraps an async-iterator and filters incoming events based on the provided filter function. + * Note: Due to promise chaining this function can use a large amount of memory when a high percentage of messages are filtered + * If using the PubSubAsyncIterator, use the filterFn property directly + */ export const withFilter = (asyncIteratorFn: () => AsyncIterableIterator, filterFn: FilterFn) => { return (rootValue: any, args: any, context: any, info: any): AsyncIterator => { const asyncIterator = asyncIteratorFn();