diff --git a/packages/pds/package.json b/packages/pds/package.json index e3fef718ecd..04fceffb5a3 100644 --- a/packages/pds/package.json +++ b/packages/pds/package.json @@ -1,6 +1,6 @@ { "name": "@atproto/pds", - "version": "0.3.0-entryway.0", + "version": "0.3.0-entryway.1", "license": "MIT", "description": "Reference implementation of atproto Personal Data Server (PDS)", "keywords": [ diff --git a/packages/pds/src/api/com/atproto/admin/updateAccountHandle.ts b/packages/pds/src/api/com/atproto/admin/updateAccountHandle.ts index 368c2dae586..10441d08ee7 100644 --- a/packages/pds/src/api/com/atproto/admin/updateAccountHandle.ts +++ b/packages/pds/src/api/com/atproto/admin/updateAccountHandle.ts @@ -1,4 +1,8 @@ -import { AuthRequiredError, InvalidRequestError } from '@atproto/xrpc-server' +import { + AuthRequiredError, + InvalidRequestError, + UpstreamFailureError, +} from '@atproto/xrpc-server' import { normalizeAndValidateHandle } from '../../../../handle' import { Server } from '../../../../lexicon' import AppContext from '../../../../context' @@ -7,6 +11,7 @@ import { UserAlreadyExistsError, } from '../../../../services/account' import { httpLogger } from '../../../../logger' +import { isThisPds } from '../../../proxy' export default function (server: Server, ctx: AppContext) { server.com.atproto.admin.updateAccountHandle({ @@ -47,6 +52,30 @@ export default function (server: Server, ctx: AppContext) { }) } + const { pdsDid } = existingAccnt + if (ctx.cfg.service.isEntryway && !isThisPds(ctx, pdsDid)) { + const pds = + pdsDid && + (await ctx.services.account(ctx.db).getPds(pdsDid, { cached: true })) + if (!pds) { + throw new UpstreamFailureError('unknown pds') + } + // the pds emits the handle event on the firehose, but the entryway is responsible for updating the did doc. + // the long flow is: pds(identity.updateHandle) -> entryway(identity.updateHandle) -> pds(admin.updateAccountHandle) + const agent = ctx.pdsAgents.get(pds.host) + await agent.com.atproto.admin.updateAccountHandle( + { + did, + handle: input.body.handle, + }, + { + encoding: 'application/json', + headers: ctx.authVerifier.createAdminRoleHeaders(), + }, + ) + return // do not sequence handle event on the entryway + } + try { await ctx.db.transaction(async (dbTxn) => { await ctx.services.account(dbTxn).sequenceHandle(seqHandleTok) diff --git a/packages/pds/src/api/com/atproto/identity/updateHandle.ts b/packages/pds/src/api/com/atproto/identity/updateHandle.ts index 44a2aaded72..61c7b79dd14 100644 --- a/packages/pds/src/api/com/atproto/identity/updateHandle.ts +++ b/packages/pds/src/api/com/atproto/identity/updateHandle.ts @@ -1,4 +1,5 @@ -import { InvalidRequestError } from '@atproto/xrpc-server' +import { InvalidRequestError, UpstreamFailureError } from '@atproto/xrpc-server' +import { DAY, MINUTE } from '@atproto/common' import { normalizeAndValidateHandle } from '../../../../handle' import { Server } from '../../../../lexicon' import AppContext from '../../../../context' @@ -7,7 +8,7 @@ import { UserAlreadyExistsError, } from '../../../../services/account' import { httpLogger } from '../../../../logger' -import { DAY, MINUTE } from '@atproto/common' +import { isThisPds } from '../../../proxy' export default function (server: Server, ctx: AppContext) { server.com.atproto.identity.updateHandle({ @@ -26,6 +27,7 @@ export default function (server: Server, ctx: AppContext) { ], handler: async ({ auth, input }) => { const requester = auth.credentials.did + const pdsDid = auth.credentials.pdsDid const handle = await normalizeAndValidateHandle({ ctx, handle: input.body.handle, @@ -63,6 +65,29 @@ export default function (server: Server, ctx: AppContext) { }) } + if (ctx.cfg.service.isEntryway && !isThisPds(ctx, pdsDid)) { + const pds = + pdsDid && + (await ctx.services.account(ctx.db).getPds(pdsDid, { cached: true })) + if (!pds) { + throw new UpstreamFailureError('unknown pds') + } + // the pds emits the handle event on the firehose, but the entryway is responsible for updating the did doc. + // the long flow is: pds(identity.updateHandle) -> entryway(identity.updateHandle) -> pds(admin.updateAccountHandle) + const agent = ctx.pdsAgents.get(pds.host) + await agent.com.atproto.admin.updateAccountHandle( + { + did: requester, + handle: input.body.handle, + }, + { + encoding: 'application/json', + headers: ctx.authVerifier.createAdminRoleHeaders(), + }, + ) + return // do not sequence handle event on the entryway + } + try { await ctx.db.transaction(async (dbTxn) => { await ctx.services.account(dbTxn).sequenceHandle(seqHandleTok) diff --git a/packages/pds/src/api/com/atproto/server/createAccount.ts b/packages/pds/src/api/com/atproto/server/createAccount.ts index 407fc683ec2..b11020ad876 100644 --- a/packages/pds/src/api/com/atproto/server/createAccount.ts +++ b/packages/pds/src/api/com/atproto/server/createAccount.ts @@ -1,7 +1,6 @@ import { MINUTE, cborDecode, cborEncode, check } from '@atproto/common' import { AtprotoData, ensureAtpDocument } from '@atproto/identity' import { InvalidRequestError } from '@atproto/xrpc-server' -import AtpAgent from '@atproto/api' import * as plc from '@did-plc/lib' import disposable from 'disposable-email' import { normalizeAndValidateHandle } from '../../../../handle' @@ -12,8 +11,9 @@ import { countAll } from '../../../../db/util' import { UserAlreadyExistsError } from '../../../../services/account' import AppContext from '../../../../context' import Database from '../../../../db' -import { getPdsEndpoint, isThisPds } from '../../../proxy' +import { isThisPds } from '../../../proxy' import { didDocForSession } from './util' +import { getPdsEndpoint } from '../../../../pds-agents' export default function (server: Server, ctx: AppContext) { server.com.atproto.server.createAccount({ @@ -132,7 +132,7 @@ export default function (server: Server, ctx: AppContext) { // Setup repo root await repoTxn.createRepo(did, [], now) } else { - const agent = new AtpAgent({ service: getPdsEndpoint(pds.host) }) + const agent = ctx.pdsAgents.get(pds.host) await agent.com.atproto.server.createAccount({ ...input.body, did, @@ -215,7 +215,7 @@ const getDidAndPlcOp = async ( }> => { const pdsEndpoint = pds ? getPdsEndpoint(pds.host) : ctx.cfg.service.publicUrl const pdsSigningKey = pds - ? await reserveSigningKey(pds.host) + ? await reserveSigningKey(ctx, pds.host) : ctx.repoSigningKey.did() // if the user brings their own PLC op then we validate it then submit it to PLC on their behalf @@ -334,8 +334,8 @@ const assignPds = async (ctx: AppContext) => { return pdses.at(idx) } -const reserveSigningKey = async (host: string) => { - const agent = new AtpAgent({ service: getPdsEndpoint(host) }) +const reserveSigningKey = async (ctx: AppContext, host: string) => { + const agent = ctx.pdsAgents.get(host) const result = await agent.com.atproto.server.reserveSigningKey() return result.data.signingKey } diff --git a/packages/pds/src/api/com/atproto/server/util.ts b/packages/pds/src/api/com/atproto/server/util.ts index fc3bfae8e05..b110f30f3ff 100644 --- a/packages/pds/src/api/com/atproto/server/util.ts +++ b/packages/pds/src/api/com/atproto/server/util.ts @@ -1,3 +1,4 @@ +import { getPdsEndpoint } from '@atproto/common' import * as crypto from '@atproto/crypto' import { DidDocument } from '@atproto/identity' import { ServerConfig } from '../../../../config' @@ -26,7 +27,6 @@ export const getRandomToken = () => { return token.slice(0, 5) + '-' + token.slice(5, 10) } -// @TODO once supporting multiple pdses, validate pds in did doc based on allow-list. export const didDocForSession = async ( ctx: AppContext, did: string, @@ -34,8 +34,15 @@ export const didDocForSession = async ( ): Promise<DidDocument | undefined> => { if (!ctx.cfg.identity.enableDidDocWithSession) return try { - const didDoc = await ctx.idResolver.did.resolve(did, forceRefresh) - return didDoc ?? undefined + const [didDoc, pdses] = await Promise.all([ + ctx.idResolver.did.resolve(did, forceRefresh), + ctx.services.account(ctx.db).getPdses({ cached: true }), + ]) + if (!didDoc) return + const pdsEndpoint = getPdsEndpoint(didDoc) + const pdsHost = pdsEndpoint && new URL(pdsEndpoint).host + if (!pdses.some((pds) => pds.host === pdsHost)) return + return didDoc } catch (err) { dbLogger.warn({ err, did }, 'failed to resolve did doc') } diff --git a/packages/pds/src/api/proxy.ts b/packages/pds/src/api/proxy.ts index 726b2118a04..803f7478ddf 100644 --- a/packages/pds/src/api/proxy.ts +++ b/packages/pds/src/api/proxy.ts @@ -17,14 +17,12 @@ export const proxy = async <T>( return null // skip proxying } const accountService = ctx.services.account(ctx.db) - const pds = pdsDid && (await accountService.getPds(pdsDid)) + const pds = pdsDid && (await accountService.getPds(pdsDid, { cached: true })) if (!pds) { throw new UpstreamFailureError('unknown pds') } - // @TODO reuse agents - const agent = new AtpAgent({ service: getPdsEndpoint(pds.host) }) try { - return await fn(agent) + return await fn(ctx.pdsAgents.get(pds.host)) } catch (err) { // @TODO may need to pass through special lexicon errors if ( @@ -42,14 +40,6 @@ export const proxy = async <T>( } } -export const getPdsEndpoint = (host: string) => { - const service = new URL(`https://${host}`) - if (service.hostname === 'localhost') { - service.protocol = 'http:' - } - return service.origin -} - export const isThisPds = ( ctx: AppContext, pdsDid: string | null | undefined, diff --git a/packages/pds/src/auth-verifier.ts b/packages/pds/src/auth-verifier.ts index 19a326f5155..37ce111951d 100644 --- a/packages/pds/src/auth-verifier.ts +++ b/packages/pds/src/auth-verifier.ts @@ -290,6 +290,17 @@ export class AuthVerifier { return { status: Invalid, admin: false, moderator: false, triage: false } } + createAdminRoleHeaders = () => { + return { + authorization: + 'Basic ' + + ui8.toString( + ui8.fromString(`admin:${this._adminPass}`, 'utf8'), + 'base64pad', + ), + } + } + isUserOrAdmin( auth: AccessOutput | RoleOutput | NullOutput, did: string, diff --git a/packages/pds/src/context.ts b/packages/pds/src/context.ts index ff0f77292c2..f24504a7619 100644 --- a/packages/pds/src/context.ts +++ b/packages/pds/src/context.ts @@ -20,6 +20,7 @@ import { Crawlers } from './crawlers' import { DiskBlobStore } from './storage' import { getRedisClient } from './redis' import { RuntimeFlags } from './runtime-flags' +import { PdsAgents } from './pds-agents' export type AppContextOptions = { db: Database @@ -38,6 +39,7 @@ export type AppContextOptions = { crawlers: Crawlers appViewAgent: AtpAgent authVerifier: AuthVerifier + pdsAgents: PdsAgents repoSigningKey: crypto.Keypair plcRotationKey: crypto.Keypair cfg: ServerConfig @@ -60,6 +62,7 @@ export class AppContext { public crawlers: Crawlers public appViewAgent: AtpAgent public authVerifier: AuthVerifier + public pdsAgents: PdsAgents public repoSigningKey: crypto.Keypair public plcRotationKey: crypto.Keypair public cfg: ServerConfig @@ -81,6 +84,7 @@ export class AppContext { this.crawlers = opts.crawlers this.appViewAgent = opts.appViewAgent this.authVerifier = opts.authVerifier + this.pdsAgents = opts.pdsAgents this.repoSigningKey = opts.repoSigningKey this.plcRotationKey = opts.plcRotationKey this.cfg = opts.cfg @@ -191,6 +195,8 @@ export class AppContext { crawlers, }) + const pdsAgents = new PdsAgents() + return new AppContext({ db, blobstore, @@ -210,6 +216,7 @@ export class AppContext { authVerifier, repoSigningKey, plcRotationKey, + pdsAgents, cfg, ...(overrides ?? {}), }) diff --git a/packages/pds/src/pds-agents.ts b/packages/pds/src/pds-agents.ts new file mode 100644 index 00000000000..c24a797d201 --- /dev/null +++ b/packages/pds/src/pds-agents.ts @@ -0,0 +1,22 @@ +import AtpAgent from '@atproto/api' + +export class PdsAgents { + // @NOTE only use with entries in the pds table, not for e.g. arbitrary entries found in did documents. + private cache = new Map<string, AtpAgent>() + get(host: string) { + const agent = + this.cache.get(host) ?? new AtpAgent({ service: getPdsEndpoint(host) }) + if (!this.cache.has(host)) { + this.cache.set(host, agent) + } + return agent + } +} + +export const getPdsEndpoint = (host: string) => { + const service = new URL(`https://${host}`) + if (service.hostname === 'localhost') { + service.protocol = 'http:' + } + return service.origin +} diff --git a/packages/pds/src/services/account/index.ts b/packages/pds/src/services/account/index.ts index 6b108a7987d..9d80f2fee8a 100644 --- a/packages/pds/src/services/account/index.ts +++ b/packages/pds/src/services/account/index.ts @@ -1,4 +1,4 @@ -import { sql } from 'kysely' +import { Selectable, sql } from 'kysely' import { randomStr } from '@atproto/crypto' import { InvalidRequestError } from '@atproto/xrpc-server' import { MINUTE, lessThanAgoMs } from '@atproto/common' @@ -15,12 +15,13 @@ import { AppPassword } from '../../lexicon/types/com/atproto/server/createAppPas import { EmailTokenPurpose } from '../../db/tables/email-token' import { getRandomToken } from '../../api/com/atproto/server/util' import { OptionalJoin } from '../../db/types' +import { Pds } from '../../db/tables/pds' export class AccountService { - constructor(public db: Database) {} + constructor(public db: Database, private pdsCache: PdsCache) {} - static creator() { - return (db: Database) => new AccountService(db) + static creator(pdsCache: PdsCache) { + return (db: Database) => new AccountService(db, pdsCache) } // @TODO decouple account from repo_root, move takedownId. @@ -84,10 +85,10 @@ export class AccountService { qb.where(notSoftDeletedClause(ref('user_account'))), ) .where('email', '=', email.toLowerCase()) - .select(['pds.did as pdsDid']) - .selectAll('user_account') + .selectAll('repo_root') // first so that its possibly-null vals don't shadow other cols .selectAll('did_handle') - .selectAll('repo_root') + .selectAll('user_account') + .select(['pds.did as pdsDid']) .executeTakeFirst() return found || null } @@ -601,13 +602,27 @@ export class AccountService { } } - // @TODO cache w/ in-mem lookup - async getPds(pdsDid: string) { - return await this.db.db + // @NOTE cached due to heavy usage in proxy logic + async getPds(pdsDid: string, opts?: { cached: boolean }) { + if (opts?.cached && this.pdsCache.has(pdsDid)) { + return this.pdsCache.get(pdsDid) + } + const pds = await this.db.db .selectFrom('pds') .where('did', '=', pdsDid) .selectAll() .executeTakeFirst() + if (pds) this.pdsCache.set(pdsDid, pds) + return pds + } + + async getPdses(opts?: { cached: boolean }) { + if (opts?.cached && this.pdsCache.hasAll()) { + return this.pdsCache.getAll() ?? [] + } + const pdses = await this.db.db.selectFrom('pds').selectAll().execute() + this.pdsCache.setAll(pdses) + return pdses } } @@ -648,3 +663,30 @@ export type HandleSequenceToken = { did: string; handle: string } type AccountInfo = UserAccountEntry & DidHandle & OptionalJoin<RepoRoot> & { pdsDid: string | null } + +export class PdsCache { + private all: PdsResult[] | undefined + private individual = new Map<string, PdsResult>() + get(did: string) { + return this.individual.get(did) + } + has(did: string) { + return this.individual.has(did) + } + set(did: string, pds: PdsResult) { + return this.individual.set(did, pds) + } + getAll() { + return this.all + } + hasAll() { + return this.all !== undefined + } + setAll(pdses: PdsResult[]) { + this.all = pdses + this.individual.clear() + pdses.forEach((pds) => this.individual.set(pds.did, pds)) + } +} + +type PdsResult = Selectable<Pds> diff --git a/packages/pds/src/services/index.ts b/packages/pds/src/services/index.ts index ffb27d00b5e..c597964383f 100644 --- a/packages/pds/src/services/index.ts +++ b/packages/pds/src/services/index.ts @@ -2,7 +2,7 @@ import { AtpAgent } from '@atproto/api' import * as crypto from '@atproto/crypto' import { BlobStore } from '@atproto/repo' import Database from '../db' -import { AccountService } from './account' +import { AccountService, PdsCache } from './account' import { AuthService } from './auth' import { RecordService } from './record' import { RepoService } from './repo' @@ -36,8 +36,9 @@ export function createServices(resources: { backgroundQueue, crawlers, } = resources + const pdsCache = new PdsCache() return { - account: AccountService.creator(), + account: AccountService.creator(pdsCache), auth: AuthService.creator(identityDid, authKeys), record: RecordService.creator(), repo: RepoService.creator( @@ -53,7 +54,7 @@ export function createServices(resources: { appViewDid, appViewCdnUrlPattern, ), - moderation: ModerationService.creator(blobstore), + moderation: ModerationService.creator(blobstore, pdsCache), } } diff --git a/packages/pds/src/services/moderation/index.ts b/packages/pds/src/services/moderation/index.ts index 3f96d3a1b90..524fe7b7d59 100644 --- a/packages/pds/src/services/moderation/index.ts +++ b/packages/pds/src/services/moderation/index.ts @@ -10,15 +10,20 @@ import { ModerationViews } from './views' import SqlRepoStorage from '../../sql-repo-storage' import { TAKEDOWN } from '../../lexicon/types/com/atproto/admin/defs' import { addHoursToDate } from '../../util/date' +import { PdsCache } from '../account' export class ModerationService { - constructor(public db: Database, public blobstore: BlobStore) {} - - static creator(blobstore: BlobStore) { - return (db: Database) => new ModerationService(db, blobstore) + constructor( + public db: Database, + public blobstore: BlobStore, + public pdsCache: PdsCache, + ) {} + + static creator(blobstore: BlobStore, pdsCache: PdsCache) { + return (db: Database) => new ModerationService(db, blobstore, pdsCache) } - views = new ModerationViews(this.db) + views = new ModerationViews(this.db, this.pdsCache) services = { record: RecordService.creator(), diff --git a/packages/pds/src/services/moderation/views.ts b/packages/pds/src/services/moderation/views.ts index e0285e6f932..b4f64adbfb5 100644 --- a/packages/pds/src/services/moderation/views.ts +++ b/packages/pds/src/services/moderation/views.ts @@ -17,17 +17,17 @@ import { } from '../../lexicon/types/com/atproto/admin/defs' import { OutputSchema as ReportOutput } from '../../lexicon/types/com/atproto/moderation/createReport' import { ModerationAction } from '../../db/tables/moderation' -import { AccountService } from '../account' +import { AccountService, PdsCache } from '../account' import { RecordService } from '../record' import { ModerationReportRowWithHandle } from '.' import { ids } from '../../lexicon/lexicons' import { OptionalJoin } from '../../db/types' export class ModerationViews { - constructor(private db: Database) {} + constructor(private db: Database, private pdsCache: PdsCache) {} services = { - account: AccountService.creator(), + account: AccountService.creator(this.pdsCache), record: RecordService.creator(), }