Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieusieben committed Dec 14, 2023
1 parent e3d6949 commit 2735185
Show file tree
Hide file tree
Showing 14 changed files with 316 additions and 129 deletions.
1 change: 1 addition & 0 deletions packages/oauth-server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"http-errors": "^2.0.0",
"ipaddr.js": "^2.1.0",
"kysely": "^0.22.0",
"lru-cache": "^10.1.0",
"zod": "^3.21.4"
},
"devDependencies": {
Expand Down
82 changes: 67 additions & 15 deletions packages/oauth-server/src/client/client-registry.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { LRUCache } from 'lru-cache'

import { extractService } from '../util/did'
import { DidWeb, didWebToUrl, fetchDidDocument } from '../util/did-web'
import { Fetch, fetchFactory } from '../util/fetch'
import {
Expand All @@ -6,48 +9,97 @@ import {
} from '../util/fetch-transform'
import { isLoopbackHostname } from '../util/net'

import { extractService } from '../util/did'
import { fetchClientMetadata } from './client-metadata'
import { fetchClientMetadata } from './fetch-client-metadata'
import { ClientMetadata, clientMetadataSchema } from './types'
import { DeepReadonly, deepFreeze } from '../util/object'

export class ClientRegistry {
protected readonly fetch: Fetch
protected readonly cache = new LRUCache<DidWeb, DeepReadonly<ClientMetadata>>(
{
max: 1000,
ttl: 60 * 60 * 1000, // 1 hour
updateAgeOnGet: false,
updateAgeOnHas: false,
allowStaleOnFetchRejection: true,
ignoreFetchAbort: true,
fetchMethod: async (clientId) =>
this.loadClientMetadata(clientId).then((v) => deepFreeze(v)),
},
)

constructor(fetch?: Fetch) {
this.fetch = fetchFactory(fetch, [
/**
* Since we will be fetching from the network based on user provided
* input, we need to make sure that the request is not vulnerable to SSRF
* attacks.
*/
ssrfSafeRequestTransform(),
/**
* Disallow fetching from domains we know are not atproto client
* implementation.
*/
forbiddenDomainNameRequestTransform(['bsky.social', 'bsky.network']),
])
}

async getClientMetadata(clientId: DidWeb): Promise<ClientMetadata> {
public async getClientMetadata(
clientId: DidWeb,
): Promise<DeepReadonly<ClientMetadata>> {
// Allow loopback URLs to be used as client ID. Since we cannot fetch from
// them, we will use a pre-defined metadata for all clients using loopback
// URLs as client ID. These clients will be able to use the "registration"
// autorization parameter to provide additional custom metadata. Note that
// the UI should clearly indicate that the client is using a loopback URL.
const url = didWebToUrl(clientId)

// Allow localhost
if (isLoopbackHostname(url.hostname)) {
return clientMetadataSchema.parse({
client_name: 'Localhost',
redirect_uris: [url.toString()],
jwks: [],
token_endpoint_auth_method: 'none',
})
return this.buildLoopbackClientMetadata(url)
}

return this.getCachedClientMetadata(clientId)
}

private async buildLoopbackClientMetadata(url: URL): Promise<ClientMetadata> {
return clientMetadataSchema.parse({
client_name: 'Localhost',
redirect_uris: [url.toString()],
jwks: [],
token_endpoint_auth_method: 'none',
})
}

private async getCachedClientMetadata(
clientId: DidWeb,
): Promise<DeepReadonly<ClientMetadata>> {
const cached = await this.cache.fetch(clientId)
if (cached != null) return cached

// Should never happen when "signal" is not used, let's provide a fallback
// anyway.
return this.loadClientMetadata(clientId)
}

private async loadClientMetadata(clientId: DidWeb): Promise<ClientMetadata> {
return fetchDidDocument(clientId, this.fetch)
.then(
// If service not found, allow fallback
(didDocument) =>
extractService(didDocument, 'OAuthClientMetadata')?.serviceEndpoint,
)
.catch((err) => {
// In case of 404, fallback to the well-known endpoint
// In case of 404, allow fallback
if (err.status === 404) return undefined
throw err
})
.then(async (metadataEndpoint) =>
fetchClientMetadata(
metadataEndpoint
? new URL(metadataEndpoint)
: new URL('/.well-known/oauth-client-metadata', url),
metadataEndpoint ||
// Fallback to well-known endpoint
new URL(
'/.well-known/oauth-client-metadata',
didWebToUrl(clientId),
),
this.fetch,
),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import createError from 'http-errors'
import { Fetch } from '../util/fetch'
import {
fetchResponseZodHandler,
Expand All @@ -11,7 +10,7 @@ import {
import { ClientMetadata, clientMetadataValidator } from './types'

export async function fetchClientMetadata(
metadataEndpoint: URL,
metadataEndpoint: string | URL,
fetchFn: Fetch = fetch,
): Promise<ClientMetadata> {
return fetchFn(new Request(metadataEndpoint, { redirect: 'error' }))
Expand All @@ -20,12 +19,3 @@ export async function fetchClientMetadata(
.then(fetchResponseJsonHandler())
.then(fetchResponseZodHandler(clientMetadataValidator))
}

export async function extractClientMetadataEndpoint({
body,
}: {
body: unknown
}): Promise<URL> {
// TODO: implement this
throw createError(404, 'No client metadata endpoint found in DID Document')
}
106 changes: 57 additions & 49 deletions packages/oauth-server/src/create-router.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import { RequestHandler, Router } from 'express'
import { ClientRegistry } from './client/client-registry'
import { DeviceSessionManager } from './device/device-session-manager'
import {
AuthorizationRequestStatus,
OAuthAuthorizationServer,
} from './oauth-authorization-server'
import { DeviceSessionManager } from './device/device-session-manager'
import {
queryAuthorizationRequestSchema,
pushedAuthorizationRequestSchema,
clientIdSchema,
} from './types'
import {
buildErrorPayload,
buildErrorStatus,
} from './output/build-error-payload'
import { buildRedirectUri } from './output/build-redirect-uri'
import { ClientRegistry } from './client/client-registry'
import { sendAuthorizationResponse } from './output/send-authorize-response'
import {
pushedAuthorizationRequestSchema,
queryAuthorizationRequestSchema,
} from './types'

export function createMiddleware(
authorizationServer: OAuthAuthorizationServer,
Expand All @@ -23,7 +22,19 @@ export function createMiddleware(
): RequestHandler {
const router = Router()

// For compat (TODO: comment)
/**
* Although OpenID compatibility is not required for this specification, we do
* support OIDC discovery as we believe this may:
* 1) Make the implementation of Atproto clients easier (since lots of
* libraries already support OIDC discovery)
* 2) Allow self hosted PDS' to implement their login flow using OIDC
* providers. By supporting OIDC, Bluesky's OAuth server can be used as an
* OIDC provider for these users.
* 3) During future developments, we may want to allow users to authenticate
* through OIDC providers instead of using their account username/password.
* If this happens, we could allow Bluesky users to host their data on
* Bluesky's servers, but authenticate through their own OIDC provider.
*/
router.get('/.well-known/openid-configuration', (req, res) => {
res.json(
authorizationServer.getAuthorizationServerMetadata({
Expand All @@ -40,13 +51,24 @@ export function createMiddleware(
res.json({ keys: authorizationServer.getJwks() })
})

// PAR
/**
* @see {@link https://datatracker.ietf.org/doc/html/rfc9126}
*/
router.post('/oauth/par', async (req, res) => {
try {
const data = pushedAuthorizationRequestSchema.parse(req.body)
const data = pushedAuthorizationRequestSchema.parse(req.body, {
path: ['body'],
})

const clientMedata = await clientRegistry.getClientMetadata(
data.client_id,
)

const result = await authorizationServer.createPushedAuthorizationRequest(
clientMedata,
data,
)

res.json(result)
} catch (err) {
// TODO: log error ?
Expand All @@ -56,51 +78,37 @@ export function createMiddleware(

router.get('/oauth/authorize', async (req, res, next) => {
try {
const clientId = clientIdSchema.parse(req.query?.client_id, {
path: ['query', 'client_id'],
const data = queryAuthorizationRequestSchema.parse(req.query, {
path: ['query'],
})

const clientMedata = await clientRegistry.getClientMetadata(clientId)

const queryResult = queryAuthorizationRequestSchema.safeParse(req.query)
if (!queryResult.success) {
const { error } = queryResult
const uri = buildRedirectUri(clientMedata, req.query, error)
if (uri) return res.redirect(uri.href)
else throw error
}

const { data } = queryResult

const clientMedata = await clientRegistry.getClientMetadata(
clientId,
'registration' in data ? data.registration : undefined,
data.client_id,
)

try {
const { deviceId } = await deviceSessionManager.load(req, res)

const [status, details] =
await authorizationServer.setupAuthorizationRequest(deviceId, data)

switch (status) {
case AuthorizationRequestStatus.done:
return res.redirect(details.href)
case AuthorizationRequestStatus.loginRequired:
return res.redirect(`/oauth/login?state=${details}`)
case AuthorizationRequestStatus.consentRequired:
return res.redirect(`/oauth/consent?state=${details}`)
}

throw new Error('Unreachable')
} catch (err) {
// TODO: log error ?

const uri = buildRedirectUri(clientMedata, req.query, err)
if (uri) return res.redirect(uri.href)
else throw err
const { deviceId } = await deviceSessionManager.load(req, res)

const [status, details] =
await authorizationServer.setupAuthorizationRequest(
clientMedata,
deviceId,
data,
)

switch (status) {
case AuthorizationRequestStatus.done:
sendAuthorizationResponse(req, res, details)
return
case AuthorizationRequestStatus.loginRequired:
return res.redirect(`/oauth/login?state=${details}`)
case AuthorizationRequestStatus.consentRequired:
return res.redirect(`/oauth/consent?state=${details}`)
}

throw new Error('Unreachable')
} catch (err) {
// TODO: log error ?

next(err)
}
})
Expand Down
19 changes: 16 additions & 3 deletions packages/oauth-server/src/oauth-authorization-server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { ClientMetadata } from './client/types'
import { DeviceId } from './device/device-id'
import { AuthorizationResponseDetails } from './output/send-authorize-response'
import { buildAuthorizationServerMetadata } from './server-metadata'
import { OnlineAuthorizationRequest, PushedAuthorizationRequest } from './types'
import { DeepReadonly } from './util/object'

export enum AuthorizationRequestStatus {
// The authorization request is not yet ready to be processed.
Expand All @@ -24,7 +27,13 @@ export class OAuthAuthorizationServer {
throw new Error('Method not implemented.')
}

async createPushedAuthorizationRequest(data: PushedAuthorizationRequest) {
/**
* @see {@link https://datatracker.ietf.org/doc/html/rfc9126}
*/
async createPushedAuthorizationRequest(
clientMetadata: DeepReadonly<ClientMetadata>,
data: PushedAuthorizationRequest,
) {
data = {
response_type: 'code',
code_challenge: 'K2-ltc83acc4h0c9w6ESC_rEMTJ3bww-uCHaoeK1t8U',
Expand All @@ -49,7 +58,7 @@ export class OAuthAuthorizationServer {
request: JSON.stringify({
// JAR
}),
client_id: 's6BhdRkqt3',
client_id: 'did:web:bsky.app:my-app',
}

return {
Expand All @@ -60,10 +69,14 @@ export class OAuthAuthorizationServer {
}

async setupAuthorizationRequest(
clientMetadata: DeepReadonly<ClientMetadata>,
deviceId: DeviceId,
data: OnlineAuthorizationRequest,
): Promise<
| [status: AuthorizationRequestStatus.done, redirectUri: URL]
| [
status: AuthorizationRequestStatus.done,
details: AuthorizationResponseDetails,
]
| [
status: AuthorizationRequestStatus.consentRequired,
authorizationId: string,
Expand Down
14 changes: 13 additions & 1 deletion packages/oauth-server/src/output/build-error-payload.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { isHttpError } from 'http-errors'
import { ZodError } from 'zod'

import { AuthorizationServerError } from '../errors/index'

export function buildErrorStatus(error: unknown) {
export function buildErrorStatus(error: unknown): number {
if (error instanceof AuthorizationServerError) {
return error.status
}
Expand All @@ -15,6 +16,10 @@ export function buildErrorStatus(error: unknown) {
return 400
}

if (isHttpError(error)) {
return error.status
}

return 500
}

Expand All @@ -40,6 +45,13 @@ export function buildErrorPayload(error: unknown) {
}
}

if (isHttpError(error)) {
return {
error: error.status < 500 ? 'invalid_request' : 'server_error',
error_description: error.expose ? error.message : 'Server error',
}
}

return {
error: 'server_error',
error_description: 'Server error',
Expand Down
Loading

0 comments on commit 2735185

Please sign in to comment.