Skip to content

Commit

Permalink
feat: implement transcribe realtime method (#10)
Browse files Browse the repository at this point in the history
* feat(ios): implement transcribeRealtime method

* feat(ios): remove isRealtimeSetup condition

* fix(ios): check isTranscribing on start realtime

* feat(android): implement transcribeRealtime method

* feat(jest): update mock & test

* chore: cleanup

* feat(ios): implement end event internally

* feat(android): implement end event internally

* feat: rename to realtimeAudioSec

* fix(android): check capturing instead of transcribing

* feat(example): improve render

* feat(ts): add comments as doc

* docs(readme): add reatime transcribe setup & usage

* feat(native): check in realtime mode in transcribeFile

* chore(ios): clear unnecessary logs & improve some logs
  • Loading branch information
jhen0409 authored Mar 27, 2023
1 parent 7d39694 commit f63c49f
Show file tree
Hide file tree
Showing 15 changed files with 1,006 additions and 181 deletions.
47 changes: 43 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@ npm install whisper.rn

Then re-run `npx pod-install` again for iOS.

## Add Microphone Permissions (Optional)

If you want to use realtime transcribe, you need to add the microphone permission to your app.

### iOS
Add these lines to ```ios/[YOU_APP_NAME]/info.plist```
```xml
<key>NSMicrophoneUsageDescription</key>
<string>This app requires microphone access in order to transcribe speech</string>
```

### Android
Add the following line to ```android/app/src/main/AndroidManifest.xml```
```xml
<uses-permission android:name="android.permission.RECORD_AUDIO" />
```

## Usage

```js
Expand All @@ -30,13 +47,35 @@ const sampleFilePath = 'file://.../sample.wav'

const whisperContext = await initWhisper({ filePath })

const { result } = await whisperContext.transcribe(sampleFilePath, {
language: 'en',
// More options
})
const options = { language: 'en' }
const { stop, promise } = whisperContext.transcribe(sampleFilePath, options)

const { result } = await promise
// result: (The inference text result from audio file)
```

Use realtime transcribe:

```js
const { stop, subscribe } = whisperContext.transcribeRealtime(options)

subscribe(evt => {
const { isCapturing, data, processTime, recordingTime } = evt
console.log(
`Realtime transcribing: ${isCapturing ? 'ON' : 'OFF'}\n` +
// The inference text result from audio record:
`Result: ${data.result}\n\n` +
`Process time: ${processTime}ms\n` +
`Recording time: ${recordingTime}ms`,
)
if (!isCapturing) console.log('Finished realtime transcribing')
})
```

In Android, you may need to request the microphone permission by [`PermissionAndroid`](https://reactnative.dev/docs/permissionsandroid).

The documentation is not ready yet, please see the comments of [index](./src/index.tsx) file for more details at the moment.

## Run with example

The example app is using [react-native-fs](https://github.com/itinance/react-native-fs) to download the model file and audio file.
Expand Down
52 changes: 43 additions & 9 deletions android/src/main/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import android.os.Build;
import android.os.Handler;
import android.os.AsyncTask;
import android.media.AudioRecord;

import com.facebook.react.bridge.Promise;
import com.facebook.react.bridge.ReactApplicationContext;
Expand Down Expand Up @@ -51,7 +52,7 @@ protected Integer doInBackground(Void... voids) {
throw new Exception("Failed to initialize context");
}
int id = Math.abs(new Random().nextInt());
WhisperContext whisperContext = new WhisperContext(context);
WhisperContext whisperContext = new WhisperContext(id, reactContext, context);
contexts.put(id, whisperContext);
return id;
} catch (Exception e) {
Expand All @@ -72,18 +73,27 @@ protected void onPostExecute(Integer id) {
}

@ReactMethod
public void transcribe(int id, int jobId, String filePath, ReadableMap options, Promise promise) {
public void transcribeFile(int id, int jobId, String filePath, ReadableMap options, Promise promise) {
final WhisperContext context = contexts.get(id);
if (context == null) {
promise.reject("Context not found");
return;
}
if (context.isCapturing()) {
promise.reject("The context is in realtime transcribe mode");
return;
}
if (context.isTranscribing()) {
promise.reject("Context is already transcribing");
return;
}
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
WhisperContext context = contexts.get(id);
if (context == null) {
throw new Exception("Context " + id + " not found");
}
return context.transcribe(jobId, filePath, options);
return context.transcribeFile(jobId, filePath, options);
} catch (Exception e) {
exception = e;
return null;
Expand All @@ -102,8 +112,32 @@ protected void onPostExecute(WritableMap data) {
}

@ReactMethod
public void abortTranscribe(int jobId) {
WhisperContext.abortTranscribe(jobId);
public void startRealtimeTranscribe(int id, int jobId, ReadableMap options, Promise promise) {
final WhisperContext context = contexts.get(id);
if (context == null) {
promise.reject("Context not found");
return;
}
if (context.isCapturing()) {
promise.reject("Context is already in capturing");
return;
}
int state = context.startRealtimeTranscribe(jobId, options);
if (state == AudioRecord.STATE_INITIALIZED) {
promise.resolve(null);
return;
}
promise.reject("Failed to start realtime transcribe. State: " + state);
}

@ReactMethod
public void abortTranscribe(int contextId, int jobId, Promise promise) {
WhisperContext context = contexts.get(contextId);
if (context == null) {
promise.reject("Context not found");
return;
}
context.stopTranscribe(jobId);
}

@ReactMethod
Expand Down
199 changes: 192 additions & 7 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.modules.core.DeviceEventManagerModule;

import android.util.Log;
import android.os.Build;
import android.content.res.AssetManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder.AudioSource;

import java.util.Random;
import java.lang.StringBuilder;
Expand All @@ -26,17 +31,175 @@

public class WhisperContext {
public static final String NAME = "RNWhisperContext";

private static final int SAMPLE_RATE = 16000;
private static final int CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO;
private static final int AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT;
private static final int AUDIO_SOURCE = AudioSource.VOICE_RECOGNITION;
private static final int DEFAULT_MAX_AUDIO_SEC = 30;

private int id;
private ReactApplicationContext reactContext;
private long context;

public WhisperContext(long context) {
private DeviceEventManagerModule.RCTDeviceEventEmitter eventEmitter;

private int jobId = -1;
private AudioRecord recorder = null;
private int bufferSize;
private short[] buffer16;
private int nSamples = 0;
private boolean isCapturing = false;
private boolean isTranscribing = false;
private boolean isRealtime = false;

public WhisperContext(int id, ReactApplicationContext reactContext, long context) {
this.id = id;
this.context = context;
this.reactContext = reactContext;
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT);
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
if (isCapturing || isTranscribing) {
return -100;
}

recorder = new AudioRecord(AUDIO_SOURCE, SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT, bufferSize);

int state = recorder.getState();
if (state != AudioRecord.STATE_INITIALIZED) {
recorder.release();
return state;
}

int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
final int maxAudioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;

buffer16 = new short[maxAudioSec * SAMPLE_RATE * Short.BYTES];

this.jobId = jobId;
isCapturing = true;
isRealtime = true;
nSamples = 0;

recorder.startRecording();

new Thread(new Runnable() {
@Override
public void run() {
try {
short[] buffer = new short[bufferSize];
Thread fullHandler = null;
while (isCapturing) {
try {
int n = recorder.read(buffer, 0, bufferSize);
if (n == 0) continue;

if (nSamples + n > maxAudioSec * SAMPLE_RATE) {
// Full, ignore data
isCapturing = false;
if (!isTranscribing)
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
break;
}
nSamples += n;
for (int i = 0; i < n; i++) {
buffer16[nSamples + i] = buffer[i];
}
if (!isTranscribing && nSamples > SAMPLE_RATE / 2) {
isTranscribing = true;
Log.d(NAME, "Start transcribing realtime: " + nSamples);
fullHandler = new Thread(new Runnable() {
@Override
public void run() {
if (!isCapturing) return;

// convert I16 to F32
float[] nSamplesBuffer32 = new float[nSamples];
for (int i = 0; i < nSamples; i++) {
nSamplesBuffer32[i] = buffer16[i] / 32768.0f;
}

int timeStart = (int) System.currentTimeMillis();
int code = full(jobId, options, nSamplesBuffer32, nSamples);
int timeEnd = (int) System.currentTimeMillis();
int timeRecording = (int) (nSamples / SAMPLE_RATE * 1000);

WritableMap payload = Arguments.createMap();
payload.putBoolean("isCapturing", isCapturing);
payload.putInt("code", code);
payload.putInt("processTime", timeEnd - timeStart);
payload.putInt("recordingTime", timeRecording);

if (code == 0) {
payload.putMap("data", getTextSegments());
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload);
} else {
payload.putString("error", "Transcribe failed with code " + code);
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload);
}

if (!isCapturing) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
}
isTranscribing = false;
}
});
fullHandler.start();
}
} catch (Exception e) {
Log.e(NAME, "Error transcribing realtime: " + e.getMessage());
}
}
if (fullHandler != null) {
fullHandler.join(); // Wait for full transcribe to finish
}
recorder.stop();
} catch (Exception e) {
e.printStackTrace();
} finally {
recorder.release();
recorder = null;
}
}
}).start();

return state;
}

private void emitTranscribeEvent(final String eventName, final WritableMap payload) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", WhisperContext.this.id);
event.putInt("jobId", jobId);
event.putMap("payload", payload);
eventEmitter.emit(eventName, event);
}

public WritableMap transcribe(int jobId, String filePath, ReadableMap options) throws IOException, Exception {
int code = fullTranscribe(
public WritableMap transcribeFile(int jobId, String filePath, ReadableMap options) throws IOException, Exception {
this.jobId = jobId;
isTranscribing = true;
float[] audioData = decodeWaveFile(new File(filePath));
int code = full(jobId, options, audioData, audioData.length);
isTranscribing = false;
this.jobId = -1;
if (code != 0) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
return getTextSegments();
}

private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
return fullTranscribe(
jobId,
context,
decodeWaveFile(new File(filePath)),
// jboolean realtime,
isRealtime,
// float[] audio_data,
audioData,
// jint audio_data_len,
audioDataLen,
// jint n_threads,
options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
// jint max_context,
Expand Down Expand Up @@ -70,9 +233,9 @@ public WritableMap transcribe(int jobId, String filePath, ReadableMap options) t
// jstring prompt
options.hasKey("prompt") ? options.getString("prompt") : null
);
if (code != 0) {
throw new Exception("Transcription failed with code " + code);
}
}

private WritableMap getTextSegments() {
Integer count = getTextSegmentCount(context);
StringBuilder builder = new StringBuilder();

Expand All @@ -93,7 +256,27 @@ public WritableMap transcribe(int jobId, String filePath, ReadableMap options) t
return data;
}


public boolean isCapturing() {
return isCapturing;
}

public boolean isTranscribing() {
return isTranscribing;
}

public void stopTranscribe(int jobId) {
abortTranscribe(jobId);
isCapturing = false;
isTranscribing = false;
}

public void stopCurrentTranscribe() {
stopTranscribe(this.jobId);
}

public void release() {
stopCurrentTranscribe();
freeContext(context);
}

Expand Down Expand Up @@ -188,7 +371,9 @@ private static String cpuInfo() {
protected static native int fullTranscribe(
int job_id,
long context,
boolean realtime,
float[] audio_data,
int audio_data_len,
int n_threads,
int max_context,
int word_thold,
Expand Down
Loading

0 comments on commit f63c49f

Please sign in to comment.