Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieusieben committed Dec 15, 2023
1 parent a9c13aa commit 1d7f7f6
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 140 deletions.
75 changes: 50 additions & 25 deletions packages/oauth-server/src/authorization/authorization-server.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import { Client } from '../client/client'
import { ClientCredentials } from '../client/types'
import { DeviceId } from '../device/device-id'
import { AuthorizationStore, RequestId } from './authorization-store'
import { buildAuthorizationServerMetadata } from './build-authorization-server-metadata'
import {
AuthorizationRequest,
AuthorizationRequestJar,
AuthorizationRequestQuery,
AuthorizationResponseDetails,
OnlineAuthorizationRequest,
PushedAuthorizationRequest,
RequestUri,
RequestUriPrefix,
authorizationRequestSchema,
} from './types'

Expand All @@ -22,7 +26,7 @@ export enum AuthorizationRequestStatus {
}

export class AuthorizationServer {
constructor(readonly origin: URL) {}
constructor(readonly origin: URL, readonly store: AuthorizationStore) {}

get authorizationServerMetadata() {
return buildAuthorizationServerMetadata(this.origin)
Expand All @@ -45,19 +49,15 @@ export class AuthorizationServer {
}
}

async getAuthorizationRequest(
async parseAuthorizationRequest(
client: Client,
data: PushedAuthorizationRequest | OnlineAuthorizationRequest,
data: AuthorizationRequestJar | AuthorizationRequest,
): Promise<AuthorizationRequest> {
if ('request_uri' in data) {
throw new Error('TODO: implement this')
}

// JAR
if ('request' in data) {
const { payload } = await client.jwtVerify(data.request, {
maxTokenAge: 60,
})

return authorizationRequestSchema.parse(payload)
}

Expand All @@ -73,51 +73,62 @@ export class AuthorizationServer {
) {
await this.verifyClientCredentials(client, data)

const request = await this.getAuthorizationRequest(client, data)

// TODO: Store request in store and return a request_uri
console.log(request)
const request = await this.parseAuthorizationRequest(client, data)
const id = await this.store.createAuthorizationRequest(client, request)

return {
request_uri:
'urn:ietf:params:oauth:request_uri:6esc_11ACC5bwc014ltc14eY22c',
request_uri: encodeRequestUri(id),
expires_in: 60,
}
}

async setupAuthorizationRequest(
client: Client,
deviceId: DeviceId,
data: OnlineAuthorizationRequest,
data: AuthorizationRequestQuery,
): Promise<
| [
status: AuthorizationRequestStatus.done,
details: AuthorizationResponseDetails,
]
| [
status: AuthorizationRequestStatus.consentRequired,
status: AuthorizationRequestStatus.loginRequired,
authorizationId: string,
]
| [
status: AuthorizationRequestStatus.loginRequired,
status: AuthorizationRequestStatus.consentRequired,
authorizationId: string,
]
> {
const request = await this.getAuthorizationRequest(client, data)
let request: AuthorizationRequest
let requestId: RequestId

// TODO: implement this
if ('request_uri' in data) {
requestId = decodeRequestUri(data.request_uri)
request = await this.store.getAuthorizationRequest(
client,
requestId,
deviceId,
)
} else {
request = await this.parseAuthorizationRequest(client, data)
requestId = await this.store.createAuthorizationRequest(
client,
request,
deviceId,
)
}

console.log(data)
// First, get list of current sessions
const sessions = await this.store.listActiveSessions(deviceId)

if (data !== null) {
return [AuthorizationRequestStatus.loginRequired, '']
}
// TODO: implement this

return [
AuthorizationRequestStatus.done,
{
redirect_uri: request.redirect_uri,
response_mode: 'query',
response_mode: request.response_mode || 'query',
response_data: {
code: 'K2-ltc83acc4h0c9w6ESC_rEMTJ3bww-uCHaoeK1t8U',
state: request.state,
Expand All @@ -126,3 +137,17 @@ export class AuthorizationServer {
]
}
}

function encodeRequestUri(requestId: RequestId): RequestUri {
return `${RequestUriPrefix}${encodeURIComponent(requestId)}`
}

function decodeRequestUri(requestUri: RequestUri) {
// Foolproofing
if (!requestUri.startsWith(RequestUriPrefix)) {
throw new TypeError('Invalid request_uri')
}
return decodeURIComponent(
requestUri.slice(RequestUriPrefix.length),
) as RequestId
}
50 changes: 50 additions & 0 deletions packages/oauth-server/src/authorization/authorization-store.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { Client } from '../client/client'
import { DeviceId } from '../device/device-id'
import { AuthorizationRequest } from './types'

export type RequestId = string & { __RequestId: true }

export interface AuthorizationStore {
// TODO Validate request agains client metadata !!!!
createAuthorizationRequest(
client: Client,
request: AuthorizationRequest,
deviceId?: DeviceId,
): Promise<RequestId>

// TODO: Add validation logic !!!
// try {
// // TODO : better error
// if (result.deviceId && result.deviceId !== deviceId) {
// throw new TypeError('Invalid request_uri')
// }
// // TODO : better error
// if (result.clientId !== client.id) {
// throw new TypeError('Invalid request_uri')
// }
// } catch (err) {
// await this.store.deleteAuthorizationRequest(requestId)
// throw err
// }
getAuthorizationRequest(
client: Client,
requestId: RequestId,
deviceId: DeviceId,
): Promise<AuthorizationRequest>

updateAuthorizationRequest(
requestId: RequestId,
data: {
request?: AuthorizationRequest
deviceId?: DeviceId
},
): Promise<void>

deleteAuthorizationRequest(requestId: RequestId): Promise<void>

listActiveSessions(deviceId: DeviceId): Promise<{
sessionId: string
deviceId: DeviceId
accountId: string // + Data ?
}>
}
39 changes: 30 additions & 9 deletions packages/oauth-server/src/authorization/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export const authorizationRequestSchema = z.object({

export type AuthorizationRequest = z.infer<typeof authorizationRequestSchema>

export const authorizationRequestObjectSchema = z.object({
export const authorizationRequestJarSchema = z.object({
/**
* AuthorizationRequest inside a JWT:
* - "iat" is required and **MUST** be less than one minute
Expand All @@ -102,17 +102,38 @@ export const authorizationRequestObjectSchema = z.object({
request: z.string().regex(/^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$/),
})

export type AuthorizationRequestJar = z.infer<
typeof authorizationRequestJarSchema
>

export const RequestUriPrefix = 'urn:ietf:params:oauth:request_uri:'

export const requestUriSchema = z
.string()
.url()
.refinement(
(data): data is `${typeof RequestUriPrefix}${string}` =>
data.startsWith(RequestUriPrefix) &&
data.length > RequestUriPrefix.length,
{
code: z.ZodIssueCode.custom,
message: 'Invalid request_uri',
},
)

export type RequestUri = z.infer<typeof requestUriSchema>

export const authorizationRequestUriSchema = z.object({
request_uri: z.string(),
request_uri: requestUriSchema,
})

export type AuthorizationRequestObject = z.infer<
typeof authorizationRequestObjectSchema
export type AuthorizationRequestUri = z.infer<
typeof authorizationRequestUriSchema
>

export const pushedAuthorizationRequestSchema = z.intersection(
clientAssertionSchema,
z.union([authorizationRequestSchema, authorizationRequestObjectSchema]),
z.union([authorizationRequestSchema, authorizationRequestJarSchema]),
)

export type PushedAuthorizationRequest = z.infer<
Expand All @@ -124,19 +145,19 @@ export type PushedAuthorizationRequestResponse = {
expires_in: number
}

export const queryAuthorizationRequestSchema = z.intersection(
export const authorizationRequestQuerySchema = z.intersection(
z.object({
client_id: clientIdSchema,
}),
z.union([
authorizationRequestSchema,
authorizationRequestObjectSchema,
authorizationRequestJarSchema,
authorizationRequestUriSchema,
]),
)

export type OnlineAuthorizationRequest = z.infer<
typeof queryAuthorizationRequestSchema
export type AuthorizationRequestQuery = z.infer<
typeof authorizationRequestQuerySchema
>

export type AuthorizationResponseDetails = {
Expand Down
91 changes: 15 additions & 76 deletions packages/oauth-server/src/client/client-registry.ts
Original file line number Diff line number Diff line change
@@ -1,84 +1,23 @@
import { LRUCache } from 'lru-cache'

import { DidWeb } from '../util/did-web'
import { Fetch } from '../util/fetch'
import { Client } from './client'
import { ClientStore } from './client-store'
import {
forbiddenDomainNameRequestTransform,
ssrfSafeRequestTransform,
} from '../util/fetch-request'

import { fetchMaxSizeProcessor } from '../util/fetch-response'
import { Client, ClientConfig } from './client'
import { combine } from '../util/transformer'

export type ClientRegistryConfig = NonNullable<
Parameters<typeof ClientRegistry.fromConfig>[0]
>
ClientStoreMemory,
ClientStoreMemoryConfig,
} from './client-store-memory'
import { ClientId } from './types'

export class ClientRegistry {
static fromConfig({
fetch: fetchFn = global.fetch as Fetch,
cacheTtl = 60 * 60 * 1000, // 1 hour
cacheMaxSize = 50 * 1024 * 1024, // 50MB
ssrfProtection = true,
maxResponseSize = 512 * 1024, // 512kB
forbiddenDomainNames = ['bsky.social', 'bsky.network'],
} = {}) {
const fetch: Fetch = combine(
/**
* Since we will be fetching from the network based on user provided
* input, we need to make sure that the request is not vulnerable to SSRF
* attacks.
*/
ssrfSafeRequestTransform(() => !ssrfProtection),
/**
* Disallow fetching from domains we know are not atproto client
* implementation.
*/
forbiddenDomainNameRequestTransform(forbiddenDomainNames),

// Wrap the fetch function to add some extra features
fetchFn,

/**
* Since we will be fetching user owned data, we need to make sure that
* an attacker cannot force us to download a large amounts of data.
*/
fetchMaxSizeProcessor(maxResponseSize),
)

const cache = new LRUCache<DidWeb, Client, ClientRegistry>({
ttl: cacheTtl,
maxSize: cacheMaxSize,
sizeCalculation: (client) => client.memoryUsage,
updateAgeOnGet: false,
updateAgeOnHas: false,
allowStaleOnFetchRejection: true,
ignoreFetchAbort: true,
fetchMethod: async (clientId, _, opts) => opts.context.load(clientId),
})

return new ClientRegistry(fetch, cache)
static memory(config?: ClientStoreMemoryConfig) {
const store = new ClientStoreMemory(config)
return new ClientRegistry(store)
}

constructor(
readonly fetch: Fetch,
readonly cache: LRUCache<DidWeb, Client, ClientRegistry>,
) {}

public async get(clientId: DidWeb): Promise<Client> {
// Since loopback clients will not cause any traffic, we won't cache them.
// This will allow to reserve the cache for clients that require HTTP
// requests to be fetched.
if (!Client.isLoopbackClient(clientId)) {
const cached = await this.cache.fetch(clientId, { context: this })
if (cached != null) return cached
}

return this.load(clientId)
}
constructor(protected readonly store: ClientStore) {}

public async load(clientId: DidWeb): Promise<Client> {
return Client.fromId(clientId, { fetch: this.fetch })
public async fetch(clientId: ClientId): Promise<Client> {
// We don't want to store loopback clients
return Client.isLoopbackClientId(clientId)
? Client.forLoopback(clientId)
: this.store.fetch(clientId)
}
}
Loading

0 comments on commit 1d7f7f6

Please sign in to comment.