diff --git a/src/lib/chat/matrix-client.test.ts b/src/lib/chat/matrix-client.test.ts index c1a038f0f..1967fc6c0 100644 --- a/src/lib/chat/matrix-client.test.ts +++ b/src/lib/chat/matrix-client.test.ts @@ -51,35 +51,164 @@ const getSdkClient = (sdkClient = {}) => ({ login: async () => ({}), initCrypto: async () => null, startClient: jest.fn(async () => undefined), + stopClient: jest.fn(), on: jest.fn((topic, callback) => { if (topic === 'sync') callback('PREPARED'); }), getRooms: jest.fn(), getAccountData: jest.fn(), getUser: jest.fn(), + setGlobalErrorOnUnknownDevices: () => undefined, ...sdkClient, }); -const subject = (props = {}) => { +const subject = (props = {}, sessionStorage = {}) => { const allProps: any = { createClient: (_opts: any) => getSdkClient(), ...props, }; - return new MatrixClient(allProps); + const mockSessionStorage: any = { + get: () => ({ deviceId: '', accessToken: '', userId: '' }), + set: (_session) => undefined, + clear: () => undefined, + ...sessionStorage, + }; + + return new MatrixClient(allProps, mockSessionStorage); }; +function resolveWith(valueToResolve: T) { + let theResolve; + const promise = new Promise((resolve) => { + theResolve = async () => { + resolve(valueToResolve); + await new Promise((resolve) => setImmediate(resolve)); + await promise; + }; + }); + + return { resolve: theResolve, mock: () => promise }; +} + describe('matrix client', () => { + describe('disconnect', () => { + it('stops client on disconnect', async () => { + const sdkClient = getSdkClient(); + const createClient = jest.fn(() => sdkClient); + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const client = subject({ createClient }, { get: () => matrixSession }); + + // initializes underlying matrix client + await client.connect(null, 'token'); + + client.disconnect(); + + expect(sdkClient.stopClient).toHaveBeenCalledOnce(); + }); + + it('clears session storage on disconnect', async () => { + const sdkClient = getSdkClient(); + const createClient = jest.fn(() => sdkClient); + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const clearSession = jest.fn(); + + const client = subject({ createClient }, { clear: clearSession, get: () => matrixSession }); + + // initializes underlying matrix client + await client.connect(null, 'token'); + + expect(clearSession).not.toHaveBeenCalled(); + + client.disconnect(); + + expect(clearSession).toHaveBeenCalledOnce(); + }); + }); + describe('createclient', () => { - it('creates SDK client on connect', () => { + it('creates SDK client with existing session on connect', async () => { const sdkClient = getSdkClient(); const createClient = jest.fn(() => sdkClient); + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; - const client = subject({ createClient }); + const client = subject({ createClient }, { get: () => matrixSession }); client.connect(null, 'token'); - expect(createClient).toHaveBeenCalledWith(expect.objectContaining({ baseUrl: config.matrix.homeServerUrl })); + await new Promise((resolve) => setImmediate(resolve)); + + expect(createClient).toHaveBeenCalledWith({ + baseUrl: config.matrix.homeServerUrl, + ...matrixSession, + }); + }); + + it('logs in and creates SDK client with new session if none exists', async () => { + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const { resolve, mock } = resolveWith({ + device_id: matrixSession.deviceId, + user_id: matrixSession.userId, + access_token: matrixSession.accessToken, + }); + + const createClient = jest.fn(() => getSdkClient({ login: mock })); + + const client = subject({ createClient }, { get: () => null }); + + client.connect(null, 'token'); + + await resolve(); + + expect(createClient).toHaveBeenNthCalledWith(2, { + baseUrl: config.matrix.homeServerUrl, + ...matrixSession, + }); + }); + + it('saves session if none exists', async () => { + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const setSession = jest.fn(); + + const { resolve, mock } = resolveWith({ + device_id: matrixSession.deviceId, + user_id: matrixSession.userId, + access_token: matrixSession.accessToken, + }); + + const createClient = jest.fn(() => getSdkClient({ login: mock })); + + const client = subject({ createClient }, { get: () => null, set: setSession }); + + client.connect(null, 'token'); + + await resolve(); + + expect(setSession).toHaveBeenCalledWith(matrixSession); }); it('starts client on connect', async () => { @@ -347,9 +476,24 @@ describe('matrix client', () => { expect(createRoom).toHaveBeenCalledWith( expect.objectContaining({ - initial_state: [ + initial_state: expect.arrayContaining([ { type: 'm.room.guest_access', state_key: '', content: { guest_access: GuestAccess.Forbidden } }, - ], + ]), + }) + ); + }); + + it('creates encrypted room', async () => { + const createRoom = jest.fn().mockResolvedValue({ room_id: 'new-room-id' }); + const client = await subject({ createRoom }); + + await client.createConversation([{ userId: 'id', matrixId: '@somebody.else' }], null, null, null); + + expect(createRoom).toHaveBeenCalledWith( + expect.objectContaining({ + initial_state: expect.arrayContaining([ + { type: 'm.room.encryption', state_key: '', content: { algorithm: 'm.megolm.v1.aes-sha2' } }, + ]), }) ); }); diff --git a/src/lib/chat/matrix-client.ts b/src/lib/chat/matrix-client.ts index f5ab1bea7..b8bdf9d49 100644 --- a/src/lib/chat/matrix-client.ts +++ b/src/lib/chat/matrix-client.ts @@ -29,6 +29,7 @@ import { MemberNetworks } from '../../store/users/types'; import { ConnectionStatus, MembershipStateType } from './matrix/types'; import { getFilteredMembersForAutoComplete, setAsDM } from './matrix/utils'; import { uploadImage } from '../../store/channels-list/api'; +import { SessionStorage } from './session-storage'; export class MatrixClient implements IChatClient { private matrix: SDKMatrixClient = null; @@ -41,7 +42,7 @@ export class MatrixClient implements IChatClient { private connectionResolver: () => void; private connectionAwaiter: Promise; - constructor(private sdk = { createClient }) { + constructor(private sdk = { createClient }, private sessionStorage = new SessionStorage()) { this.addConnectionAwaiter(); } @@ -62,7 +63,11 @@ export class MatrixClient implements IChatClient { return this.userId; } - disconnect: () => void; + disconnect() { + this.matrix.stopClient(); + this.sessionStorage.clear(); + } + reconnect: () => void; async getAccountData(eventType: string) { @@ -154,6 +159,7 @@ export class MatrixClient implements IChatClient { const initial_state: any[] = [ { type: EventType.RoomGuestAccess, state_key: '', content: { guest_access: GuestAccess.Forbidden } }, + { type: EventType.RoomEncryption, state_key: '', content: { algorithm: 'm.megolm.v1.aes-sha2' } }, ]; if (coverUrl) { @@ -258,7 +264,7 @@ export class MatrixClient implements IChatClient { } if (event.type === EventType.RoomMessage) { - this.events.receiveNewMessage(event.room_id, mapMatrixMessage(event, this.matrix) as any); + this.publishMessageEvent(event); } if (event.type === EventType.RoomCreate) { @@ -276,6 +282,13 @@ export class MatrixClient implements IChatClient { } }); + this.matrix.on(MatrixEventEvent.Decrypted, async (decryptedEvent: MatrixEvent) => { + const event = decryptedEvent.getEffectiveEvent(); + if (event.type === EventType.RoomMessage) { + this.publishMessageEvent(event); + } + }); + this.matrix.on(ClientEvent.AccountData, this.publishConversationListChange); this.matrix.on(ClientEvent.Event, this.publishUserPresenceChange); this.matrix.on(RoomEvent.Name, this.publishRoomNameChange); @@ -303,21 +316,50 @@ export class MatrixClient implements IChatClient { return (data) => console.log('Received Event', name, data); } - private async initializeClient(_userId: string, accessToken: string) { + private async getCredentials(accessToken: string) { + const credentials = this.sessionStorage.get(); + + if (credentials) { + return credentials; + } + + return await this.login(accessToken); + } + + private async login(token: string) { + const tempClient = this.sdk.createClient({ baseUrl: config.matrix.homeServerUrl }); + + const { user_id, device_id, access_token } = await tempClient.login('org.matrix.login.jwt', { token }); + + this.sessionStorage.set({ + userId: user_id, + deviceId: device_id, + accessToken: access_token, + }); + + return { accessToken: access_token, userId: user_id, deviceId: device_id }; + } + + private async initializeClient(_userId: string, ssoToken: string) { if (!this.matrix) { - this.matrix = this.sdk.createClient({ + const opts: any = { baseUrl: config.matrix.homeServerUrl, - }); + ...(await this.getCredentials(ssoToken)), + }; - const loginResult = await this.matrix.login('org.matrix.login.jwt', { token: accessToken }); + this.matrix = this.sdk.createClient(opts); - this.matrix.deviceId = loginResult.device_id; await this.matrix.initCrypto(); + // suppsedly the setter is deprecated, but the direct property set doesn't seem to work. + // this is hopefully only a short-term setting anyway, so just leaving for now. + // this.matrix.getCrypto().globalBlacklistUnverifiedDevices = false; + this.matrix.setGlobalErrorOnUnknownDevices(false); + await this.matrix.startClient(); await this.waitForSync(); - return loginResult.user_id; + return opts.userId; } } @@ -343,6 +385,10 @@ export class MatrixClient implements IChatClient { this.events.onUserJoinedChannel(this.mapChannel(this.matrix.getRoom(event.room_id))); } + private publishMessageEvent(event) { + this.events.receiveNewMessage(event.room_id, mapMatrixMessage(event, this.matrix) as any); + } + private publishConversationListChange = (event: MatrixEvent) => { if (event.getType() === EventType.Direct) { const content = event.getContent(); diff --git a/src/lib/chat/session-storage.test.ts b/src/lib/chat/session-storage.test.ts new file mode 100644 index 000000000..e499e7c0b --- /dev/null +++ b/src/lib/chat/session-storage.test.ts @@ -0,0 +1,77 @@ +import { SessionStorage } from './session-storage'; + +const setItem = jest.fn(); +const getItem = jest.fn(); +const removeItem = jest.fn(); + +describe('session storage', () => { + const subject = (mockLocalStorage = {}) => { + return new SessionStorage({ + getItem, + setItem, + removeItem, + ...mockLocalStorage, + } as any); + }; + + it('sets localStorage vars', async () => { + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const client = subject(); + + client.set(matrixSession); + + expect(setItem).toHaveBeenCalledWith('mxz_device_id', 'abc123'); + expect(setItem).toHaveBeenCalledWith('mxz_access_token_abc123', 'token-4321'); + expect(setItem).toHaveBeenCalledWith('mxz_user_id', '@bob:zos-matrix'); + }); + + it('removes localStorage vars on clear', async () => { + const getItem = jest.fn((key) => (key === 'mxz_device_id' ? 'abc123' : '')); + const client = subject({ getItem }); + + client.clear(); + + expect(removeItem).toHaveBeenCalledWith('mxz_device_id'); + expect(removeItem).toHaveBeenCalledWith('mxz_access_token_abc123'); + expect(removeItem).toHaveBeenCalledWith('mxz_user_id'); + }); + + it('gets from localStorage vars', async () => { + const matrixSession = { + deviceId: 'abc123', + accessToken: 'token-4321', + userId: '@bob:zos-matrix', + }; + + const getItem = jest.fn((key) => { + return { + mxz_device_id: 'abc123', + mxz_access_token_abc123: 'token-4321', + mxz_user_id: '@bob:zos-matrix', + }[key]; + }); + + const client = subject({ getItem }); + + expect(client.get()).toEqual(matrixSession); + }); + + it('returns null if deviceId is not set', async () => { + const getItem = jest.fn((key) => { + return { + mxz_device_id: '', + mxz_access_token_abc123: 'token-4321', + mxz_user_id: '@bob:zos-matrix', + }[key]; + }); + + const client = subject({ getItem }); + + expect(client.get()).toBeNull(); + }); +}); diff --git a/src/lib/chat/session-storage.ts b/src/lib/chat/session-storage.ts new file mode 100644 index 000000000..fe9827cc2 --- /dev/null +++ b/src/lib/chat/session-storage.ts @@ -0,0 +1,35 @@ +export interface ChatSession { + deviceId: string; + accessToken: string; + userId: string; +} + +export class SessionStorage { + constructor(private storage = localStorage) {} + + clear() { + const deviceId = this.storage.getItem('mxz_device_id'); + + this.storage.removeItem('mxz_device_id'); + this.storage.removeItem(`mxz_access_token_${deviceId}`); + this.storage.removeItem('mxz_user_id'); + } + + set(session: ChatSession) { + this.storage.setItem('mxz_device_id', session.deviceId); + this.storage.setItem(`mxz_access_token_${session.deviceId}`, session.accessToken); + this.storage.setItem('mxz_user_id', session.userId); + } + + get(): ChatSession { + const deviceId = this.storage.getItem('mxz_device_id'); + + if (!deviceId) return null; + + return { + deviceId, + accessToken: this.storage.getItem(`mxz_access_token_${deviceId}`), + userId: this.storage.getItem('mxz_user_id'), + }; + } +}