diff --git a/packages/express/README.md b/packages/express/README.md index a7dff06..e13b5bc 100644 --- a/packages/express/README.md +++ b/packages/express/README.md @@ -150,6 +150,7 @@ app.listen(port, () => { saltByteLength: 8, secretByteLength: 18, token: { + fieldName: 'csrf_token', responseHeader: 'X-CSRF-Token' } } diff --git a/packages/express/src/index.ts b/packages/express/src/index.ts index 6a40686..8fcef22 100644 --- a/packages/express/src/index.ts +++ b/packages/express/src/index.ts @@ -1,8 +1,9 @@ import * as cookielib from 'cookie'; import type { Request as ExpressRequest, Response as ExpressResponse, RequestHandler as ExpressRequestHandler } from 'express'; -import { CsrfError, createCsrfProtect as _createCsrfProtect, Config, TokenOptions } from '@shared/protect'; -import type { ConfigOptions } from '@shared/protect'; +import { Config, TokenOptions } from '@shared/config'; +import type { ConfigOptions } from '@shared/config'; +import { CsrfError, createCsrfProtect as _createCsrfProtect } from '@shared/protect'; export { CsrfError }; diff --git a/packages/nextjs/README.md b/packages/nextjs/README.md index ae5f40e..bea6867 100644 --- a/packages/nextjs/README.md +++ b/packages/nextjs/README.md @@ -179,6 +179,7 @@ export const middleware = async (request: NextRequest) => { saltByteLength: 8, secretByteLength: 18, token: { + fieldName: 'csrf_token', responseHeader: 'X-CSRF-Token', value: undefined } diff --git a/packages/nextjs/src/index.test.ts b/packages/nextjs/src/index.test.ts index 8356802..af6add0 100644 --- a/packages/nextjs/src/index.test.ts +++ b/packages/nextjs/src/index.test.ts @@ -133,7 +133,7 @@ describe('csrfProtect integration tests', () => { expect(newTokenStr).not.toBe(''); }); - it('should handle server action non-form submissions', async () => { + it('should handle server action non-form submissions with string arg0', async () => { const secret = util.createSecret(8); const token = await util.createToken(secret, 8); @@ -153,6 +153,26 @@ describe('csrfProtect integration tests', () => { expect(newTokenStr).not.toBe(''); }); + it('should handle server action non-form submissions with object arg0', async () => { + const secret = util.createSecret(8); + const token = await util.createToken(secret, 8); + + const request = new NextRequest('http://example.com', { + method: 'POST', + headers: { 'Content-Type': 'text/plain' }, + body: JSON.stringify([{ csrf_token: util.utoa(token) }, 'arg']), + }); + request.cookies.set('_csrfSecret', util.utoa(secret)); + + const response = NextResponse.next(); + await csrfProtectDefault(request, response); + + // assertions + const newTokenStr = response.headers.get('X-CSRF-Token'); + expect(newTokenStr).toBeDefined(); + expect(newTokenStr).not.toBe(''); + }); + it('should fail with token from different secret', async () => { const evilSecret = util.createSecret(8); const goodSecret = util.createSecret(8); diff --git a/packages/nextjs/src/index.ts b/packages/nextjs/src/index.ts index d07c159..3b8f7c9 100644 --- a/packages/nextjs/src/index.ts +++ b/packages/nextjs/src/index.ts @@ -2,8 +2,9 @@ import type { NextRequest } from 'next/server'; // eslint-disable-next-line import/no-extraneous-dependencies import { NextResponse } from 'next/server'; -import { CsrfError, createCsrfProtect as _createCsrfProtect, Config, TokenOptions } from '@shared/protect'; -import type { ConfigOptions } from '@shared/protect'; +import { Config, TokenOptions } from '@shared/config'; +import type { ConfigOptions } from '@shared/config'; +import { CsrfError, createCsrfProtect as _createCsrfProtect } from '@shared/protect'; export { CsrfError }; diff --git a/packages/node-http/README.md b/packages/node-http/README.md index d088544..7e39d8e 100644 --- a/packages/node-http/README.md +++ b/packages/node-http/README.md @@ -108,6 +108,7 @@ Check out the example Node-HTTP server in this repository: [Node-HTTP example](e saltByteLength: 8, secretByteLength: 18, token: { + fieldName: 'csrf_token', responseHeader: 'X-CSRF-Token' } } diff --git a/packages/node-http/src/index.ts b/packages/node-http/src/index.ts index 8180359..9c1f529 100644 --- a/packages/node-http/src/index.ts +++ b/packages/node-http/src/index.ts @@ -2,8 +2,9 @@ import type { IncomingMessage, ServerResponse } from 'http'; import * as cookielib from 'cookie'; -import { CsrfError, createCsrfProtect as _createCsrfProtect, Config, TokenOptions } from '@shared/protect'; -import type { ConfigOptions } from '@shared/protect'; +import { Config, TokenOptions } from '@shared/config'; +import type { ConfigOptions } from '@shared/config'; +import { CsrfError, createCsrfProtect as _createCsrfProtect } from '@shared/protect'; export { CsrfError }; diff --git a/packages/sveltekit/README.md b/packages/sveltekit/README.md index 8778d7c..8ac1c89 100644 --- a/packages/sveltekit/README.md +++ b/packages/sveltekit/README.md @@ -146,6 +146,7 @@ export const handle: Handle = async ({ event, resolve }) => { saltByteLength: 8, secretByteLength: 18, token: { + fieldName: 'csrf_token', value: undefined } } diff --git a/packages/sveltekit/src/index.ts b/packages/sveltekit/src/index.ts index 20648a2..5a00642 100644 --- a/packages/sveltekit/src/index.ts +++ b/packages/sveltekit/src/index.ts @@ -1,7 +1,8 @@ import type { Handle, RequestEvent, Cookies } from '@sveltejs/kit'; -import { CsrfError, createCsrfProtect as _createCsrfProtect, Config, TokenOptions } from '@shared/protect'; -import type { ConfigOptions } from '@shared/protect'; +import { Config, TokenOptions } from '@shared/config'; +import type { ConfigOptions } from '@shared/config'; +import { CsrfError, createCsrfProtect as _createCsrfProtect } from '@shared/protect'; export { CsrfError }; diff --git a/shared/src/config.test.ts b/shared/src/config.test.ts new file mode 100644 index 0000000..179b243 --- /dev/null +++ b/shared/src/config.test.ts @@ -0,0 +1,90 @@ +import { Config, CookieOptions, TokenOptions } from './config'; +import type { ConfigOptions } from './config'; + +describe('CookieOptions tests', () => { + it('returns default values when options are absent', () => { + const cookieOpts = new CookieOptions(); + expect(cookieOpts.domain).toEqual(''); + expect(cookieOpts.httpOnly).toEqual(true); + expect(cookieOpts.maxAge).toEqual(undefined); + expect(cookieOpts.name).toEqual('_csrfSecret'); + expect(cookieOpts.partitioned).toEqual(undefined); + expect(cookieOpts.path).toEqual('/'); + expect(cookieOpts.sameSite).toEqual('strict'); + expect(cookieOpts.secure).toEqual(true); + }); + + it('handles overrides', () => { + const cookieOpts = new CookieOptions({ domain: 'xxx' }); + expect(cookieOpts.domain).toEqual('xxx'); + }); +}); + +describe('TokenOptions tests', () => { + it('returns default values when options are absent', () => { + const tokenOpts = new TokenOptions(); + expect(tokenOpts.fieldName).toEqual('csrf_token'); + expect(tokenOpts.value).toEqual(undefined); + }); + + it('handles overrides', () => { + const fn = async () => ''; + const tokenOpts = new TokenOptions({ + fieldName: 'csrfToken', + value: fn, + }); + expect(tokenOpts.fieldName).toEqual('csrfToken'); + expect(tokenOpts.value).toBe(fn); + }); +}); + +describe('Config tests', () => { + const initConfigFn = (opts: Partial) => () => new Config(opts); + + it('returns default config when options are absent', () => { + const config = new Config(); + expect(config.excludePathPrefixes).toEqual([]); + expect(config.ignoreMethods).toEqual(['GET', 'HEAD', 'OPTIONS']); + expect(config.saltByteLength).toEqual(8); + expect(config.secretByteLength).toEqual(18); + expect(config.cookie instanceof CookieOptions).toBe(true); + expect(config.token instanceof TokenOptions).toBe(true); + }); + + it('handles top-level overrides', () => { + const config = new Config({ saltByteLength: 10 }); + expect(config.saltByteLength).toEqual(10); + }); + + it('handles nested cookie overrides', () => { + const config = new Config({ cookie: { domain: 'xxx' } }); + expect(config.cookie.domain).toEqual('xxx'); + }); + + it('handles nested token overrides', () => { + const fn = async () => ''; + const config = new Config({ token: { fieldName: 'csrfToken', value: fn } }); + expect(config.token.fieldName).toEqual('csrfToken'); + expect(config.token.value).toBe(fn); + }); + + it('saltByteLength must be greater than 0', () => { + expect(initConfigFn({ saltByteLength: 0 })).toThrow(Error); + expect(initConfigFn({ saltByteLength: 1 })).not.toThrow(Error); + }); + + it('saltByteLength must be less than 256', () => { + expect(initConfigFn({ saltByteLength: 256 })).toThrow(Error); + expect(initConfigFn({ saltByteLength: 255 })).not.toThrow(Error); + }); + + it('secretByteLength must be greater than 0', () => { + expect(initConfigFn({ secretByteLength: 0 })).toThrow(Error); + expect(initConfigFn({ secretByteLength: 1 })).not.toThrow(Error); + }); + + it('secretByteLength must be less than 256', () => { + expect(initConfigFn({ secretByteLength: 256 })).toThrow(Error); + expect(initConfigFn({ secretByteLength: 255 })).not.toThrow(Error); + }); +}); diff --git a/shared/src/config.ts b/shared/src/config.ts new file mode 100644 index 0000000..436b5d5 --- /dev/null +++ b/shared/src/config.ts @@ -0,0 +1,90 @@ +/** + * Represents cookie options in config + */ +export class CookieOptions { + domain: string = ''; + + httpOnly: boolean = true; + + maxAge: number | undefined = undefined; + + name: string = '_csrfSecret'; + + partitioned: boolean | undefined = undefined; + + path: string = '/'; + + sameSite: boolean | 'none' | 'strict' | 'lax' = 'strict'; + + secure: boolean = true; + + constructor(opts?: Partial) { + Object.assign(this, opts); + } +} + +/** + * Represents a function to retrieve token value from a request + */ +export type TokenValueFunction = { + (request: Request): Promise +}; + +/** + * Represents token options in config + */ +export class TokenOptions { + readonly fieldName: string = 'csrf_token'; + + value: TokenValueFunction | undefined = undefined; + + _fieldNameRegex: RegExp; + + constructor(opts?: Partial) { + Object.assign(this, opts); + + // create fieldname regex + this._fieldNameRegex = new RegExp(`^(\\d+_)*${this.fieldName}$`); + } +} + +/** + * Represents CsrfProtect configuration object + */ +export class Config { + excludePathPrefixes: string[] = []; + + ignoreMethods: string[] = ['GET', 'HEAD', 'OPTIONS']; + + saltByteLength: number = 8; + + secretByteLength: number = 18; + + cookie: CookieOptions = new CookieOptions(); + + token: TokenOptions = new TokenOptions(); + + constructor(opts?: Partial) { + const newOpts = opts || {}; + if (newOpts.cookie) newOpts.cookie = new CookieOptions(newOpts.cookie); + if (newOpts.token) newOpts.token = new TokenOptions(newOpts.token); + Object.assign(this, newOpts); + + // basic validation + if (this.saltByteLength < 1 || this.saltByteLength > 255) { + throw Error('saltBytLength must be greater than 0 and less than 256'); + } + + if (this.secretByteLength < 1 || this.secretByteLength > 255) { + throw Error('secretBytLength must be greater than 0 and less than 256'); + } + } +} + +/** + * Represents CsrfProtect configuration options object + */ +export interface ConfigOptions extends Omit { + cookie: Partial; + token: Partial; +} diff --git a/shared/src/protect.test.ts b/shared/src/protect.test.ts index bfa9838..a649464 100644 --- a/shared/src/protect.test.ts +++ b/shared/src/protect.test.ts @@ -1,91 +1,10 @@ import { vi } from 'vitest'; -import { Config, CookieOptions, CsrfError, TokenOptions, createCsrfProtect } from './protect'; -import type { ConfigOptions, Cookie, CsrfProtectArgs } from './protect'; +import { Config, CookieOptions } from './config'; +import { CsrfError, createCsrfProtect } from './protect'; +import type { Cookie, CsrfProtectArgs } from './protect'; import * as util from './util'; -describe('CookieOptions tests', () => { - it('returns default values when options are absent', () => { - const cookieOpts = new CookieOptions(); - expect(cookieOpts.domain).toEqual(''); - expect(cookieOpts.httpOnly).toEqual(true); - expect(cookieOpts.maxAge).toEqual(undefined); - expect(cookieOpts.name).toEqual('_csrfSecret'); - expect(cookieOpts.partitioned).toEqual(undefined); - expect(cookieOpts.path).toEqual('/'); - expect(cookieOpts.sameSite).toEqual('strict'); - expect(cookieOpts.secure).toEqual(true); - }); - - it('handles overrides', () => { - const cookieOpts = new CookieOptions({ domain: 'xxx' }); - expect(cookieOpts.domain).toEqual('xxx'); - }); -}); - -describe('TokenOptions tests', () => { - it('returns default values when options are absent', () => { - const tokenOpts = new TokenOptions(); - expect(tokenOpts.value).toEqual(undefined); - }); - - it('handles overrides', () => { - const fn = async () => ''; - const tokenOpts = new TokenOptions({ value: fn }); - expect(tokenOpts.value).toBe(fn); - }); -}); - -describe('Config tests', () => { - const initConfigFn = (opts: Partial) => () => new Config(opts); - - it('returns default config when options are absent', () => { - const config = new Config(); - expect(config.excludePathPrefixes).toEqual([]); - expect(config.ignoreMethods).toEqual(['GET', 'HEAD', 'OPTIONS']); - expect(config.saltByteLength).toEqual(8); - expect(config.secretByteLength).toEqual(18); - expect(config.cookie instanceof CookieOptions).toBe(true); - expect(config.token instanceof TokenOptions).toBe(true); - }); - - it('handles top-level overrides', () => { - const config = new Config({ saltByteLength: 10 }); - expect(config.saltByteLength).toEqual(10); - }); - - it('handles nested cookie overrides', () => { - const config = new Config({ cookie: { domain: 'xxx' } }); - expect(config.cookie.domain).toEqual('xxx'); - }); - - it('handles nested token overrides', () => { - const fn = async () => ''; - const config = new Config({ token: { value: fn } }); - expect(config.token.value).toBe(fn); - }); - - it('saltByteLength must be greater than 0', () => { - expect(initConfigFn({ saltByteLength: 0 })).toThrow(Error); - expect(initConfigFn({ saltByteLength: 1 })).not.toThrow(Error); - }); - - it('saltByteLength must be less than 256', () => { - expect(initConfigFn({ saltByteLength: 256 })).toThrow(Error); - expect(initConfigFn({ saltByteLength: 255 })).not.toThrow(Error); - }); - - it('secretByteLength must be greater than 0', () => { - expect(initConfigFn({ secretByteLength: 0 })).toThrow(Error); - expect(initConfigFn({ secretByteLength: 1 })).not.toThrow(Error); - }); - - it('secretByteLength must be less than 256', () => { - expect(initConfigFn({ secretByteLength: 256 })).toThrow(Error); - expect(initConfigFn({ secretByteLength: 255 })).not.toThrow(Error); - }); -}); - describe('csrfProtect tests', () => { let createSecretMock = vi.spyOn(util, 'createSecret'); let getTokenStringMock = vi.spyOn(util, 'getTokenString'); diff --git a/shared/src/protect.ts b/shared/src/protect.ts index fa539cd..8519e67 100644 --- a/shared/src/protect.ts +++ b/shared/src/protect.ts @@ -1,88 +1,12 @@ +import { Config } from './config'; +import type { ConfigOptions, CookieOptions } from './config'; import { createSecret, createToken, getTokenString, verifyToken, atou, utoa } from './util'; -import type { TokenValueFunction } from './util'; /** * Represents a generic CSRF protection error */ export class CsrfError extends Error {} -/** - * Represents cookie options in config - */ -export class CookieOptions { - domain: string = ''; - - httpOnly: boolean = true; - - maxAge: number | undefined = undefined; - - name: string = '_csrfSecret'; - - partitioned: boolean | undefined = undefined; - - path: string = '/'; - - sameSite: boolean | 'none' | 'strict' | 'lax' = 'strict'; - - secure: boolean = true; - - constructor(opts?: Partial) { - Object.assign(this, opts); - } -} - -/** - * Represents token options in config - */ -export class TokenOptions { - value: TokenValueFunction | undefined = undefined; - - constructor(opts?: Partial) { - Object.assign(this, opts); - } -} - -/** - * Represents CsrfProtect configuration object - */ -export class Config { - excludePathPrefixes: string[] = []; - - ignoreMethods: string[] = ['GET', 'HEAD', 'OPTIONS']; - - saltByteLength: number = 8; - - secretByteLength: number = 18; - - cookie: CookieOptions = new CookieOptions(); - - token: TokenOptions = new TokenOptions(); - - constructor(opts?: Partial) { - const newOpts = opts || {}; - if (newOpts.cookie) newOpts.cookie = new CookieOptions(newOpts.cookie); - if (newOpts.token) newOpts.token = new TokenOptions(newOpts.token); - Object.assign(this, newOpts); - - // basic validation - if (this.saltByteLength < 1 || this.saltByteLength > 255) { - throw Error('saltBytLength must be greater than 0 and less than 256'); - } - - if (this.secretByteLength < 1 || this.secretByteLength > 255) { - throw Error('secretBytLength must be greater than 0 and less than 256'); - } - } -} - -/** - * Represents CsrfProtect configuration options object - */ -export interface ConfigOptions extends Omit { - cookie: Partial; - token: Partial; -} - /** * Represents a cookie */ @@ -137,7 +61,7 @@ export function createCsrfProtect(opts?: Partial): CsrfProtect { // verify token if (!config.ignoreMethods.includes(request.method)) { - const tokenStr = await getTokenString(request, config.token.value); + const tokenStr = await getTokenString(request, config.token); if (!await verifyToken(atou(tokenStr), secret)) { throw new CsrfError('csrf validation error'); } diff --git a/shared/src/util.test.ts b/shared/src/util.test.ts index e7f0c34..9947f58 100644 --- a/shared/src/util.test.ts +++ b/shared/src/util.test.ts @@ -1,3 +1,4 @@ +import { TokenOptions } from './config'; import * as util from './util'; describe('createSecret', () => { @@ -108,6 +109,21 @@ describe('getTokenString', () => { expect(tokenStr).toEqual('my-token'); }); + it('gets token from custom field name', async() => { + const formData = new FormData(); + formData.set('file', new Blob(['xxx']), 'filename'); + formData.set('csrfToken', 'my-token'); + + const request = new Request('http://example.com/', { + method: 'POST', + body: formData, + }); + + const tokenOpts = new TokenOptions({ fieldName: 'csrfToken' }); + const tokenStr = await util.getTokenString(request, tokenOpts); + expect(tokenStr).toEqual('my-token'); + }); + it('gets token from raw body with other content-type', async () => { const request = new Request('http://example.com/', { method: 'POST', @@ -144,8 +160,10 @@ describe('getTokenString', () => { method: 'POST', body: JSON.stringify({ 'custom-token-name': 'my-token' }), }); - const valueFn = async (request: Request) => (await request.json())['custom-token-name']; - const tokenStr = await util.getTokenString(requestOuter, valueFn); + const tokenOpts = new TokenOptions({ + value: async (request: Request) => (await request.json())['custom-token-name'], + }); + const tokenStr = await util.getTokenString(requestOuter, tokenOpts); expect(tokenStr).toEqual('my-token'); }); }); diff --git a/shared/src/util.ts b/shared/src/util.ts index 4665386..2d57890 100644 --- a/shared/src/util.ts +++ b/shared/src/util.ts @@ -1,6 +1,6 @@ -export type TokenValueFunction = { - (request: Request): Promise -}; +import { TokenOptions } from './config'; + +const defaultTokenOpts = new TokenOptions(); /** * Create new secret (cryptographically secure) @@ -49,11 +49,9 @@ export function atou(inputB64: string): Uint8Array { * Get CSRF token from form * @param {FormData} formData - The form data object */ -const formDataKeyRegex = /^(\d+_)*csrf_token$/; - -function getTokenValueFromFormData(formData: FormData): File | string | undefined { +function getTokenValueFromFormData(formData: FormData, tokenOpts: TokenOptions = defaultTokenOpts): File | string | undefined { for (const [key, value] of formData.entries()) { - if (formDataKeyRegex.test(key)) return value; + if (tokenOpts._fieldNameRegex.test(key)) return value; } return undefined; } @@ -63,28 +61,30 @@ function getTokenValueFromFormData(formData: FormData): File | string | undefine * @param {Request} request - The request object * @param {ValueFunc|null} valueFn - Function to retrieve token value from request */ -export async function getTokenString(request: Request, valueFn?: TokenValueFunction): Promise { - if (valueFn !== undefined) return valueFn(request); +export async function getTokenString(request: Request, tokenOpts: TokenOptions = defaultTokenOpts): Promise { + if (tokenOpts.value !== undefined) return tokenOpts.value(request); // check the `x-csrf-token` request header const token = request.headers.get('x-csrf-token'); if (token !== null) return token; + const fieldName = tokenOpts.fieldName; + // check request body const contentType = request.headers.get('content-type') || 'text/plain'; // url-encoded or multipart/form-data if (contentType === 'application/x-www-form-urlencoded' || contentType.startsWith('multipart/form-data')) { const formData = await request.formData(); - const formDataVal = getTokenValueFromFormData(formData); + const formDataVal = getTokenValueFromFormData(formData, tokenOpts); if (typeof formDataVal === 'string') return formDataVal; return ''; } // json-encoded if (contentType === 'application/json' || contentType === 'application/ld+json') { - const json = await request.json() as { csrf_token: unknown; }; - const jsonVal = json.csrf_token; + const json = await request.json() as any; + const jsonVal = json[fieldName]; if (typeof jsonVal === 'string') return jsonVal; return ''; } @@ -109,7 +109,7 @@ export async function getTokenString(request: Request, valueFn?: TokenValueFunct if (typeofArgs0 === 'object') { // if first argument is an object, look for token there - return args0.csrf_token || ''; + return args0[fieldName] || ''; } return args0;