diff --git a/src/features/device-logs/index.ts b/src/features/device-logs/index.ts index 839b621c0d..b458291cd2 100644 --- a/src/features/device-logs/index.ts +++ b/src/features/device-logs/index.ts @@ -22,6 +22,18 @@ const deviceLogsRateLimiter = createRateLimitMiddleware( }, ); +// Rate limit device log get requests +const streamableDeviceLogsRateLimiter = createRateLimitMiddleware( + createRateLimiter('get-device-logs', { + points: 10, // allow 10 device log streams / get requests + blockDuration: 60, // seconds + duration: 60, // reset counter after 60 seconds (from the first batch of the window) + }), + { + ignoreIP: true, + }, +); + export const setup = ( app: Application, onLogWriteStreamInitialized: SetupOptions['onLogWriteStreamInitialized'], @@ -30,6 +42,7 @@ export const setup = ( app.get( '/device/v2/:uuid/logs', middleware.fullyAuthenticatedUser, + streamableDeviceLogsRateLimiter(['params.uuid', 'query.stream']), read(onLogReadStreamInitialized), ); app.post( diff --git a/src/infra/rate-limiting/index.ts b/src/infra/rate-limiting/index.ts index 2f7879e47e..e1fe90f9bd 100644 --- a/src/infra/rate-limiting/index.ts +++ b/src/infra/rate-limiting/index.ts @@ -97,7 +97,7 @@ export type RateLimitKeyFn = ( req: Request, res: Response, ) => Resolvable; -export type RateLimitKey = string | RateLimitKeyFn; +export type RateLimitKey = string | RateLimitKeyFn | string[]; export type RateLimitMiddleware = ( ...args: Parameters @@ -126,14 +126,22 @@ const $createRateLimitMiddleware = ( ignoreIP = false, allowReset = true, }: { ignoreIP?: boolean; allowReset?: boolean } = {}, - field?: RateLimitKey, + fields?: RateLimitKey, ): RateLimitMiddleware => { let fieldFn: RateLimitKeyFn; - if (field != null) { - if (typeof field === 'function') { - fieldFn = field; + if (fields != null) { + if (typeof fields === 'function') { + fieldFn = fields; + } else if (Array.isArray(fields)) { + fieldFn = (req) => + fields + .map((field) => { + const path = _.toPath(field); + return _.get(req, path); + }) + .join('$'); } else { - const path = _.toPath(field); + const path = _.toPath(fields); fieldFn = (req) => _.get(req, path); } } else { diff --git a/test/06_device-log.ts b/test/06_device-log.ts index af95954253..87674ba133 100644 --- a/test/06_device-log.ts +++ b/test/06_device-log.ts @@ -222,4 +222,54 @@ describe('device log', () => { 'streamed log line 1', ]); }); + + it('should rate limit stream-read device', async () => { + const dummyLogs = [createLog({ message: 'not rate limited' })]; + await supertest(ctx.device.apiKey) + .post(`/device/v2/${ctx.device.uuid}/logs`) + .send(dummyLogs) + .expect(201); + + async function testRatelimitedDeviceLogsStream() { + let evalValue; + const req = supertest(ctx.user) + .get(`/device/v2/${ctx.device.uuid}/logs`) + .query({ + stream: 1, + count: 1, + }) + .parse(function (res, callback) { + res.on('data', async function (chunk) { + const parsedChunk = JSON.parse(Buffer.from(chunk).toString()); + evalValue = parsedChunk; + // if data stream provides proper data terminate the stream (abort) + if ( + typeof parsedChunk === 'object' && + parsedChunk?.message === 'not rate limited' + ) { + req.abort(); + } + }); + res.on('close', () => callback(null, null)); + }); + + try { + await req; + } catch (error) { + if (error.code !== 'ABORTED') { + throw error; + } + } + return evalValue; + } + + const notLimited = await testRatelimitedDeviceLogsStream(); + expect(notLimited?.['message']).to.deep.equal(dummyLogs[0].message); + + while ((await testRatelimitedDeviceLogsStream()) !== 'Too Many Requests') { + // no empty block + } + const rateLimited = await testRatelimitedDeviceLogsStream(); + expect(rateLimited).to.be.string('Too Many Requests'); + }); });