diff --git a/jest/mock.js b/jest/mock.js index 7f2944a..8d9ae9b 100644 --- a/jest/mock.js +++ b/jest/mock.js @@ -2,11 +2,13 @@ const { NativeModules, DeviceEventEmitter } = require('react-native') if (!NativeModules.RNLlama) { NativeModules.RNLlama = { - initContext: jest.fn(() => Promise.resolve({ - contextId: 1, - gpu: false, - reasonNoGPU: 'Test', - })), + initContext: jest.fn(() => + Promise.resolve({ + contextId: 1, + gpu: false, + reasonNoGPU: 'Test', + }), + ), completion: jest.fn(async (contextId, jobId) => { const testResult = { @@ -150,6 +152,11 @@ if (!NativeModules.RNLlama) { })), saveSession: jest.fn(async () => 1), + bench: jest.fn( + async () => + '["test 3B Q4_0",1600655360,2779683840,16.211304,0.021748,38.570646,1.195800]', + ), + releaseContext: jest.fn(() => Promise.resolve()), releaseAllContexts: jest.fn(() => Promise.resolve()), diff --git a/src/__tests__/__snapshots__/index.test.ts.snap b/src/__tests__/__snapshots__/index.test.ts.snap index ccf9c51..73c9c49 100644 --- a/src/__tests__/__snapshots__/index.test.ts.snap +++ b/src/__tests__/__snapshots__/index.test.ts.snap @@ -90,7 +90,19 @@ Array [ ] `; -exports[`Mock 2`] = ` +exports[`Mock: bench 1`] = ` +Object { + "modelDesc": "test 3B Q4_0", + "modelNParams": 2779683840, + "modelSize": 1600655360, + "ppAvg": 16.211304, + "ppStd": 0.021748, + "tgAvg": 38.570646, + "tgStd": 1.1958, +} +`; + +exports[`Mock: completion result 1`] = ` Object { "completion_probabilities": Array [ Object { diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 8bcca60..e980961 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -17,7 +17,9 @@ test('Mock', async () => { events.push(data) }) expect(events).toMatchSnapshot() - expect(completionResult).toMatchSnapshot() + expect(completionResult).toMatchSnapshot('completion result') + + expect(await context.bench(512, 128, 1, 3)).toMatchSnapshot('bench') await context.release() await releaseAllLlama() diff --git a/src/index.ts b/src/index.ts index cc7843a..1e8a187 100644 --- a/src/index.ts +++ b/src/index.ts @@ -128,6 +128,7 @@ export class LlamaContext { async bench(pp: number, tg: number, pl: number, nr: number): Promise { const result = await RNLlama.bench(this.id, pp, tg, pl, nr) + console.log(result) const [ modelDesc, modelSize,