From c043f1e1305ac791ccca98f05533cc9da5fcc143 Mon Sep 17 00:00:00 2001 From: Devin Ivy Date: Thu, 12 Oct 2023 13:55:42 -0400 Subject: [PATCH] don't proxy hop-by-hop headers, test chunked encoding --- packages/xrpc-server/src/server.ts | 4 +- packages/xrpc-server/src/util.ts | 19 ++++++++- packages/xrpc-server/tests/proxy.test.ts | 50 ++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/packages/xrpc-server/src/server.ts b/packages/xrpc-server/src/server.ts index cea1f23b2ee..e25582e7135 100644 --- a/packages/xrpc-server/src/server.ts +++ b/packages/xrpc-server/src/server.ts @@ -41,6 +41,7 @@ import { } from './types' import { decodeQueryParams, + getHopByHopHeaders, getQueryParams, validateInput, validateOutput, @@ -269,8 +270,9 @@ export class Server { res.statusCode = passthru.statusCode res.statusMessage = passthru.statusMessage if (!res.headersSent) { + const hopByHop = getHopByHopHeaders(passthru.headers.connection) for (const [name, value] of Object.entries(passthru.headers)) { - if (value !== undefined) { + if (value !== undefined && !hopByHop.has(name)) { res.setHeader(name, value) } } diff --git a/packages/xrpc-server/src/util.ts b/packages/xrpc-server/src/util.ts index 175c282266e..963d1177e5a 100644 --- a/packages/xrpc-server/src/util.ts +++ b/packages/xrpc-server/src/util.ts @@ -295,8 +295,9 @@ export async function proxy( ): Promise { // headers const headers: Record = Object.create(null) + const hopByHop = getHopByHopHeaders(ctx.req.headers.connection) for (const [name, value] of Object.entries(ctx.req.headers)) { - if (value !== undefined) { + if (value !== undefined && !hopByHop.has(name)) { headers[name] = Array.isArray(value) ? value.join(', ') : value } } @@ -351,6 +352,22 @@ export async function proxy( } } +export function getHopByHopHeaders(connectionHeader?: string) { + const hopByHop = new Set([ + 'connection', + 'keep-alive', + 'proxy-authenticate', + 'proxy-authorization', + 'te', + 'trailer', + 'transfer-encoding', + 'upgrade', + ]) + const additional = (connectionHeader ?? '').split(/\s*,\s*/) + additional.forEach((header) => hopByHop.add(header.toLowerCase())) + return hopByHop +} + export class ServerTimer implements ServerTiming { public duration?: number private startMs?: number diff --git a/packages/xrpc-server/tests/proxy.test.ts b/packages/xrpc-server/tests/proxy.test.ts index 5fc635a0cfe..a5c7d9d15ac 100644 --- a/packages/xrpc-server/tests/proxy.test.ts +++ b/packages/xrpc-server/tests/proxy.test.ts @@ -101,6 +101,18 @@ const LEXICONS = [ }, }, }, + { + lexicon: 1, + id: 'io.example.stream', + defs: { + main: { + type: 'query', + output: { + encoding: '*/*', + }, + }, + }, + }, { lexicon: 1, id: 'io.example.headers', @@ -206,6 +218,14 @@ describe('Proxy', () => { } }) + proxy.method('io.example.stream', proxyHandler) + server.method('io.example.stream', () => { + return { + encoding: 'application/octet-stream', + body: Readable.from(Buffer.from('streaming bytes')), + } + }) + proxy.method('io.example.headers', (ctx) => { return xrpcServer.proxy(ctx, client.uri.href, { headers: { @@ -218,6 +238,8 @@ describe('Proxy', () => { encoding: 'application/json', body: ctx.req.headers, headers: { + connection: 'x-dont-pass-me-down', + 'x-dont-pass-me-down': 'x-dont-pass-me-down-val', 'my-custom-header-down': 'my-custom-header-down-val', }, } @@ -275,6 +297,7 @@ describe('Proxy', () => { expect(res.data.arr).toEqual([3]) expect(res.data.def).toEqual(0) expect(res.headers['content-type']).toContain('application/json') + expect(res.headers['content-length']).toBeDefined() }) it('proxies json input', async () => { @@ -295,6 +318,7 @@ describe('Proxy', () => { expect(res.data.arr).toEqual([3]) expect(res.data.def).toEqual(0) expect(res.headers['content-type']).toContain('application/json') + expect(res.headers['content-length']).toBeDefined() }) it('proxies blob input uncompressed', async () => { @@ -308,6 +332,7 @@ describe('Proxy', () => { expect(res.success).toBeTruthy() expect(res.data.cid).toBe(cid.toString()) expect(res.headers['content-type']).toContain('application/json') + expect(res.headers['content-length']).toBeDefined() }) it('proxies blob input compressed', async () => { @@ -327,6 +352,7 @@ describe('Proxy', () => { expect(res.success).toBeTruthy() expect(res.data.cid).toBe(cid.toString()) expect(res.headers['content-type']).toContain('application/json') + expect(res.headers['content-length']).toBeDefined() }) it('proxies blob output uncompressed', async () => { @@ -335,6 +361,7 @@ describe('Proxy', () => { expect(res.data.byteLength).toBe(1024) expect(res.headers['content-type']).toBe('application/octet-stream') expect(res.headers['content-encoding']).toBe('identity') + expect(res.headers['content-length']).toBeDefined() }) it('proxies blob output compressed', async () => { @@ -345,6 +372,16 @@ describe('Proxy', () => { expect(res.data.byteLength).toBe(1024) expect(res.headers['content-type']).toBe('application/octet-stream') expect(res.headers['content-encoding']).toBe('gzip') + expect(res.headers['content-length']).toBeDefined() + }) + + it('proxies chunked output', async () => { + const res = await proxyClient.call('io.example.stream') + expect(res.success).toBeTruthy() + expect(Buffer.from(res.data).toString()).toBe('streaming bytes') + expect(res.headers['content-type']).toBe('application/octet-stream') + expect(res.headers['content-length']).toBeUndefined() + expect(res.headers['transfer-encoding']).toBe('chunked') }) it('proxies custom headers', async () => { @@ -363,6 +400,19 @@ describe('Proxy', () => { 'my-custom-header-down-val', ) expect(res.headers['content-type']).toContain('application/json') + expect(res.headers['content-length']).toBeDefined() + }) + + it('does not proxy hop-by-hop headers', async () => { + const res = await proxyClient.call('io.example.headers', {}, undefined, { + headers: { + 'proxy-authenticate': 'proxy-authenticate-val', + }, + }) + expect(res.success).toBeTruthy() + expect(res.data['proxy-authenticate']).toBeUndefined() + expect(res.headers['x-dont-pass-me-down']).toBeUndefined() + expect(res.headers['connection']).not.toContain('x-dont-pass-me-down') }) it('proxies 4xx errors', async () => {