Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

E2e encryption #1098

Merged
merged 5 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 151 additions & 7 deletions src/lib/chat/matrix-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,164 @@ const getSdkClient = (sdkClient = {}) => ({
login: async () => ({}),
initCrypto: async () => null,
startClient: jest.fn(async () => undefined),
stopClient: jest.fn(),
on: jest.fn((topic, callback) => {
if (topic === 'sync') callback('PREPARED');
}),
getRooms: jest.fn(),
getAccountData: jest.fn(),
getUser: jest.fn(),
setGlobalErrorOnUnknownDevices: () => undefined,
...sdkClient,
});

const subject = (props = {}) => {
const subject = (props = {}, sessionStorage = {}) => {
const allProps: any = {
createClient: (_opts: any) => getSdkClient(),
...props,
};

return new MatrixClient(allProps);
const mockSessionStorage: any = {
get: () => ({ deviceId: '', accessToken: '', userId: '' }),
set: (_session) => undefined,
clear: () => undefined,
...sessionStorage,
};

return new MatrixClient(allProps, mockSessionStorage);
};

function resolveWith<T>(valueToResolve: T) {
let theResolve;
const promise = new Promise((resolve) => {
theResolve = async () => {
resolve(valueToResolve);
await new Promise((resolve) => setImmediate(resolve));
await promise;
};
});

return { resolve: theResolve, mock: () => promise };
}

describe('matrix client', () => {
describe('disconnect', () => {
it('stops client on disconnect', async () => {
const sdkClient = getSdkClient();
const createClient = jest.fn(() => sdkClient);
const matrixSession = {
deviceId: 'abc123',
accessToken: 'token-4321',
userId: '@bob:zos-matrix',
};

const client = subject({ createClient }, { get: () => matrixSession });

// initializes underlying matrix client
await client.connect(null, 'token');

client.disconnect();

expect(sdkClient.stopClient).toHaveBeenCalledOnce();
});

it('clears session storage on disconnect', async () => {
const sdkClient = getSdkClient();
const createClient = jest.fn(() => sdkClient);
const matrixSession = {
deviceId: 'abc123',
accessToken: 'token-4321',
userId: '@bob:zos-matrix',
};

const clearSession = jest.fn();

const client = subject({ createClient }, { clear: clearSession, get: () => matrixSession });

// initializes underlying matrix client
await client.connect(null, 'token');

expect(clearSession).not.toHaveBeenCalled();

client.disconnect();

expect(clearSession).toHaveBeenCalledOnce();
});
});

describe('createclient', () => {
it('creates SDK client on connect', () => {
it('creates SDK client with existing session on connect', async () => {
const sdkClient = getSdkClient();
const createClient = jest.fn(() => sdkClient);
const matrixSession = {
deviceId: 'abc123',
accessToken: 'token-4321',
userId: '@bob:zos-matrix',
};

const client = subject({ createClient });
const client = subject({ createClient }, { get: () => matrixSession });

client.connect(null, 'token');

expect(createClient).toHaveBeenCalledWith(expect.objectContaining({ baseUrl: config.matrix.homeServerUrl }));
await new Promise((resolve) => setImmediate(resolve));

expect(createClient).toHaveBeenCalledWith({
baseUrl: config.matrix.homeServerUrl,
...matrixSession,
});
});

it('logs in and creates SDK client with new session if none exists', async () => {
const matrixSession = {
deviceId: 'abc123',
accessToken: 'token-4321',
userId: '@bob:zos-matrix',
};

const { resolve, mock } = resolveWith({
device_id: matrixSession.deviceId,
user_id: matrixSession.userId,
access_token: matrixSession.accessToken,
});

const createClient = jest.fn(() => getSdkClient({ login: mock }));

const client = subject({ createClient }, { get: () => null });

client.connect(null, 'token');

await resolve();

expect(createClient).toHaveBeenNthCalledWith(2, {
baseUrl: config.matrix.homeServerUrl,
...matrixSession,
});
});

it('saves session if none exists', async () => {
const matrixSession = {
deviceId: 'abc123',
accessToken: 'token-4321',
userId: '@bob:zos-matrix',
};

const setSession = jest.fn();

const { resolve, mock } = resolveWith({
device_id: matrixSession.deviceId,
user_id: matrixSession.userId,
access_token: matrixSession.accessToken,
});

const createClient = jest.fn(() => getSdkClient({ login: mock }));

const client = subject({ createClient }, { get: () => null, set: setSession });

client.connect(null, 'token');

await resolve();

expect(setSession).toHaveBeenCalledWith(matrixSession);
});

it('starts client on connect', async () => {
Expand Down Expand Up @@ -347,9 +476,24 @@ describe('matrix client', () => {

expect(createRoom).toHaveBeenCalledWith(
expect.objectContaining({
initial_state: [
initial_state: expect.arrayContaining([
{ type: 'm.room.guest_access', state_key: '', content: { guest_access: GuestAccess.Forbidden } },
],
]),
})
);
});

it('creates encrypted room', async () => {
const createRoom = jest.fn().mockResolvedValue({ room_id: 'new-room-id' });
const client = await subject({ createRoom });

await client.createConversation([{ userId: 'id', matrixId: '@somebody.else' }], null, null, null);

expect(createRoom).toHaveBeenCalledWith(
expect.objectContaining({
initial_state: expect.arrayContaining([
{ type: 'm.room.encryption', state_key: '', content: { algorithm: 'm.megolm.v1.aes-sha2' } },
]),
})
);
});
Expand Down
64 changes: 55 additions & 9 deletions src/lib/chat/matrix-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { MemberNetworks } from '../../store/users/types';
import { ConnectionStatus, MembershipStateType } from './matrix/types';
import { getFilteredMembersForAutoComplete, setAsDM } from './matrix/utils';
import { uploadImage } from '../../store/channels-list/api';
import { SessionStorage } from './session-storage';

export class MatrixClient implements IChatClient {
private matrix: SDKMatrixClient = null;
Expand All @@ -41,7 +42,7 @@ export class MatrixClient implements IChatClient {
private connectionResolver: () => void;
private connectionAwaiter: Promise<void>;

constructor(private sdk = { createClient }) {
constructor(private sdk = { createClient }, private sessionStorage = new SessionStorage()) {
this.addConnectionAwaiter();
}

Expand All @@ -62,7 +63,11 @@ export class MatrixClient implements IChatClient {
return this.userId;
}

disconnect: () => void;
disconnect() {
this.matrix.stopClient();
this.sessionStorage.clear();
}

reconnect: () => void;

async getAccountData(eventType: string) {
Expand Down Expand Up @@ -154,6 +159,7 @@ export class MatrixClient implements IChatClient {

const initial_state: any[] = [
{ type: EventType.RoomGuestAccess, state_key: '', content: { guest_access: GuestAccess.Forbidden } },
{ type: EventType.RoomEncryption, state_key: '', content: { algorithm: 'm.megolm.v1.aes-sha2' } },
];

if (coverUrl) {
Expand Down Expand Up @@ -258,7 +264,7 @@ export class MatrixClient implements IChatClient {
}

if (event.type === EventType.RoomMessage) {
this.events.receiveNewMessage(event.room_id, mapMatrixMessage(event, this.matrix) as any);
this.publishMessageEvent(event);
}

if (event.type === EventType.RoomCreate) {
Expand All @@ -276,6 +282,13 @@ export class MatrixClient implements IChatClient {
}
});

this.matrix.on(MatrixEventEvent.Decrypted, async (decryptedEvent: MatrixEvent) => {
const event = decryptedEvent.getEffectiveEvent();
if (event.type === EventType.RoomMessage) {
this.publishMessageEvent(event);
}
});

this.matrix.on(ClientEvent.AccountData, this.publishConversationListChange);
this.matrix.on(ClientEvent.Event, this.publishUserPresenceChange);
this.matrix.on(RoomEvent.Name, this.publishRoomNameChange);
Expand Down Expand Up @@ -303,21 +316,50 @@ export class MatrixClient implements IChatClient {
return (data) => console.log('Received Event', name, data);
}

private async initializeClient(_userId: string, accessToken: string) {
private async getCredentials(accessToken: string) {
const credentials = this.sessionStorage.get();

if (credentials) {
return credentials;
}

return await this.login(accessToken);
}

private async login(token: string) {
const tempClient = this.sdk.createClient({ baseUrl: config.matrix.homeServerUrl });

const { user_id, device_id, access_token } = await tempClient.login('org.matrix.login.jwt', { token });

this.sessionStorage.set({
userId: user_id,
deviceId: device_id,
accessToken: access_token,
});

return { accessToken: access_token, userId: user_id, deviceId: device_id };
}

private async initializeClient(_userId: string, ssoToken: string) {
if (!this.matrix) {
this.matrix = this.sdk.createClient({
const opts: any = {
baseUrl: config.matrix.homeServerUrl,
});
...(await this.getCredentials(ssoToken)),
};

const loginResult = await this.matrix.login('org.matrix.login.jwt', { token: accessToken });
this.matrix = this.sdk.createClient(opts);

this.matrix.deviceId = loginResult.device_id;
await this.matrix.initCrypto();

// suppsedly the setter is deprecated, but the direct property set doesn't seem to work.
// this is hopefully only a short-term setting anyway, so just leaving for now.
// this.matrix.getCrypto().globalBlacklistUnverifiedDevices = false;
this.matrix.setGlobalErrorOnUnknownDevices(false);

await this.matrix.startClient();
await this.waitForSync();

return loginResult.user_id;
return opts.userId;
}
}

Expand All @@ -343,6 +385,10 @@ export class MatrixClient implements IChatClient {
this.events.onUserJoinedChannel(this.mapChannel(this.matrix.getRoom(event.room_id)));
}

private publishMessageEvent(event) {
this.events.receiveNewMessage(event.room_id, mapMatrixMessage(event, this.matrix) as any);
}

private publishConversationListChange = (event: MatrixEvent) => {
if (event.getType() === EventType.Direct) {
const content = event.getContent();
Expand Down
Loading
Loading