diff --git a/packages/oauth-server/package.json b/packages/oauth-server/package.json index 1f058961fd2..3913e3656be 100644 --- a/packages/oauth-server/package.json +++ b/packages/oauth-server/package.json @@ -31,6 +31,7 @@ "http-errors": "^2.0.0", "ipaddr.js": "^2.1.0", "kysely": "^0.22.0", + "lru-cache": "^10.1.0", "zod": "^3.21.4" }, "devDependencies": { diff --git a/packages/oauth-server/src/client/client-registry.ts b/packages/oauth-server/src/client/client-registry.ts index 7036f74d11a..3bdf63774fc 100644 --- a/packages/oauth-server/src/client/client-registry.ts +++ b/packages/oauth-server/src/client/client-registry.ts @@ -1,3 +1,6 @@ +import { LRUCache } from 'lru-cache' + +import { extractService } from '../util/did' import { DidWeb, didWebToUrl, fetchDidDocument } from '../util/did-web' import { Fetch, fetchFactory } from '../util/fetch' import { @@ -6,48 +9,97 @@ import { } from '../util/fetch-transform' import { isLoopbackHostname } from '../util/net' -import { extractService } from '../util/did' -import { fetchClientMetadata } from './client-metadata' +import { fetchClientMetadata } from './fetch-client-metadata' import { ClientMetadata, clientMetadataSchema } from './types' +import { DeepReadonly, deepFreeze } from '../util/object' export class ClientRegistry { protected readonly fetch: Fetch + protected readonly cache = new LRUCache>( + { + max: 1000, + ttl: 60 * 60 * 1000, // 1 hour + updateAgeOnGet: false, + updateAgeOnHas: false, + allowStaleOnFetchRejection: true, + ignoreFetchAbort: true, + fetchMethod: async (clientId) => + this.loadClientMetadata(clientId).then((v) => deepFreeze(v)), + }, + ) constructor(fetch?: Fetch) { this.fetch = fetchFactory(fetch, [ + /** + * Since we will be fetching from the network based on user provided + * input, we need to make sure that the request is not vulnerable to SSRF + * attacks. + */ ssrfSafeRequestTransform(), + /** + * Disallow fetching from domains we know are not atproto client + * implementation. + */ forbiddenDomainNameRequestTransform(['bsky.social', 'bsky.network']), ]) } - async getClientMetadata(clientId: DidWeb): Promise { + public async getClientMetadata( + clientId: DidWeb, + ): Promise> { + // Allow loopback URLs to be used as client ID. Since we cannot fetch from + // them, we will use a pre-defined metadata for all clients using loopback + // URLs as client ID. These clients will be able to use the "registration" + // autorization parameter to provide additional custom metadata. Note that + // the UI should clearly indicate that the client is using a loopback URL. const url = didWebToUrl(clientId) - - // Allow localhost if (isLoopbackHostname(url.hostname)) { - return clientMetadataSchema.parse({ - client_name: 'Localhost', - redirect_uris: [url.toString()], - jwks: [], - token_endpoint_auth_method: 'none', - }) + return this.buildLoopbackClientMetadata(url) } + return this.getCachedClientMetadata(clientId) + } + + private async buildLoopbackClientMetadata(url: URL): Promise { + return clientMetadataSchema.parse({ + client_name: 'Localhost', + redirect_uris: [url.toString()], + jwks: [], + token_endpoint_auth_method: 'none', + }) + } + + private async getCachedClientMetadata( + clientId: DidWeb, + ): Promise> { + const cached = await this.cache.fetch(clientId) + if (cached != null) return cached + + // Should never happen when "signal" is not used, let's provide a fallback + // anyway. + return this.loadClientMetadata(clientId) + } + + private async loadClientMetadata(clientId: DidWeb): Promise { return fetchDidDocument(clientId, this.fetch) .then( + // If service not found, allow fallback (didDocument) => extractService(didDocument, 'OAuthClientMetadata')?.serviceEndpoint, ) .catch((err) => { - // In case of 404, fallback to the well-known endpoint + // In case of 404, allow fallback if (err.status === 404) return undefined throw err }) .then(async (metadataEndpoint) => fetchClientMetadata( - metadataEndpoint - ? new URL(metadataEndpoint) - : new URL('/.well-known/oauth-client-metadata', url), + metadataEndpoint || + // Fallback to well-known endpoint + new URL( + '/.well-known/oauth-client-metadata', + didWebToUrl(clientId), + ), this.fetch, ), ) diff --git a/packages/oauth-server/src/client/client-metadata.ts b/packages/oauth-server/src/client/fetch-client-metadata.ts similarity index 70% rename from packages/oauth-server/src/client/client-metadata.ts rename to packages/oauth-server/src/client/fetch-client-metadata.ts index 95ee5432b80..f8a322fe923 100644 --- a/packages/oauth-server/src/client/client-metadata.ts +++ b/packages/oauth-server/src/client/fetch-client-metadata.ts @@ -1,4 +1,3 @@ -import createError from 'http-errors' import { Fetch } from '../util/fetch' import { fetchResponseZodHandler, @@ -11,7 +10,7 @@ import { import { ClientMetadata, clientMetadataValidator } from './types' export async function fetchClientMetadata( - metadataEndpoint: URL, + metadataEndpoint: string | URL, fetchFn: Fetch = fetch, ): Promise { return fetchFn(new Request(metadataEndpoint, { redirect: 'error' })) @@ -20,12 +19,3 @@ export async function fetchClientMetadata( .then(fetchResponseJsonHandler()) .then(fetchResponseZodHandler(clientMetadataValidator)) } - -export async function extractClientMetadataEndpoint({ - body, -}: { - body: unknown -}): Promise { - // TODO: implement this - throw createError(404, 'No client metadata endpoint found in DID Document') -} diff --git a/packages/oauth-server/src/create-router.ts b/packages/oauth-server/src/create-router.ts index bcb6945b93d..82a3c2653bb 100644 --- a/packages/oauth-server/src/create-router.ts +++ b/packages/oauth-server/src/create-router.ts @@ -1,20 +1,19 @@ import { RequestHandler, Router } from 'express' +import { ClientRegistry } from './client/client-registry' +import { DeviceSessionManager } from './device/device-session-manager' import { AuthorizationRequestStatus, OAuthAuthorizationServer, } from './oauth-authorization-server' -import { DeviceSessionManager } from './device/device-session-manager' -import { - queryAuthorizationRequestSchema, - pushedAuthorizationRequestSchema, - clientIdSchema, -} from './types' import { buildErrorPayload, buildErrorStatus, } from './output/build-error-payload' -import { buildRedirectUri } from './output/build-redirect-uri' -import { ClientRegistry } from './client/client-registry' +import { sendAuthorizationResponse } from './output/send-authorize-response' +import { + pushedAuthorizationRequestSchema, + queryAuthorizationRequestSchema, +} from './types' export function createMiddleware( authorizationServer: OAuthAuthorizationServer, @@ -23,7 +22,19 @@ export function createMiddleware( ): RequestHandler { const router = Router() - // For compat (TODO: comment) + /** + * Although OpenID compatibility is not required for this specification, we do + * support OIDC discovery as we believe this may: + * 1) Make the implementation of Atproto clients easier (since lots of + * libraries already support OIDC discovery) + * 2) Allow self hosted PDS' to implement their login flow using OIDC + * providers. By supporting OIDC, Bluesky's OAuth server can be used as an + * OIDC provider for these users. + * 3) During future developments, we may want to allow users to authenticate + * through OIDC providers instead of using their account username/password. + * If this happens, we could allow Bluesky users to host their data on + * Bluesky's servers, but authenticate through their own OIDC provider. + */ router.get('/.well-known/openid-configuration', (req, res) => { res.json( authorizationServer.getAuthorizationServerMetadata({ @@ -40,13 +51,24 @@ export function createMiddleware( res.json({ keys: authorizationServer.getJwks() }) }) - // PAR + /** + * @see {@link https://datatracker.ietf.org/doc/html/rfc9126} + */ router.post('/oauth/par', async (req, res) => { try { - const data = pushedAuthorizationRequestSchema.parse(req.body) + const data = pushedAuthorizationRequestSchema.parse(req.body, { + path: ['body'], + }) + + const clientMedata = await clientRegistry.getClientMetadata( + data.client_id, + ) + const result = await authorizationServer.createPushedAuthorizationRequest( + clientMedata, data, ) + res.json(result) } catch (err) { // TODO: log error ? @@ -56,51 +78,37 @@ export function createMiddleware( router.get('/oauth/authorize', async (req, res, next) => { try { - const clientId = clientIdSchema.parse(req.query?.client_id, { - path: ['query', 'client_id'], + const data = queryAuthorizationRequestSchema.parse(req.query, { + path: ['query'], }) - const clientMedata = await clientRegistry.getClientMetadata(clientId) - - const queryResult = queryAuthorizationRequestSchema.safeParse(req.query) - if (!queryResult.success) { - const { error } = queryResult - const uri = buildRedirectUri(clientMedata, req.query, error) - if (uri) return res.redirect(uri.href) - else throw error - } - - const { data } = queryResult - const clientMedata = await clientRegistry.getClientMetadata( - clientId, - 'registration' in data ? data.registration : undefined, + data.client_id, ) - try { - const { deviceId } = await deviceSessionManager.load(req, res) - - const [status, details] = - await authorizationServer.setupAuthorizationRequest(deviceId, data) - - switch (status) { - case AuthorizationRequestStatus.done: - return res.redirect(details.href) - case AuthorizationRequestStatus.loginRequired: - return res.redirect(`/oauth/login?state=${details}`) - case AuthorizationRequestStatus.consentRequired: - return res.redirect(`/oauth/consent?state=${details}`) - } - - throw new Error('Unreachable') - } catch (err) { - // TODO: log error ? - - const uri = buildRedirectUri(clientMedata, req.query, err) - if (uri) return res.redirect(uri.href) - else throw err + const { deviceId } = await deviceSessionManager.load(req, res) + + const [status, details] = + await authorizationServer.setupAuthorizationRequest( + clientMedata, + deviceId, + data, + ) + + switch (status) { + case AuthorizationRequestStatus.done: + sendAuthorizationResponse(req, res, details) + return + case AuthorizationRequestStatus.loginRequired: + return res.redirect(`/oauth/login?state=${details}`) + case AuthorizationRequestStatus.consentRequired: + return res.redirect(`/oauth/consent?state=${details}`) } + + throw new Error('Unreachable') } catch (err) { + // TODO: log error ? + next(err) } }) diff --git a/packages/oauth-server/src/oauth-authorization-server.ts b/packages/oauth-server/src/oauth-authorization-server.ts index cb2bd0a4b5e..9d5048334f1 100644 --- a/packages/oauth-server/src/oauth-authorization-server.ts +++ b/packages/oauth-server/src/oauth-authorization-server.ts @@ -1,6 +1,9 @@ +import { ClientMetadata } from './client/types' import { DeviceId } from './device/device-id' +import { AuthorizationResponseDetails } from './output/send-authorize-response' import { buildAuthorizationServerMetadata } from './server-metadata' import { OnlineAuthorizationRequest, PushedAuthorizationRequest } from './types' +import { DeepReadonly } from './util/object' export enum AuthorizationRequestStatus { // The authorization request is not yet ready to be processed. @@ -24,7 +27,13 @@ export class OAuthAuthorizationServer { throw new Error('Method not implemented.') } - async createPushedAuthorizationRequest(data: PushedAuthorizationRequest) { + /** + * @see {@link https://datatracker.ietf.org/doc/html/rfc9126} + */ + async createPushedAuthorizationRequest( + clientMetadata: DeepReadonly, + data: PushedAuthorizationRequest, + ) { data = { response_type: 'code', code_challenge: 'K2-ltc83acc4h0c9w6ESC_rEMTJ3bww-uCHaoeK1t8U', @@ -49,7 +58,7 @@ export class OAuthAuthorizationServer { request: JSON.stringify({ // JAR }), - client_id: 's6BhdRkqt3', + client_id: 'did:web:bsky.app:my-app', } return { @@ -60,10 +69,14 @@ export class OAuthAuthorizationServer { } async setupAuthorizationRequest( + clientMetadata: DeepReadonly, deviceId: DeviceId, data: OnlineAuthorizationRequest, ): Promise< - | [status: AuthorizationRequestStatus.done, redirectUri: URL] + | [ + status: AuthorizationRequestStatus.done, + details: AuthorizationResponseDetails, + ] | [ status: AuthorizationRequestStatus.consentRequired, authorizationId: string, diff --git a/packages/oauth-server/src/output/build-error-payload.ts b/packages/oauth-server/src/output/build-error-payload.ts index b06d49089fe..a9e3bde47e8 100644 --- a/packages/oauth-server/src/output/build-error-payload.ts +++ b/packages/oauth-server/src/output/build-error-payload.ts @@ -1,8 +1,9 @@ +import { isHttpError } from 'http-errors' import { ZodError } from 'zod' import { AuthorizationServerError } from '../errors/index' -export function buildErrorStatus(error: unknown) { +export function buildErrorStatus(error: unknown): number { if (error instanceof AuthorizationServerError) { return error.status } @@ -15,6 +16,10 @@ export function buildErrorStatus(error: unknown) { return 400 } + if (isHttpError(error)) { + return error.status + } + return 500 } @@ -40,6 +45,13 @@ export function buildErrorPayload(error: unknown) { } } + if (isHttpError(error)) { + return { + error: error.status < 500 ? 'invalid_request' : 'server_error', + error_description: error.expose ? error.message : 'Server error', + } + } + return { error: 'server_error', error_description: 'Server error', diff --git a/packages/oauth-server/src/output/build-redirect-uri.ts b/packages/oauth-server/src/output/build-redirect-uri.ts deleted file mode 100644 index 3210221f752..00000000000 --- a/packages/oauth-server/src/output/build-redirect-uri.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { ClientMetadata } from '../client/types' -import { buildErrorPayload } from './build-error-payload' - -export function buildRedirectUri( - clientMetadata: ClientMetadata | null, - unsafeParams?: { - redirect_uri?: unknown - state?: unknown - nonce?: unknown - }, - error?: unknown, -): URL | null { - if (typeof redirectUri !== 'string') return null - try { - const url = new URL(redirectUri) - - if (error) { - const payload = buildErrorPayload(error) - for (const key of ['error', 'error_description'] as const) { - url.searchParams.set(key, payload[key]) - } - } - - if (unsafeParams) { - for (const key of ['state', 'nonce'] as const) { - const value = unsafeParams[key] - if (typeof value === 'string') url.searchParams.set(key, value) - } - } - - return url - } catch (err) { - return null - } -} diff --git a/packages/oauth-server/src/output/send-authorize-response.ts b/packages/oauth-server/src/output/send-authorize-response.ts new file mode 100644 index 00000000000..19b456d108a --- /dev/null +++ b/packages/oauth-server/src/output/send-authorize-response.ts @@ -0,0 +1,86 @@ +import { IncomingMessage, ServerResponse } from 'node:http' +import { html } from '../util/html' + +export type AuthorizationResponseDetails = { + redirect_uri: string + response_mode: 'query' | 'fragment' | 'form_post' + response_data: AuthorizationResponseData +} + +export type AuthorizationResponseData = { + code?: string + state?: string + error?: string + error_description?: string +} + +export function sendAuthorizationResponse( + req: IncomingMessage, + res: ServerResponse, + details: AuthorizationResponseDetails, +) { + const redirectUri = new URL(details.redirect_uri) + + if (details.response_mode === 'query') { + populateQueryString(redirectUri.searchParams, details.response_data) + res.writeHead(302, { Location: redirectUri.href }).end() + return + } + + if (details.response_mode === 'fragment') { + const searchParams = new URLSearchParams() + populateQueryString(searchParams, details.response_data) + redirectUri.hash = searchParams.toString() + res.writeHead(302, { Location: redirectUri.href }).end() + return + } + + if (details.response_mode === 'form_post') { + const body = html` + + +
+ ${Object.entries(details.response_data).map( + ([key, value]) => html` + + `, + )} + +
+ + + + `.toString() + + res + .writeHead(200, { + 'Content-Type': 'text/html', + 'Content-Length': Buffer.byteLength(body), + }) + .end(body) + return + } + + throw new Error('Unreachable') +} + +export function buildRedirectUri( + redirectUri: string | URL, + responseData: AuthorizationResponseData, +): URL { + const url = new URL(redirectUri) + populateQueryString(url.searchParams, responseData) + return url +} + +function populateQueryString( + params: URLSearchParams, + responseData: AuthorizationResponseData, +): void { + for (const key of ['state', 'nonce', 'error', 'error_description'] as const) { + const value = responseData[key] + if (typeof value === 'string') params.set(key, value) + } +} diff --git a/packages/oauth-server/src/types.ts b/packages/oauth-server/src/types.ts index da87b09b79d..4c03fdf9241 100644 --- a/packages/oauth-server/src/types.ts +++ b/packages/oauth-server/src/types.ts @@ -110,6 +110,7 @@ export const authorizationRequestSchema = z.object({ export type AuthorizationRequest = z.infer export const authorizationRequestObjectSchema = z.object({ + client_id: clientIdSchema, /** * AuthorizationRequest inside a JWT: * - "nonce" is REQUIRED @@ -118,6 +119,11 @@ export const authorizationRequestObjectSchema = z.object({ request: z.string().regex(/^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$/), }) +export const authorizationRequestPushedSchema = z.object({ + client_id: clientIdSchema, + request_uri: z.string(), +}) + export type AuthorizationRequestObject = z.infer< typeof authorizationRequestObjectSchema > @@ -139,7 +145,7 @@ export type PushedAuthorizationRequestResponse = { export const queryAuthorizationRequestSchema = z.union([ authorizationRequestSchema.strict(), authorizationRequestObjectSchema.strict(), - z.object({ client_id: z.string(), request_uri: z.string() }).strict(), + authorizationRequestPushedSchema.strict(), ]) export type OnlineAuthorizationRequest = z.infer< diff --git a/packages/oauth-server/src/util/did-web.ts b/packages/oauth-server/src/util/did-web.ts index 68ce33239eb..abb19982e33 100644 --- a/packages/oauth-server/src/util/did-web.ts +++ b/packages/oauth-server/src/util/did-web.ts @@ -8,7 +8,7 @@ import { fetchResponseZodHandler, fetchSuccessHandler, } from './fetch-handlers' -import { Did, didDocumentSchema, didSchema } from './did' +import { didDocumentSchema, didSchema } from './did' export const DID_WEB = `did:web:` export type DidWeb = `did:web:${string}` @@ -18,24 +18,20 @@ export const didWebSchema = didSchema.refinement(isWebDid, { message: 'Invalid Web DID', }) -export function isWebDid(did: Did): did is DidWeb { - // Fast check - if (!did.startsWith(DID_WEB)) return false - +export function isWebDid(did: unknown): did is DidWeb { try { - didWebToUrl(did) - return true + return typeof did === 'string' && didWebToUrl(did) != null } catch { return false } } -export function didWebToUrl(did: Did): URL { +export function didWebToUrl(did: string): URL { if (!did.startsWith(DID_WEB)) { throw new TypeError(`Not a Web DID`) } const suffix = did.slice(DID_WEB.length) - if (!suffix || suffix.startsWith(':')) { + if (!suffix || suffix.startsWith(':') || suffix.endsWith(':')) { throw new TypeError(`Invalid Web DID`) } const parts = suffix.split(':').map(decodeURIComponent) diff --git a/packages/oauth-server/src/util/fetch.ts b/packages/oauth-server/src/util/fetch.ts index 1c9ec2b83ab..05bd6ffc8d0 100644 --- a/packages/oauth-server/src/util/fetch.ts +++ b/packages/oauth-server/src/util/fetch.ts @@ -4,10 +4,8 @@ export type Fetch = (input: RequestInfo | URL) => Promise export function fetchFactory( baseFetch: Fetch = fetch, - requestTransformers?: Iterable, + requestTransformers: Iterable, ): Fetch { - if (!requestTransformers) return baseFetch - - const transformer = combineTransformers(requestTransformers) - return async (input) => baseFetch(await transformer(input)) + const requestTransformer = combineTransformers(requestTransformers) + return async (input) => baseFetch(await requestTransformer(input)) } diff --git a/packages/oauth-server/src/util/html.ts b/packages/oauth-server/src/util/html.ts new file mode 100644 index 00000000000..535e7b8981d --- /dev/null +++ b/packages/oauth-server/src/util/html.ts @@ -0,0 +1,32 @@ +const charToHtmlEntity = (i) => '&#' + i.charCodeAt(0) + ';' + +function encodeHtmlEntities(value: string) { + return value.replace(/[\u00A0-\u9999<>&"'= ]/g, charToHtmlEntity) +} + +const asSafeHtml = (value: string) => ({ + __safe: true, + valueOf: () => value, + toString: () => value, +}) + +type SafeHtml = ReturnType + +export function html( + strings: TemplateStringsArray, + ...values: Nested[] +) { + return asSafeHtml( + strings.reduce((acc, str, i) => acc + str + toEncodedHtml(values[i]), ''), + ) +} + +type Nested = V | Array> + +function toEncodedHtml(value: Nested): string { + return Array.isArray(value) + ? value.map(toEncodedHtml).join('') + : typeof value === 'string' + ? encodeHtmlEntities(value) + : value.toString() +} diff --git a/packages/oauth-server/src/util/object.ts b/packages/oauth-server/src/util/object.ts new file mode 100644 index 00000000000..f8c87c241c7 --- /dev/null +++ b/packages/oauth-server/src/util/object.ts @@ -0,0 +1,20 @@ +export type DeepReadonly = T extends [infer A, ...infer B] + ? readonly [DeepReadonly, ...DeepReadonly] + : T extends Array + ? ReadonlyArray> + : T extends object + ? { readonly [P in keyof T]: DeepReadonly } + : T + +export function deepFreeze(input: T): DeepReadonly { + if (input != null && typeof input === 'object') { + Object.freeze(input) + for (const prop of Object.getOwnPropertyNames(input)) { + const value = input[prop] + if (value != null && typeof value === 'object') { + deepFreeze(value) + } + } + } + return input as DeepReadonly +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 827e4f274eb..ea620937e99 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -490,6 +490,9 @@ importers: kysely: specifier: ^0.22.0 version: 0.22.0 + lru-cache: + specifier: ^10.1.0 + version: 10.1.0 zod: specifier: ^3.21.4 version: 3.21.4 @@ -9145,6 +9148,11 @@ packages: resolution: {integrity: sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==} dev: false + /lru-cache@10.1.0: + resolution: {integrity: sha512-/1clY/ui8CzjKFyjdvwPWJUYKiFVXG2I2cY0ssG7h4+hwk+XOIX7ZSG9Q7TW8TW3Kp3BUSqgFWBLgL4PJ+Blag==} + engines: {node: 14 || >=16.14} + dev: false + /lru-cache@4.1.5: resolution: {integrity: sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==} dependencies: