From 76a76355dc048a4289a1b91e071750938004a198 Mon Sep 17 00:00:00 2001 From: turbocrime Date: Mon, 17 Feb 2025 23:07:35 -0800 Subject: [PATCH] prevent race conditions --- .../transport-chrome/src/session-client.ts | 8 +- .../transport-chrome/src/session-manager.ts | 368 ++++++------------ packages/transport-chrome/src/session.ts | 159 ++++++++ .../src/suppress-disconnected.ts | 7 + packages/transport-chrome/src/util/senders.ts | 50 +++ 5 files changed, 341 insertions(+), 251 deletions(-) create mode 100644 packages/transport-chrome/src/session.ts create mode 100644 packages/transport-chrome/src/suppress-disconnected.ts create mode 100644 packages/transport-chrome/src/util/senders.ts diff --git a/packages/transport-chrome/src/session-client.ts b/packages/transport-chrome/src/session-client.ts index 2c2a29c5e..7a8b3e0f3 100644 --- a/packages/transport-chrome/src/session-client.ts +++ b/packages/transport-chrome/src/session-client.ts @@ -53,11 +53,9 @@ export class CRSessionClient { // listen to client this.clientPort.addEventListener('message', this.clientListener); - if (globalThis.__DEV__) { - this.clientPort.addEventListener('messageerror', ev => - console.warn('CRSessionClient.clientPort messageerror', ev), - ); - } + this.clientPort.addEventListener('messageerror', ev => + console.warn('CRSessionClient.clientPort messageerror', ev), + ); this.clientPort.start(); } diff --git a/packages/transport-chrome/src/session-manager.ts b/packages/transport-chrome/src/session-manager.ts index d20150699..c5ad940e1 100644 --- a/packages/transport-chrome/src/session-manager.ts +++ b/packages/transport-chrome/src/session-manager.ts @@ -1,32 +1,14 @@ -import type { JsonValue } from '@bufbuild/protobuf'; -import { Code, ConnectError } from '@connectrpc/connect'; -import { errorToJson } from '@connectrpc/connect/protocol-connect'; import type { ChannelHandlerFn } from '@penumbra-zone/transport-dom/adapter'; -import { - isTransportAbort, - isTransportEvent, - isTransportMessage, - TransportEvent, - type TransportError, - type TransportMessage, - type TransportStream, -} from '@penumbra-zone/transport-dom/messages'; -import { ChannelLabel, nameConnection, parseConnectionName } from './channel-names.js'; -import { isTransportInitChannel, TransportInitChannel } from './message.js'; -import { PortStreamSink, PortStreamSource } from './stream.js'; +import { ChannelLabel, parseConnectionName } from './channel-names.js'; +import { CRSession } from './session.js'; +import { assertMatchingSenders, isPortWithSenderOrigin } from './util/senders.js'; -interface CRSession { - abort: (reason?: unknown) => void; - signal: AbortSignal; - sessionId: string; +export interface ManagedPort { port: chrome.runtime.Port; - origin: string; - requests: Map; + portAc: AbortController; } -type SenderWithOrigin = chrome.runtime.MessageSender & { origin: string }; -type PortWithOrigin = chrome.runtime.Port & { sender: SenderWithOrigin }; -export type CheckPortSenderFn = (port: chrome.runtime.Port) => Promise; +export type ValidateSessionPortFn = (port: chrome.runtime.Port) => Promise; /** * Only for use as an extension-level singleton by the extension's main @@ -53,70 +35,54 @@ export type CheckPortSenderFn = (port: chrome.runtime.Port) => Promise(); - /** - * Create a new session manager to accept connections from `CRSessionClient`. - * - * @param managerId a string containing no spaces, matching the prefix used in your content script - * @param handler your router entry function - * @param checkPortSender a function used to validate the sender of a connection - */ - constructor( - private readonly managerId: string, - private readonly handler: ChannelHandlerFn, - private readonly checkPortSender: CheckPortSenderFn, - ) { + private static assertInitialized() { + if (!CRSessionManager.singleton) { + throw new Error('Not initialized'); + } + return CRSessionManager.singleton; + } + + private static assertUninitialized() { if (CRSessionManager.singleton) { - throw new Error('Already constructed'); + throw new Error('Already initialized'); } + } + + private sessions = new Map(); + private ports = new Map(); + + private constructor( + public readonly managerId: string, + public readonly handler: ChannelHandlerFn, + public readonly validateSessionPort: ValidateSessionPortFn, + ) { + CRSessionManager.assertUninitialized(); CRSessionManager.singleton = this; - chrome.runtime.onConnect.addListener(this.transportConnection); + chrome.runtime.onConnect.addListener(this.initSession); } /** - * Initialize the singleton session manager. + * Initialize the singleton session manager, or return the existing singleton. * * @param managerId a string identifying this manager * @param handler your router entry function - * @param checkPortSender function to assert validity of a sender + * @param validateSessionPort callback to assert validity of a connection */ public static init = ( managerId: string, handler: ChannelHandlerFn, - checkPortSender: CheckPortSenderFn, - ) => { - CRSessionManager.singleton ??= new CRSessionManager(managerId, handler, checkPortSender); - return CRSessionManager.singleton.sessions; - }; - - /** - * Abort all sessions from a given origin presently active in the singleton. - * - * @param targetOrigin the origin to kill - */ - public static killOrigin = (targetOrigin: string) => { - if (CRSessionManager.singleton) { - CRSessionManager.singleton.sessions.forEach(session => { - if (session.origin === targetOrigin) { - session.requests.forEach(request => { - if (!request.signal.aborted) { - request.abort( - new Error('Kill origin request', { - cause: targetOrigin, - }), - ); - } - }); - if (!session.signal.aborted) { - session.abort(new Error('Kill origin session', { cause: targetOrigin })); - session.port.disconnect(); - } - } - }); - } else { - throw new Error('No session manager'); + validateSessionPort: ValidateSessionPortFn, + ): ReadonlyMap => { + CRSessionManager.singleton ??= new CRSessionManager(managerId, handler, validateSessionPort); + if ( + CRSessionManager.singleton.managerId !== managerId || + CRSessionManager.singleton.handler !== handler || + CRSessionManager.singleton.validateSessionPort !== validateSessionPort + ) { + throw new Error("Init parameters don't match singleton parameters"); } + return CRSessionManager.singleton.sessions; }; /** @@ -124,195 +90,105 @@ export class CRSessionManager { * with access to the chrome runtime. * * Here we make an effort to identify these connections. If the name indicates - * the connection is for this manager, a handler is connected to the port. + * the connection is for this manager, handlers are connected to the port. */ - private transportConnection = (port: chrome.runtime.Port) => { - // require an identified origin - if (!port.sender?.origin) { - return; - } - - // fast and simple name test - if (!port.name.startsWith(this.managerId)) { - return; - } - - // parse the name - const { label: channelLabel, uuid: clientId } = - parseConnectionName(this.managerId, port.name) ?? {}; - if (channelLabel !== ChannelLabel.TRANSPORT || !clientId) { - return; - } - - // client is re-using a present session?? - if (this.sessions.has(clientId)) { - port.disconnect(); - throw new Error(`Session collision: ${clientId}`); - } - - // checking port sender is async - void this.checkPortSender(port).then( - okPort => { - console.debug('Accepted connection', port.name); - this.acceptSession(okPort, clientId); - }, - (e: unknown) => console.warn('Attempted connection was rejected', port.name, e), - ); - }; - - private acceptSession = (port: PortWithOrigin, sessionId: string) => { - console.debug('acceptSession', port.name, sessionId); - const senderOrigin = port.sender.origin; - - const ac = new AbortController(); - const session: CRSession = { - abort: (r?: unknown) => ac.abort(r), - signal: ac.signal, - sessionId, - origin: senderOrigin, - port, - requests: new Map(), - }; - - const sessionAbortListener = () => { - console.debug('sessionAbortListener', sessionId); - session.requests.forEach(request => request.abort(session.signal.reason)); - if (this.sessions.delete(sessionId)) { - port.disconnect(); - } - }; - - const sessionDisconnectListener = () => { - console.debug('sessionDisconnectListener', sessionId); - if (this.sessions.delete(sessionId)) { - session.abort(new Error('Session port disconnected')); - } - }; - - const sessionMessageListener = (tev: unknown) => { - console.debug('sessionMessageListener', tev); - if (isTransportEvent(tev)) { - void this.acceptRequest(session, tev); - } else { - console.warn('Unknown item in transport', tev); - } - }; - - this.sessions.set(sessionId, session); - - session.signal.addEventListener('abort', sessionAbortListener); - port.onDisconnect.addListener(sessionDisconnectListener); - port.onMessage.addListener(sessionMessageListener); - }; - - private acceptRequest = async (session: CRSession, tev: TransportEvent) => { - console.debug('acceptRequest', session.port.name, tev); - const { requestId } = tev; + private initSession = (sessionPort: chrome.runtime.Port) => { + if ( + // quick check for a name indicating this manager + sessionPort.name.startsWith(this.managerId) && + // require an origin + isPortWithSenderOrigin(sessionPort) + ) { + // parse the name thoroughly + const { label: channelLabel, uuid: sessionId } = + parseConnectionName(this.managerId, sessionPort.name) ?? {}; + if (channelLabel === ChannelLabel.TRANSPORT && sessionId) { + // client is re-using a present session?? + if (this.sessions.has(sessionId)) { + // don't disconnect the port, just leave it hanging + throw new Error(`Session collision: ${sessionId}`); + } - try { - if (isTransportAbort(tev, requestId)) { - session.requests - .get(requestId) - ?.abort(ConnectError.from('Client requested abort', Code.Canceled)); - } else if (session.requests.has(requestId)) { - throw new ConnectError('Request collision', Code.Internal); - } else { - const ac = new AbortController(); - session.requests.set(requestId, ac); - const response = await this.sessionRequestHandler(session, ac, tev); - session.port.postMessage(response); + // create a new session + const session = new CRSession(this, this.trackPort(sessionPort)); + this.sessions.set(sessionId, session); + session.signal.addEventListener('abort', () => this.sessions.delete(sessionId)); } - } catch (cause) { - console.debug('acceptRequest error', cause); - session.port.postMessage({ - requestId, - error: errorToJson(ConnectError.from(cause), undefined), - }); - } finally { - session.requests.delete(requestId); } }; /** - * This method enters the router, and returns a response. + * Kill all connections with a given origin. * - * It expects a single message, so only supports unary requests and - * server-streaming requests. This should *always successfully return* a - * `TransportEvent`, containing json representing a response or json - * representing an error. + * @param targetOrigin the origin to kill */ - private sessionRequestHandler = async ( - session: CRSession, - ac: AbortController, - tev: TransportEvent, - ): Promise> => { - console.debug('sessionRequestHandler', session.port.name, tev); - const { requestId } = tev; - - let request: JsonValue | ReadableStream; - if (isTransportMessage(tev, requestId)) { - request = tev.message; - } else if (isTransportInitChannel(tev) && globalThis.__DEV__) { - request = await this.acceptChannelStreamRequest(session.port.sender?.tab?.id, tev.channel); - } else { - throw new ConnectError('Unknown request kind', Code.Unimplemented); + public static killOrigin(targetOrigin: string) { + for (const [port, ac] of CRSessionManager.assertInitialized().ports.entries()) { + if (port.sender?.origin === targetOrigin) { + ac.abort(); + } } + } - const response = await this.handler(request, AbortSignal.any([session.signal, ac.signal])); - if (response instanceof ReadableStream) { - return { requestId, channel: this.makeChannelStreamResponse(response) }; + private trackPort(port: chrome.runtime.Port, portAc = new AbortController()): ManagedPort { + if (this.ports.has(port)) { + throw new Error('Port already tracked'); } else { - return { requestId, message: response }; - } - }; + port.onDisconnect.addListener(() => { + if (this.ports.delete(port)) { + portAc.abort(); + } + }); - /** - * Streams are not jsonifiable, so this function sinks a response stream - * into a dedicated chrome runtime channel, for reconstruction by the - * client. - * - * A jsonifiable message identifying a unique connection name is returned - * and should be transported to the client. The client should open a - * connection bearing this name to source the stream. - */ - private makeChannelStreamResponse = ( - stream: TransportStream['stream'], - ): TransportInitChannel['channel'] => { - const channel = nameConnection(this.managerId, ChannelLabel.STREAM); - console.debug('responseChannelStream', channel); - const sinkListener = (sinkPort: chrome.runtime.Port) => { - if (sinkPort.name === channel) { - chrome.runtime.onConnect.removeListener(sinkListener); - void this.checkPortSender(sinkPort) - .then( - () => - stream - .pipeTo(new WritableStream(new PortStreamSink(sinkPort))) - .catch((e: unknown) => console.debug('response channel stream error', e)), - (e: unknown) => console.warn('Attempted stream was rejected', sinkPort.name, e), - ) - .finally(() => sinkPort.disconnect()); - } - }; + portAc.signal.addEventListener('abort', () => { + if (this.ports.delete(port)) { + port.disconnect(); + } + }); - AbortSignal.any([AbortSignal.timeout(60_000)]).addEventListener('abort', () => - chrome.runtime.onConnect.removeListener(sinkListener), - ); + this.ports.set(port, portAc); + return { port, portAc }; + } + } - chrome.runtime.onConnect.addListener(sinkListener); + public async acceptSubChannel( + name: string, + expectedSender: chrome.runtime.MessageSender, + subAc?: AbortController, + ): Promise { + const unvalidatedPort = expectedSender.tab?.id + ? chrome.tabs.connect(expectedSender.tab.id, { name }) + : chrome.runtime.connect({ name }); - return channel; - }; + assertMatchingSenders(expectedSender, unvalidatedPort.sender); + const validPort = this.validateSessionPort(unvalidatedPort); - private acceptChannelStreamRequest = async ( - tabId: number | undefined, - channel: TransportInitChannel['channel'], - ): Promise => { - console.debug('requestChannelStream', channel); - const streamPort = tabId - ? chrome.tabs.connect(tabId, { name: channel }) - : chrome.runtime.connect({ name: channel }); + return this.trackPort(await validPort, subAc); + } - return new ReadableStream(new PortStreamSource(await this.checkPortSender(streamPort))); - }; + public async offerSubChannel( + name: string, + expectedSender: chrome.runtime.MessageSender, + subAc?: AbortController, + ): Promise { + const { promise: validPort, resolve, reject } = Promise.withResolvers(); + + const wrappedListener = (unvalidatedPort: chrome.runtime.Port) => { + if (unvalidatedPort.name === name) { + chrome.runtime.onConnect.removeListener(wrappedListener); + void (async () => { + try { + assertMatchingSenders(expectedSender, unvalidatedPort.sender); + resolve(await this.validateSessionPort(unvalidatedPort)); + } catch (e: unknown) { + console.warn('Subchannel init failed', unvalidatedPort.name, e); + reject(e); + } + })(); + } + }; + chrome.runtime.onConnect.addListener(wrappedListener); + + return this.trackPort(await validPort, subAc); + } } diff --git a/packages/transport-chrome/src/session.ts b/packages/transport-chrome/src/session.ts new file mode 100644 index 000000000..2de9b5d01 --- /dev/null +++ b/packages/transport-chrome/src/session.ts @@ -0,0 +1,159 @@ +import { Code, ConnectError } from '@connectrpc/connect'; +import { errorToJson } from '@connectrpc/connect/protocol-connect'; +import { + isTransportAbort, + isTransportEvent, + isTransportMessage, + type TransportError, + type TransportMessage, +} from '@penumbra-zone/transport-dom/messages'; +import { ChannelLabel, nameConnection } from './channel-names.js'; +import { isTransportInitChannel, type TransportInitChannel } from './message.js'; +import type { CRSessionManager, ManagedPort } from './session-manager.js'; +import { PortStreamSink, PortStreamSource } from './stream.js'; +import { rethrowOrSuppressDisconnectedPortError } from './suppress-disconnected.js'; +import { assertSenderWithOrigin } from './util/senders.js'; + +/** + * Listeners and abort control for a single session. + * + * @param manager - the parent session manager + * @param unvalidatedPort - for synchronous listener attach + * @param approved - port validation promise + */ +export class CRSession { + public readonly abort: (r?: unknown) => void; + public readonly signal: AbortSignal; + + public readonly sender: chrome.runtime.MessageSender & { origin: string }; + public get origin() { + return this.sender.origin; + } + + public readonly pending = new Map(); + + constructor( + /** reference to the parent session manager */ + private readonly manager: CRSessionManager, + { + /** + * this unvalidated port is used to synchronously attach listeners, to + * avoid a race against incoming messages. + * + * the unvalidated port must not leave the constructor scope. + */ + port: unvalidatedPort, + portAc: sessionAc, + }: ManagedPort, + /** blocks listener execution until the port is validated. */ + private readonly approved = manager.validateSessionPort(unvalidatedPort), + ) { + this.sender = assertSenderWithOrigin(unvalidatedPort.sender); + + this.signal = sessionAc.signal; + this.abort = (r?: unknown) => sessionAc.abort(r); + + this.signal.addEventListener('abort', () => + this.pending.forEach(pendingAc => pendingAc.abort()), + ); + + void this.approved.catch(() => this.abort()); + + unvalidatedPort.onMessage.addListener(this.sessionListener); + } + + private postResponse = (m: TransportMessage | TransportInitChannel) => + this.approved.then(approvedPort => { + try { + approvedPort.postMessage(m); + } catch (e) { + rethrowOrSuppressDisconnectedPortError(e); + } + }); + + private postFailure = (e: TransportError) => + this.approved.then(approvedPort => { + try { + approvedPort.postMessage(e); + } catch (e) { + rethrowOrSuppressDisconnectedPortError(e); + } + }); + + /** + * This listener is attached immediately, but blocks on sender validation. + * + * Basic filtering and transport control are handled here. Valid requests are + * passed to `sessionRequestHandler`. Failures are caught here and serialized + * for response. + */ + private sessionListener = (tev: unknown) => + void this.approved.then(async () => { + if (!isTransportEvent(tev)) { + console.warn('Unknown item in transport', tev); + // exit condition + } else { + const requestId = tev.requestId; + if (isTransportAbort(tev)) { + // abort control message + this.pending.get(requestId)?.abort(); + this.pending.delete(requestId); + // exit condition + } else if (this.pending.has(requestId)) { + // request collisions can't be handled + console.error('Request collision', tev); + // exit condition + } else { + // it's a new request + try { + const pendingAc = new AbortController(); + this.pending.set(requestId, pendingAc); + if (isTransportMessage(tev) || (globalThis.__DEV__ && isTransportInitChannel(tev))) { + // successful responses are posted by `sessionHandler` + await this.sessionHandler(tev, pendingAc); + } else { + throw new ConnectError('Unknown request kind', Code.Unimplemented); + } + } catch (cause) { + // attempt to provide an error response + await this.postFailure({ + requestId, + error: errorToJson(ConnectError.from(cause), undefined), + }); + } finally { + this.pending.delete(requestId); + } + } + } + }); + + /** + * Accepts a request, queries the method router for a response, and posts it + * back to the session channel. Any errors thrown from here should be caught + * and serialized into responses by `sessionListener`. + */ + private async sessionHandler( + tev: TransportMessage | TransportInitChannel, + requestAc: AbortController, + ): Promise { + const requestId = tev.requestId; + + const request = isTransportMessage(tev) + ? tev.message + : await this.manager + .acceptSubChannel(tev.channel, this.sender) + .then(({ port: approvedPort }) => new ReadableStream(new PortStreamSource(approvedPort))); + + const response = await this.manager.handler(request, requestAc.signal); + if (response instanceof ReadableStream) { + const channel = nameConnection(this.manager.managerId, ChannelLabel.STREAM); + const responseSink = this.manager + .offerSubChannel(channel, this.sender) + .then(({ port: approvedPort }) => new WritableStream(new PortStreamSink(approvedPort))); + await this.postResponse({ requestId, channel }); + await response.pipeTo(await responseSink, { signal: requestAc.signal }); + } else { + await this.postResponse({ requestId, message: response }); + } + } +} diff --git a/packages/transport-chrome/src/suppress-disconnected.ts b/packages/transport-chrome/src/suppress-disconnected.ts new file mode 100644 index 000000000..36a1df32e --- /dev/null +++ b/packages/transport-chrome/src/suppress-disconnected.ts @@ -0,0 +1,7 @@ +export const rethrowOrSuppressDisconnectedPortError = (e: unknown) => { + if (!(e instanceof Error && e.message === 'Attempting to use a disconnected port object')) { + throw e; + } else if (globalThis.__DEV__) { + console.debug('Suppressed disconnected port error', e); + } +}; diff --git a/packages/transport-chrome/src/util/senders.ts b/packages/transport-chrome/src/util/senders.ts new file mode 100644 index 000000000..fded3cbdf --- /dev/null +++ b/packages/transport-chrome/src/util/senders.ts @@ -0,0 +1,50 @@ +const compareSenders = ( + a: chrome.runtime.MessageSender, + b: chrome.runtime.MessageSender, +): boolean => + a.tab?.id === b.tab?.id && + a.documentId === b.documentId && + a.frameId === b.frameId && + a.id === b.id && + a.nativeApplication === b.nativeApplication && + a.origin === b.origin && + a.tlsChannelId === b.tlsChannelId && + a.url === b.url; + +export const assertMatchingSenders = ( + a?: chrome.runtime.MessageSender, + b?: chrome.runtime.MessageSender, +) => { + if (!a || !b) { + throw new Error('Missing sender'); + } else if (!compareSenders(a, b)) { + throw new Error('Sender mismatch'); + } +}; + +export const isSenderWithOrigin = ( + sender?: chrome.runtime.MessageSender, +): sender is chrome.runtime.MessageSender & { origin: string } => Boolean(sender?.origin); + +export const isPortWithSenderOrigin = ( + port?: chrome.runtime.Port, +): port is chrome.runtime.Port & { sender: chrome.runtime.MessageSender & { origin: string } } => + isSenderWithOrigin(port?.sender); + +export const assertPortWithSenderOrigin = ( + port?: chrome.runtime.Port, +): chrome.runtime.Port & { sender: chrome.runtime.MessageSender & { origin: string } } => { + if (!isPortWithSenderOrigin(port)) { + throw new Error('Port sender has no origin'); + } + return port; +}; + +export const assertSenderWithOrigin = ( + sender?: chrome.runtime.MessageSender, +): chrome.runtime.MessageSender & { origin: string } => { + if (!isSenderWithOrigin(sender)) { + throw new Error('Sender has no origin'); + } + return sender; +};