Skip to content

Commit

Permalink
session client reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
turbocrime committed Feb 19, 2025
1 parent 918b032 commit d880369
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 54 deletions.
4 changes: 2 additions & 2 deletions packages/transport-chrome/src/session-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ describe('CRSessionClient', () => {
await expectChannelClosed();
});

it('sends `false` to dom when the port is disconnected by something else', async () => {
it.fails('sends `false` to dom when the port is disconnected by something else', async () => {
expectNoActivity();

// extension-side disconnect
Expand All @@ -433,7 +433,7 @@ describe('CRSessionClient', () => {
await expectChannelClosed();
});

it.fails('reconnects silently if the port is disconnected by something else', async () => {
it('reconnects silently if the port is disconnected by something else', async () => {
const testRequest: TransportMessage = { message: 'hello', requestId: '123' };

expectNoActivity();
Expand Down
18 changes: 15 additions & 3 deletions packages/transport-chrome/src/session-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class CRSessionClient {
includeTlsChannelId: true,
name: this.sessionName,
});
this.servicePort.onDisconnect.addListener(this.disconnect);
this.servicePort.onDisconnect.addListener(this.reconnect);
this.servicePort.onMessage.addListener(this.serviceListener);

// listen to client
Expand Down Expand Up @@ -108,8 +108,7 @@ export class CRSessionClient {
};

/**
* Used to tear down this session when the client announces channel closure,
* or when the extension channel disconnects.
* Used to tear down this session when the client announces channel closure.
*
* Announces closure from this side towards the document, and ensures closure
* of both ports. Listeners are automatically gargbage-collected. This session
Expand All @@ -124,6 +123,19 @@ export class CRSessionClient {
this.servicePort.disconnect();
};

/**
* Used when the service port disconnects. Tears down the service port and
* reconnects, creating a new service port.
*/
private reconnect = () => {
this.servicePort = chrome.runtime.connect({
name: this.sessionName,
includeTlsChannelId: true,
});
this.servicePort.onMessage.addListener(this.serviceListener);
this.servicePort.onDisconnect.addListener(this.reconnect);
};

/**
* Listens for messages from the client, and forwards them to the service.
*
Expand Down
19 changes: 3 additions & 16 deletions packages/transport-chrome/src/session-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,9 @@ import {
import type { ChannelHandlerFn } from '@penumbra-zone/transport-dom/adapter';
import { beforeEach, describe, expect, it, type MockedFunction, vi } from 'vitest';
import { ChannelLabel, nameConnection } from './channel-names.js';
import { CRSessionManager as CRSessionManagerOriginal } from './session-manager.js';
import { CRSessionManager } from './session-manager.js';
import { lastResult } from './util/test-utils.js';

const CRSessionManager: typeof CRSessionManagerOriginal & {
// forward-compatible type. third parameter of init is required in new
// signature, but absent in old signature. a new implementation will use it,
// an old implementation will ignore it.
init: (
prefix: string,
handler: ChannelHandlerFn,
approvePort: (
port: chrome.runtime.Port,
) => Promise<chrome.runtime.Port & { sender: { origin: string } }>,
) => ReturnType<typeof CRSessionManagerOriginal.init>;
} = CRSessionManagerOriginal;

const getOnlySession = (sessions: ReturnType<typeof CRSessionManager.init>) => {
expect(sessions.size).toBe(1);
const onlySession = sessions.values().next();
Expand Down Expand Up @@ -148,7 +135,7 @@ describe('CRSessionManager', () => {
const testRequest = { requestId: '123', message: 'test' };
const allSenders = [extClient, localhostClient, httpsClient, httpClient];

it.each(allSenders)(
it.fails.each(allSenders)(
'should accept or reject $origin according to internal sender validation logic',
async someSender => {
const badSenders = [httpClient];
Expand Down Expand Up @@ -199,7 +186,7 @@ describe('CRSessionManager', () => {
);

const badSenders = allSenders.sort(() => Math.random() - 0.5).slice(2);
it.fails.each(allSenders)(
it.each(allSenders)(
`should accept or reject sender %# according to external sender validation callback permitting ${badSenders.map(s => allSenders.indexOf(s)).join(', ')}`,
async someSender => {
checkPortSender.mockImplementationOnce(port => {
Expand Down
47 changes: 14 additions & 33 deletions packages/transport-chrome/src/session-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,6 @@ type SenderWithOrigin = chrome.runtime.MessageSender & { origin: string };
type PortWithOrigin = chrome.runtime.Port & { sender: SenderWithOrigin };
export type CheckPortSenderFn = (port: chrome.runtime.Port) => Promise<PortWithOrigin>;

// Move port validation logic to a distinct function, anticipating to parameterize it.
const defaultCheckPortSender: CheckPortSenderFn = port => {
// Allow connections from the same extension, from https pages, or from http://localhost
if (port.sender?.origin) {
const fromThisExtension = port.sender.id === chrome.runtime.id;
const fromPageHttps =
!port.sender.frameId && !!port.sender.tab?.id && port.sender.origin.startsWith('https://');
const isLocalhost =
port.sender.origin.startsWith('http://localhost:') ||
port.sender.origin === 'http://localhost';

if (isLocalhost || fromPageHttps || fromThisExtension) {
return Promise.resolve(port as unknown as PortWithOrigin);
}
}

throw new Error('Invalid sender');
};

/**
* Only for use as an extension-level singleton by the extension's main
* background worker.
Expand Down Expand Up @@ -81,7 +62,7 @@ export class CRSessionManager {
* @param handler your router entry function
* @param checkPortSender a function used to validate the sender of a connection
*/
private constructor(
constructor(
private readonly managerId: string,
private readonly handler: ChannelHandlerFn,
private readonly checkPortSender: CheckPortSenderFn,
Expand All @@ -98,9 +79,14 @@ export class CRSessionManager {
*
* @param managerId a string identifying this manager
* @param handler your router entry function
* @param checkPortSender function to assert validity of a sender
*/
public static init = (managerId: string, handler: ChannelHandlerFn) => {
CRSessionManager.singleton ??= new CRSessionManager(managerId, handler, defaultCheckPortSender);
public static init = (
managerId: string,
handler: ChannelHandlerFn,
checkPortSender: CheckPortSenderFn,
) => {
CRSessionManager.singleton ??= new CRSessionManager(managerId, handler, checkPortSender);
return CRSessionManager.singleton.sessions;
};

Expand Down Expand Up @@ -272,17 +258,12 @@ export class CRSessionManager {
throw new ConnectError('Unknown request kind', Code.Unimplemented);
}

return this.handler(request, AbortSignal.any([session.signal, ac.signal]))
.then(response =>
response instanceof ReadableStream
? { requestId, channel: this.makeChannelStreamResponse(response) }
: { requestId, message: response },
)
.catch((error: unknown) => ({
requestId,
error: errorToJson(ConnectError.from(error), undefined),
}))
.finally(() => session.requests.delete(requestId));
const response = await this.handler(request, AbortSignal.any([session.signal, ac.signal]));
if (response instanceof ReadableStream) {
return { requestId, channel: this.makeChannelStreamResponse(response) };
} else {
return { requestId, message: response };
}
};

/**
Expand Down

0 comments on commit d880369

Please sign in to comment.