diff --git a/README.md b/README.md index 7ec8a31..b27f521 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,23 @@ npm install whisper.rn Then re-run `npx pod-install` again for iOS. +## Add Microphone Permissions (Optional) + +If you want to use realtime transcribe, you need to add the microphone permission to your app. + +### iOS +Add these lines to ```ios/[YOU_APP_NAME]/info.plist``` +```xml +NSMicrophoneUsageDescription +This app requires microphone access in order to transcribe speech +``` + +### Android +Add the following line to ```android/app/src/main/AndroidManifest.xml``` +```xml + +``` + ## Usage ```js @@ -30,13 +47,35 @@ const sampleFilePath = 'file://.../sample.wav' const whisperContext = await initWhisper({ filePath }) -const { result } = await whisperContext.transcribe(sampleFilePath, { - language: 'en', - // More options -}) +const options = { language: 'en' } +const { stop, promise } = whisperContext.transcribe(sampleFilePath, options) + +const { result } = await promise // result: (The inference text result from audio file) ``` +Use realtime transcribe: + +```js +const { stop, subscribe } = whisperContext.transcribeRealtime(options) + +subscribe(evt => { + const { isCapturing, data, processTime, recordingTime } = evt + console.log( + `Realtime transcribing: ${isCapturing ? 'ON' : 'OFF'}\n` + + // The inference text result from audio record: + `Result: ${data.result}\n\n` + + `Process time: ${processTime}ms\n` + + `Recording time: ${recordingTime}ms`, + ) + if (!isCapturing) console.log('Finished realtime transcribing') +}) +``` + +In Android, you may need to request the microphone permission by [`PermissionAndroid`](https://reactnative.dev/docs/permissionsandroid). + +The documentation is not ready yet, please see the comments of [index](./src/index.tsx) file for more details at the moment. + ## Run with example The example app is using [react-native-fs](https://github.com/itinance/react-native-fs) to download the model file and audio file. diff --git a/android/src/main/java/com/rnwhisper/RNWhisperModule.java b/android/src/main/java/com/rnwhisper/RNWhisperModule.java index 7eebdd7..0f10f69 100644 --- a/android/src/main/java/com/rnwhisper/RNWhisperModule.java +++ b/android/src/main/java/com/rnwhisper/RNWhisperModule.java @@ -5,6 +5,7 @@ import android.os.Build; import android.os.Handler; import android.os.AsyncTask; +import android.media.AudioRecord; import com.facebook.react.bridge.Promise; import com.facebook.react.bridge.ReactApplicationContext; @@ -51,7 +52,7 @@ protected Integer doInBackground(Void... voids) { throw new Exception("Failed to initialize context"); } int id = Math.abs(new Random().nextInt()); - WhisperContext whisperContext = new WhisperContext(context); + WhisperContext whisperContext = new WhisperContext(id, reactContext, context); contexts.put(id, whisperContext); return id; } catch (Exception e) { @@ -72,18 +73,27 @@ protected void onPostExecute(Integer id) { } @ReactMethod - public void transcribe(int id, int jobId, String filePath, ReadableMap options, Promise promise) { + public void transcribeFile(int id, int jobId, String filePath, ReadableMap options, Promise promise) { + final WhisperContext context = contexts.get(id); + if (context == null) { + promise.reject("Context not found"); + return; + } + if (context.isCapturing()) { + promise.reject("The context is in realtime transcribe mode"); + return; + } + if (context.isTranscribing()) { + promise.reject("Context is already transcribing"); + return; + } new AsyncTask() { private Exception exception; @Override protected WritableMap doInBackground(Void... voids) { try { - WhisperContext context = contexts.get(id); - if (context == null) { - throw new Exception("Context " + id + " not found"); - } - return context.transcribe(jobId, filePath, options); + return context.transcribeFile(jobId, filePath, options); } catch (Exception e) { exception = e; return null; @@ -102,8 +112,32 @@ protected void onPostExecute(WritableMap data) { } @ReactMethod - public void abortTranscribe(int jobId) { - WhisperContext.abortTranscribe(jobId); + public void startRealtimeTranscribe(int id, int jobId, ReadableMap options, Promise promise) { + final WhisperContext context = contexts.get(id); + if (context == null) { + promise.reject("Context not found"); + return; + } + if (context.isCapturing()) { + promise.reject("Context is already in capturing"); + return; + } + int state = context.startRealtimeTranscribe(jobId, options); + if (state == AudioRecord.STATE_INITIALIZED) { + promise.resolve(null); + return; + } + promise.reject("Failed to start realtime transcribe. State: " + state); + } + + @ReactMethod + public void abortTranscribe(int contextId, int jobId, Promise promise) { + WhisperContext context = contexts.get(contextId); + if (context == null) { + promise.reject("Context not found"); + return; + } + context.stopTranscribe(jobId); } @ReactMethod diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 8c72983..0a9d2e6 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -4,10 +4,15 @@ import com.facebook.react.bridge.WritableArray; import com.facebook.react.bridge.WritableMap; import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.bridge.ReactApplicationContext; +import com.facebook.react.modules.core.DeviceEventManagerModule; import android.util.Log; import android.os.Build; import android.content.res.AssetManager; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder.AudioSource; import java.util.Random; import java.lang.StringBuilder; @@ -26,17 +31,175 @@ public class WhisperContext { public static final String NAME = "RNWhisperContext"; + + private static final int SAMPLE_RATE = 16000; + private static final int CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO; + private static final int AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT; + private static final int AUDIO_SOURCE = AudioSource.VOICE_RECOGNITION; + private static final int DEFAULT_MAX_AUDIO_SEC = 30; + + private int id; + private ReactApplicationContext reactContext; private long context; - public WhisperContext(long context) { + private DeviceEventManagerModule.RCTDeviceEventEmitter eventEmitter; + + private int jobId = -1; + private AudioRecord recorder = null; + private int bufferSize; + private short[] buffer16; + private int nSamples = 0; + private boolean isCapturing = false; + private boolean isTranscribing = false; + private boolean isRealtime = false; + + public WhisperContext(int id, ReactApplicationContext reactContext, long context) { + this.id = id; this.context = context; + this.reactContext = reactContext; + eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class); + bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT); + } + + public int startRealtimeTranscribe(int jobId, ReadableMap options) { + if (isCapturing || isTranscribing) { + return -100; + } + + recorder = new AudioRecord(AUDIO_SOURCE, SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT, bufferSize); + + int state = recorder.getState(); + if (state != AudioRecord.STATE_INITIALIZED) { + recorder.release(); + return state; + } + + int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; + final int maxAudioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; + + buffer16 = new short[maxAudioSec * SAMPLE_RATE * Short.BYTES]; + + this.jobId = jobId; + isCapturing = true; + isRealtime = true; + nSamples = 0; + + recorder.startRecording(); + + new Thread(new Runnable() { + @Override + public void run() { + try { + short[] buffer = new short[bufferSize]; + Thread fullHandler = null; + while (isCapturing) { + try { + int n = recorder.read(buffer, 0, bufferSize); + if (n == 0) continue; + + if (nSamples + n > maxAudioSec * SAMPLE_RATE) { + // Full, ignore data + isCapturing = false; + if (!isTranscribing) + emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); + break; + } + nSamples += n; + for (int i = 0; i < n; i++) { + buffer16[nSamples + i] = buffer[i]; + } + if (!isTranscribing && nSamples > SAMPLE_RATE / 2) { + isTranscribing = true; + Log.d(NAME, "Start transcribing realtime: " + nSamples); + fullHandler = new Thread(new Runnable() { + @Override + public void run() { + if (!isCapturing) return; + + // convert I16 to F32 + float[] nSamplesBuffer32 = new float[nSamples]; + for (int i = 0; i < nSamples; i++) { + nSamplesBuffer32[i] = buffer16[i] / 32768.0f; + } + + int timeStart = (int) System.currentTimeMillis(); + int code = full(jobId, options, nSamplesBuffer32, nSamples); + int timeEnd = (int) System.currentTimeMillis(); + int timeRecording = (int) (nSamples / SAMPLE_RATE * 1000); + + WritableMap payload = Arguments.createMap(); + payload.putBoolean("isCapturing", isCapturing); + payload.putInt("code", code); + payload.putInt("processTime", timeEnd - timeStart); + payload.putInt("recordingTime", timeRecording); + + if (code == 0) { + payload.putMap("data", getTextSegments()); + emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload); + } else { + payload.putString("error", "Transcribe failed with code " + code); + emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload); + } + + if (!isCapturing) { + emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); + } + isTranscribing = false; + } + }); + fullHandler.start(); + } + } catch (Exception e) { + Log.e(NAME, "Error transcribing realtime: " + e.getMessage()); + } + } + if (fullHandler != null) { + fullHandler.join(); // Wait for full transcribe to finish + } + recorder.stop(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + recorder.release(); + recorder = null; + } + } + }).start(); + + return state; + } + + private void emitTranscribeEvent(final String eventName, final WritableMap payload) { + WritableMap event = Arguments.createMap(); + event.putInt("contextId", WhisperContext.this.id); + event.putInt("jobId", jobId); + event.putMap("payload", payload); + eventEmitter.emit(eventName, event); } - public WritableMap transcribe(int jobId, String filePath, ReadableMap options) throws IOException, Exception { - int code = fullTranscribe( + public WritableMap transcribeFile(int jobId, String filePath, ReadableMap options) throws IOException, Exception { + this.jobId = jobId; + isTranscribing = true; + float[] audioData = decodeWaveFile(new File(filePath)); + int code = full(jobId, options, audioData, audioData.length); + isTranscribing = false; + this.jobId = -1; + if (code != 0) { + throw new Exception("Failed to transcribe the file. Code: " + code); + } + return getTextSegments(); + } + + private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) { + return fullTranscribe( jobId, context, - decodeWaveFile(new File(filePath)), + // jboolean realtime, + isRealtime, + // float[] audio_data, + audioData, + // jint audio_data_len, + audioDataLen, // jint n_threads, options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1, // jint max_context, @@ -70,9 +233,9 @@ public WritableMap transcribe(int jobId, String filePath, ReadableMap options) t // jstring prompt options.hasKey("prompt") ? options.getString("prompt") : null ); - if (code != 0) { - throw new Exception("Transcription failed with code " + code); - } + } + + private WritableMap getTextSegments() { Integer count = getTextSegmentCount(context); StringBuilder builder = new StringBuilder(); @@ -93,7 +256,27 @@ public WritableMap transcribe(int jobId, String filePath, ReadableMap options) t return data; } + + public boolean isCapturing() { + return isCapturing; + } + + public boolean isTranscribing() { + return isTranscribing; + } + + public void stopTranscribe(int jobId) { + abortTranscribe(jobId); + isCapturing = false; + isTranscribing = false; + } + + public void stopCurrentTranscribe() { + stopTranscribe(this.jobId); + } + public void release() { + stopCurrentTranscribe(); freeContext(context); } @@ -188,7 +371,9 @@ private static String cpuInfo() { protected static native int fullTranscribe( int job_id, long context, + boolean realtime, float[] audio_data, + int audio_data_len, int n_threads, int max_context, int word_thold, diff --git a/android/src/main/jni/whisper/jni.cpp b/android/src/main/jni/whisper/jni.cpp index c494f4f..ae17dde 100644 --- a/android/src/main/jni/whisper/jni.cpp +++ b/android/src/main/jni/whisper/jni.cpp @@ -38,7 +38,9 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jobject thiz, jint job_id, jlong context_ptr, + jboolean realtime, jfloatArray audio_data, + jint audio_data_len, jint n_threads, jint max_context, int word_thold, @@ -58,7 +60,6 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( UNUSED(thiz); struct whisper_context *context = reinterpret_cast(context_ptr); jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr); - const jsize audio_data_length = env->GetArrayLength(audio_data); int max_threads = min(4, get_nprocs()); @@ -82,7 +83,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.speed_up = speed_up; params.offset_ms = 0; params.no_context = true; - params.single_segment = false; + params.single_segment = realtime; if (max_len > -1) { params.max_len = max_len; @@ -128,7 +129,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( whisper_reset_timings(context); LOGI("About to run whisper_full"); - int code = whisper_full(context, params, audio_data_arr, audio_data_length); + int code = whisper_full(context, params, audio_data_arr, audio_data_len); if (code == 0) { // whisper_print_timings(context); } diff --git a/example/android/app/src/main/AndroidManifest.xml b/example/android/app/src/main/AndroidManifest.xml index 4122f36..bdbed10 100644 --- a/example/android/app/src/main/AndroidManifest.xml +++ b/example/android/app/src/main/AndroidManifest.xml @@ -1,6 +1,7 @@ + UIViewControllerBasedStatusBarAppearance + NSMicrophoneUsageDescription + This app requires microphone access in order to transcribe speech diff --git a/example/src/App.js b/example/src/App.js index 7393794..9e160cd 100644 --- a/example/src/App.js +++ b/example/src/App.js @@ -6,11 +6,24 @@ import { Text, TouchableOpacity, SafeAreaView, + Platform, + PermissionsAndroid, } from 'react-native' import RNFS from 'react-native-fs' // eslint-disable-next-line import/no-unresolved import { initWhisper } from 'whisper.rn' +if (Platform.OS === 'android') { + // Request record audio permission + PermissionsAndroid.request(PermissionsAndroid.PERMISSIONS.RECORD_AUDIO, { + title: 'Whisper Audio Permission', + message: 'Whisper needs access to your microphone', + buttonNeutral: 'Ask Me Later', + buttonNegative: 'Cancel', + buttonPositive: 'OK', + }) +} + const styles = StyleSheet.create({ container: { flex: 1 }, content: { @@ -73,6 +86,8 @@ const filterPath = (path) => export default function App() { const [whisperContext, setWhisperContext] = useState(null) const [logs, setLogs] = useState([]) + const [transcibeResult, setTranscibeResult] = useState(null) + const [stopTranscribe, setStopTranscribe] = useState(null) const log = useCallback((...messages) => { setLogs((prev) => [...prev, messages.join(' ')]) @@ -136,6 +151,7 @@ export default function App() { { if (!whisperContext) { log('No context') @@ -163,41 +179,77 @@ export default function App() { const startTime = Date.now() const { // stop, - promise - } = await whisperContext.transcribe( - sampleFilePath, - { - language: 'en', - maxLen: 1, - tokenTimestamps: true, - }, - ) + promise, + } = whisperContext.transcribe(sampleFilePath, { + language: 'en', + maxLen: 1, + tokenTimestamps: true, + }) const { result, segments } = await promise const endTime = Date.now() - log('Transcribed result:', result) - log('Transcribed in', endTime - startTime, `ms in ${mode} mode`) - log('Segments:') - segments.forEach((segment) => { - log( - `[${toTimestamp(segment.t0)} --> ${toTimestamp( - segment.t1, - )}] ${segment.text}`, - ) - }) + setTranscibeResult( + `Transcribed result: ${result}\n` + + `Transcribed in ${endTime - startTime}ms in ${mode} mode` + + `\n` + + `Segments:` + + `\n${segments + .map( + (segment) => + `[${toTimestamp(segment.t0)} --> ${toTimestamp( + segment.t1, + )}] ${segment.text}`, + ) + .join('\n')}`, + ) + log('Finished transcribing') }} > - Transcribe + Transcribe File { - if (!whisperContext) return - await whisperContext.release() - setWhisperContext(null) - log('Released context') + if (!whisperContext) { + log('No context') + return + } + if (stopTranscribe?.stop) { + stopTranscribe?.stop() + setStopTranscribe(null) + return + } + log('Start realtime transcribing...') + try { + const { stop, subscribe } = + await whisperContext.transcribeRealtime({ + language: 'en', + realtimeAudioSec: 10, + }) + setStopTranscribe({ stop }) + subscribe((evt) => { + const { isCapturing, data, processTime, recordingTime } = evt + setTranscibeResult( + `Realtime transcribing: ${isCapturing ? 'ON' : 'OFF'}\n` + + `Result: ${data.result}\n\n` + + `Process time: ${processTime}ms\n` + + `Recording time: ${recordingTime}ms`, + ) + if (!isCapturing) { + setStopTranscribe(null) + log('Finished realtime transcribing') + } + }) + } catch (e) { + log('Error:', e) + } }} > - Release Context + + {stopTranscribe?.stop ? 'Stop' : 'Realtime'} + @@ -207,9 +259,29 @@ export default function App() { ))} + {transcibeResult && ( + + {transcibeResult} + + )} + setLogs([])} + onPress={async () => { + if (!whisperContext) return + await whisperContext.release() + setWhisperContext(null) + log('Released context') + }} + > + Release Context + + { + setLogs([]) + setTranscibeResult('') + }} > Clear Logs diff --git a/ios/RNWhisper.h b/ios/RNWhisper.h index 68e4a5d..a134c5f 100644 --- a/ios/RNWhisper.h +++ b/ios/RNWhisper.h @@ -3,9 +3,9 @@ #import "rn-whisper.h" #endif - #import +#import -@interface RNWhisper : NSObject +@interface RNWhisper : RCTEventEmitter @end diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index e15b898..b2a8f39 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -1,23 +1,8 @@ - #import "RNWhisper.h" +#import "RNWhisperContext.h" #include #include -@interface WhisperContext : NSObject { -} - -@property struct whisper_context * ctx; - -@end - -@implementation WhisperContext - -- (void)invalidate { - whisper_free(self.ctx); -} - -@end - @implementation RNWhisper NSMutableDictionary *contexts; @@ -33,10 +18,8 @@ @implementation RNWhisper contexts = [[NSMutableDictionary alloc] init]; } - WhisperContext *context = [[WhisperContext alloc] init]; - context.ctx = whisper_init_from_file([modelPath UTF8String]); - - if (context.ctx == NULL) { + RNWhisperContext *context = [RNWhisperContext initWithModelPath:modelPath]; + if ([context getContext] == NULL) { reject(@"whisper_cpp_error", @"Failed to load the model", nil); return; } @@ -47,138 +30,105 @@ @implementation RNWhisper resolve([NSNumber numberWithInt:contextId]); } -RCT_REMAP_METHOD(transcribe, +RCT_REMAP_METHOD(transcribeFile, withContextId:(int)contextId - withJobId:(int)job_id + withJobId:(int)jobId withWaveFile:(NSString *)waveFilePath withOptions:(NSDictionary *)options withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { - WhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; + RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; if (context == nil) { reject(@"whisper_error", @"Context not found", nil); return; } + if ([context isCapturing]) { + reject(@"whisper_error", @"The context is in realtime transcribe mode", nil); + return; + } + if ([context isTranscribing]) { + reject(@"whisper_error", @"Context is already transcribing", nil); + return; + } NSURL *url = [NSURL fileURLWithPath:waveFilePath]; int count = 0; float *waveFile = [self decodeWaveFile:url count:&count]; - if (waveFile == nil) { reject(@"whisper_error", @"Invalid file", nil); return; } - - struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - const int max_threads = options[@"maxThreads"] != nil ? - [options[@"maxThreads"] intValue] : - MIN(4, (int)[[NSProcessInfo processInfo] processorCount]); - - if (options[@"beamSize"] != nil) { - params.strategy = WHISPER_SAMPLING_BEAM_SEARCH; - params.beam_search.beam_size = [options[@"beamSize"] intValue]; - } - - params.print_realtime = false; - params.print_progress = false; - params.print_timestamps = false; - params.print_special = false; - params.speed_up = options[@"speedUp"] != nil ? [options[@"speedUp"] boolValue] : false; - params.translate = options[@"translate"] != nil ? [options[@"translate"] boolValue] : false; - params.language = options[@"language"] != nil ? [options[@"language"] UTF8String] : "auto"; - params.n_threads = max_threads; - params.offset_ms = 0; - params.no_context = true; - params.single_segment = false; - - if (options[@"maxLen"] != nil) { - params.max_len = [options[@"maxLen"] intValue]; - } - params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false; - - if (options[@"bestOf"] != nil) { - params.greedy.best_of = [options[@"bestOf"] intValue]; - } - if (options[@"maxContext"] != nil) { - params.n_max_text_ctx = [options[@"maxContext"] intValue]; - } - - if (options[@"offset"] != nil) { - params.offset_ms = [options[@"offset"] intValue]; - } - if (options[@"duration"] != nil) { - params.duration_ms = [options[@"duration"] intValue]; - } - if (options[@"wordThold"] != nil) { - params.thold_pt = [options[@"wordThold"] intValue]; - } - if (options[@"temperature"] != nil) { - params.temperature = [options[@"temperature"] floatValue]; - } - if (options[@"temperatureInc"] != nil) { - params.temperature_inc = [options[@"temperature_inc"] floatValue]; - } - - if (options[@"prompt"] != nil) { - std::string *prompt = new std::string([options[@"prompt"] UTF8String]); - rn_whisper_convert_prompt( - context.ctx, - params, - prompt - ); - } - - params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id); - - whisper_reset_timings(context.ctx); - int code = whisper_full(context.ctx, params, waveFile, count); + int code = [context transcribeFile:jobId audioData:waveFile audioDataCount:count options:options]; if (code != 0) { - NSLog(@"Failed to run the model"); free(waveFile); - reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to run the model. Code: %d", code], nil); + reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil); return; } - - // whisper_print_timings(context.ctx); free(waveFile); + resolve([context getTextSegments]); +} - rn_whisper_remove_abort_map(job_id); +- (NSArray *)supportedEvents { + return@[ + @"@RNWhisper_onRealtimeTranscribe", + @"@RNWhisper_onRealtimeTranscribeEnd", + ]; +} - NSString *result = @""; - int n_segments = whisper_full_n_segments(context.ctx); +RCT_REMAP_METHOD(startRealtimeTranscribe, + withContextId:(int)contextId + withJobId:(int)jobId + withOptions:(NSDictionary *)options + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; - NSMutableArray *segments = [[NSMutableArray alloc] init]; - for (int i = 0; i < n_segments; i++) { - const char * text_cur = whisper_full_get_segment_text(context.ctx, i); - result = [result stringByAppendingString:[NSString stringWithUTF8String:text_cur]]; + if (context == nil) { + reject(@"whisper_error", @"Context not found", nil); + return; + } + if ([context isCapturing]) { + reject(@"whisper_error", @"The context is already capturing", nil); + return; + } - const int64_t t0 = whisper_full_get_segment_t0(context.ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(context.ctx, i); - NSDictionary *segment = @{ - @"text": [NSString stringWithUTF8String:text_cur], - @"t0": [NSNumber numberWithLongLong:t0], - @"t1": [NSNumber numberWithLongLong:t1] - }; - [segments addObject:segment]; + OSStatus status = [context transcribeRealtime:jobId + options:options + onTranscribe:^(int _jobId, NSString *type, NSDictionary *payload) { + NSString *eventName = nil; + if ([type isEqual:@"transcribe"]) { + eventName = @"@RNWhisper_onRealtimeTranscribe"; + } else if ([type isEqual:@"end"]) { + eventName = @"@RNWhisper_onRealtimeTranscribeEnd"; + } + if (eventName == nil) { + return; + } + [self sendEventWithName:eventName + body:@{ + @"contextId": [NSNumber numberWithInt:contextId], + @"jobId": [NSNumber numberWithInt:jobId], + @"payload": payload + } + ]; + } + ]; + if (status == 0) { + resolve(nil); + return; } - resolve(@{ - @"result": result, - @"segments": segments - }); + reject(@"whisper_error", [NSString stringWithFormat:@"Failed to start realtime transcribe. Status: %d", status], nil); } - RCT_REMAP_METHOD(abortTranscribe, - withJobId:(int)job_id) + withContextId:(int)contextId + withJobId:(int)jobId) { - rn_whisper_abort_transcribe(job_id); + RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; + [context stopTranscribe:jobId]; } RCT_REMAP_METHOD(releaseContext, @@ -186,7 +136,7 @@ @implementation RNWhisper withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { - WhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; + RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]]; if (context == nil) { reject(@"whisper_error", @"Context not found", nil); return; @@ -232,7 +182,7 @@ - (void)invalidate { } for (NSNumber *contextId in contexts) { - WhisperContext *context = contexts[contextId]; + RNWhisperContext *context = contexts[contextId]; [context invalidate]; } diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h new file mode 100644 index 0000000..6faa608 --- /dev/null +++ b/ios/RNWhisperContext.h @@ -0,0 +1,53 @@ +#ifdef __cplusplus +#import "whisper.h" +#import "rn-whisper.h" +#endif + +#import +#import + +#define NUM_BUFFERS 3 +#define DEFAULT_MAX_AUDIO_SEC 30 + +typedef struct { + __unsafe_unretained id mSelf; + + int jobId; + NSDictionary* options; + + bool isTranscribing; + bool isRealtime; + bool isCapturing; + int maxAudioSec; + int nSamples; + int16_t* audioBufferI16; + float* audioBufferF32; + + AudioQueueRef queue; + AudioStreamBasicDescription dataFormat; + AudioQueueBufferRef buffers[NUM_BUFFERS]; + + void (^transcribeHandler)(int, NSString *, NSDictionary *); +} RNWhisperContextRecordState; + +@interface RNWhisperContext : NSObject { + struct whisper_context * ctx; + RNWhisperContextRecordState recordState; +} + ++ (instancetype)initWithModelPath:(NSString *)modelPath; +- (struct whisper_context *)getContext; +- (OSStatus)transcribeRealtime:(int)jobId + options:(NSDictionary *)options + onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe; +- (int)transcribeFile:(int)jobId + audioData:(float *)audioData + audioDataCount:(int)audioDataCount + options:(NSDictionary *)options; +- (void)stopTranscribe:(int)jobId; +- (bool)isCapturing; +- (bool)isTranscribing; +- (NSDictionary *)getTextSegments; +- (void)invalidate; + +@end diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm new file mode 100644 index 0000000..3adcfe4 --- /dev/null +++ b/ios/RNWhisperContext.mm @@ -0,0 +1,303 @@ +#import "RNWhisperContext.h" + +#define NUM_BYTES_PER_BUFFER 16 * 1024 + +@implementation RNWhisperContext + ++ (instancetype)initWithModelPath:(NSString *)modelPath { + RNWhisperContext *context = [[RNWhisperContext alloc] init]; + context->ctx = whisper_init_from_file([modelPath UTF8String]); + return context; +} + +- (struct whisper_context *)getContext { + return self->ctx; +} + +- (void)prepareRealtime:(NSDictionary *)options { + self->recordState.options = options; + + self->recordState.dataFormat.mSampleRate = WHISPER_SAMPLE_RATE; // 16000 + self->recordState.dataFormat.mFormatID = kAudioFormatLinearPCM; + self->recordState.dataFormat.mFramesPerPacket = 1; + self->recordState.dataFormat.mChannelsPerFrame = 1; // mono + self->recordState.dataFormat.mBytesPerFrame = 2; + self->recordState.dataFormat.mBytesPerPacket = 2; + self->recordState.dataFormat.mBitsPerChannel = 16; + self->recordState.dataFormat.mReserved = 0; + self->recordState.dataFormat.mFormatFlags = kLinearPCMFormatFlagIsSignedInteger; + + self->recordState.nSamples = 0; + + int maxAudioSecOpt = options[@"realtimeAudioSec"] != nil ? [options[@"realtimeAudioSec"] intValue] : 0; + int maxAudioSec = maxAudioSecOpt > 0 ? maxAudioSecOpt : DEFAULT_MAX_AUDIO_SEC; + self->recordState.maxAudioSec = maxAudioSec; + self->recordState.audioBufferI16 = (int16_t *) malloc(maxAudioSec * WHISPER_SAMPLE_RATE * sizeof(int16_t)); + self->recordState.audioBufferF32 = (float *) malloc(maxAudioSec * WHISPER_SAMPLE_RATE * sizeof(float)); + + self->recordState.isRealtime = true; + self->recordState.isTranscribing = false; + self->recordState.isCapturing = false; + + self->recordState.mSelf = self; +} + +void AudioInputCallback(void * inUserData, + AudioQueueRef inAQ, + AudioQueueBufferRef inBuffer, + const AudioTimeStamp * inStartTime, + UInt32 inNumberPacketDescriptions, + const AudioStreamPacketDescription * inPacketDescs) +{ + RNWhisperContextRecordState *state = (RNWhisperContextRecordState *)inUserData; + + if (!state->isCapturing) { + NSLog(@"[RNWhisper] Not capturing, ignoring audio"); + return; + } + + const int n = inBuffer->mAudioDataByteSize / 2; + NSLog(@"[RNWhisper] Captured %d new samples", n); + + if (state->nSamples + n > state->maxAudioSec * WHISPER_SAMPLE_RATE) { + NSLog(@"[RNWhisper] Audio buffer is full, ignoring audio"); + state->isCapturing = false; + if (!state->isTranscribing) { + state->transcribeHandler(state->jobId, @"end", @{}); + } + [state->mSelf stopAudio]; + return; + } + + for (int i = 0; i < n; i++) { + state->audioBufferI16[state->nSamples + i] = ((short*)inBuffer->mAudioData)[i]; + } + state->nSamples += n; + + AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL); + + if (!state->isTranscribing) { + state->isTranscribing = true; + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + NSLog(@"[RNWhisper] Transcribing %d samples", state->nSamples); + // convert I16 to F32 + for (int i = 0; i < state->nSamples; i++) { + state->audioBufferF32[i] = (float)state->audioBufferI16[i] / 32768.0f; + } + CFTimeInterval timeStart = CACurrentMediaTime(); + + int code = [state->mSelf fullTranscribe:state->jobId audioData:state->audioBufferF32 audioDataCount:state->nSamples options:state->options]; + + CFTimeInterval timeEnd = CACurrentMediaTime(); + const float timeRecording = (float) state->nSamples / (float) state->dataFormat.mSampleRate; + if (code == 0) { + state->transcribeHandler(state->jobId, @"transcribe", @{ + @"isCapturing": @(state->isCapturing), + @"code": [NSNumber numberWithInt:code], + @"data": [state->mSelf getTextSegments], + @"processTime": [NSNumber numberWithInt:(timeEnd - timeStart) * 1E3], + @"recordingTime": [NSNumber numberWithInt:timeRecording * 1E3], + }); + state->isTranscribing = false; + return; + } + state->transcribeHandler(state->jobId, @"transcribe", @{ + @"isCapturing": @(state->isCapturing), + @"code": [NSNumber numberWithInt:code], + @"error": [NSString stringWithFormat:@"Transcribe failed with code %d", code], + @"processTime": [NSNumber numberWithDouble:timeEnd - timeStart], + @"recordingTime": [NSNumber numberWithFloat:timeRecording], + }); + if (!state->isCapturing) { + NSLog(@"[RNWhisper] Transcribe end"); + state->transcribeHandler(state->jobId, @"end", @{}); + } + state->isTranscribing = false; + }); + } +} + +- (bool)isCapturing { + return self->recordState.isCapturing; +} + +- (bool)isTranscribing { + return self->recordState.isTranscribing; +} + +- (OSStatus)transcribeRealtime:(int)jobId + options:(NSDictionary *)options + onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe +{ + self->recordState.transcribeHandler = onTranscribe; + self->recordState.jobId = jobId; + [self prepareRealtime:options]; + self->recordState.nSamples = 0; + + OSStatus status = AudioQueueNewInput( + &self->recordState.dataFormat, + AudioInputCallback, + &self->recordState, + NULL, + kCFRunLoopCommonModes, + 0, + &self->recordState.queue + ); + + if (status == 0) { + for (int i = 0; i < NUM_BUFFERS; i++) { + AudioQueueAllocateBuffer(self->recordState.queue, NUM_BYTES_PER_BUFFER, &self->recordState.buffers[i]); + AudioQueueEnqueueBuffer(self->recordState.queue, self->recordState.buffers[i], 0, NULL); + } + status = AudioQueueStart(self->recordState.queue, NULL); + if (status == 0) { + self->recordState.isCapturing = true; + } + } + return status; +} + +- (int)transcribeFile:(int)jobId + audioData:(float *)audioData + audioDataCount:(int)audioDataCount + options:(NSDictionary *)options +{ + self->recordState.isTranscribing = true; + self->recordState.jobId = jobId; + int code = [self fullTranscribe:jobId audioData:audioData audioDataCount:audioDataCount options:options]; + self->recordState.jobId = -1; + self->recordState.isTranscribing = false; + return code; +} + +- (void)stopAudio { + AudioQueueStop(self->recordState.queue, true); + for (int i = 0; i < NUM_BUFFERS; i++) { + AudioQueueFreeBuffer(self->recordState.queue, self->recordState.buffers[i]); + } + AudioQueueDispose(self->recordState.queue, true); +} + +- (void)stopTranscribe:(int)jobId { + rn_whisper_abort_transcribe(jobId); + if (!self->recordState.isRealtime || !self->recordState.isCapturing) { + return; + } + self->recordState.isCapturing = false; + [self stopAudio]; +} + +- (void)stopCurrentTranscribe { + if (!self->recordState.jobId) { + return; + } + [self stopTranscribe:self->recordState.jobId]; +} + +- (int)fullTranscribe:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount options:(NSDictionary *)options { + struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + const int max_threads = options[@"maxThreads"] != nil ? + [options[@"maxThreads"] intValue] : + MIN(4, (int)[[NSProcessInfo processInfo] processorCount]); + + if (options[@"beamSize"] != nil) { + params.strategy = WHISPER_SAMPLING_BEAM_SEARCH; + params.beam_search.beam_size = [options[@"beamSize"] intValue]; + } + + params.print_realtime = false; + params.print_progress = false; + params.print_timestamps = false; + params.print_special = false; + params.speed_up = options[@"speedUp"] != nil ? [options[@"speedUp"] boolValue] : false; + params.translate = options[@"translate"] != nil ? [options[@"translate"] boolValue] : false; + params.language = options[@"language"] != nil ? [options[@"language"] UTF8String] : "auto"; + params.n_threads = max_threads; + params.offset_ms = 0; + params.no_context = true; + params.single_segment = self->recordState.isRealtime; + + if (options[@"maxLen"] != nil) { + params.max_len = [options[@"maxLen"] intValue]; + } + params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false; + + if (options[@"bestOf"] != nil) { + params.greedy.best_of = [options[@"bestOf"] intValue]; + } + if (options[@"maxContext"] != nil) { + params.n_max_text_ctx = [options[@"maxContext"] intValue]; + } + + if (options[@"offset"] != nil) { + params.offset_ms = [options[@"offset"] intValue]; + } + if (options[@"duration"] != nil) { + params.duration_ms = [options[@"duration"] intValue]; + } + if (options[@"wordThold"] != nil) { + params.thold_pt = [options[@"wordThold"] intValue]; + } + if (options[@"temperature"] != nil) { + params.temperature = [options[@"temperature"] floatValue]; + } + if (options[@"temperatureInc"] != nil) { + params.temperature_inc = [options[@"temperature_inc"] floatValue]; + } + + if (options[@"prompt"] != nil) { + std::string *prompt = new std::string([options[@"prompt"] UTF8String]); + rn_whisper_convert_prompt( + self->ctx, + params, + prompt + ); + } + + params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(jobId); + + whisper_reset_timings(self->ctx); + + int code = whisper_full(self->ctx, params, audioData, audioDataCount); + rn_whisper_remove_abort_map(jobId); + // if (code == 0) { + // whisper_print_timings(self->ctx); + // } + return code; +} + +- (NSDictionary *)getTextSegments { + NSString *result = @""; + int n_segments = whisper_full_n_segments(self->ctx); + + NSMutableArray *segments = [[NSMutableArray alloc] init]; + for (int i = 0; i < n_segments; i++) { + const char * text_cur = whisper_full_get_segment_text(self->ctx, i); + result = [result stringByAppendingString:[NSString stringWithUTF8String:text_cur]]; + + const int64_t t0 = whisper_full_get_segment_t0(self->ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(self->ctx, i); + NSDictionary *segment = @{ + @"text": [NSString stringWithUTF8String:text_cur], + @"t0": [NSNumber numberWithLongLong:t0], + @"t1": [NSNumber numberWithLongLong:t1] + }; + [segments addObject:segment]; + } + return @{ + @"result": result, + @"segments": segments + }; +} + +- (void)invalidate { + [self stopCurrentTranscribe]; + whisper_free(self->ctx); +} + +@end diff --git a/jest/mock.js b/jest/mock.js index 9edea9a..2574b3d 100644 --- a/jest/mock.js +++ b/jest/mock.js @@ -1,14 +1,50 @@ -const { NativeModules } = require('react-native') +const { NativeModules, DeviceEventEmitter } = require('react-native') if (!NativeModules.RNWhisper) { NativeModules.RNWhisper = { initContext: jest.fn(() => Promise.resolve(1)), - transcribe: jest.fn(() => Promise.resolve({ + transcribeFile: jest.fn(() => Promise.resolve({ result: ' Test', segments: [{ text: ' Test', t0: 0, t1: 33 }], })), + startRealtimeTranscribe: jest.fn((contextId, jobId) => { + setTimeout(() => { + // Start + DeviceEventEmitter.emit('@RNWhisper_onRealtimeTranscribe', { + contextId, + jobId, + payload: { + isCapturing: true, + data: { + result: ' Test', + segments: [{ text: ' Test', t0: 0, t1: 33 }], + }, + processTime: 100, + recordingTime: 1000, + }, + }) + // End + DeviceEventEmitter.emit('@RNWhisper_onRealtimeTranscribe', { + contextId, + jobId, + payload: { + isCapturing: false, + data: { + result: ' Test', + segments: [{ text: ' Test', t0: 0, t1: 33 }], + }, + processTime: 100, + recordingTime: 2000, + }, + }) + }) + }), releaseContext: jest.fn(() => Promise.resolve()), releaseAllContexts: jest.fn(() => Promise.resolve()), + + // For NativeEventEmitter + addListener: jest.fn(), + removeListeners: jest.fn(), } } diff --git a/src/__tests__/index.test.tsx b/src/__tests__/index.test.tsx index cfb4d5f..9104655 100644 --- a/src/__tests__/index.test.tsx +++ b/src/__tests__/index.test.tsx @@ -10,6 +10,45 @@ test('Mock', async () => { result: ' Test', segments: [{ text: ' Test', t0: 0, t1: 33 }], }) + + const { subscribe } = await context.transcribeRealtime() + const events: any[] = [] + subscribe((event) => events.push(event)) + await new Promise((resolve) => setTimeout(resolve, 0)) + expect(events).toMatchObject([ + { + contextId: 1, + data: { + result: ' Test', + segments: [ + { + t0: 0, + t1: 33, + text: ' Test', + }, + ], + }, + isCapturing: true, + processTime: 100, + recordingTime: 1000, + }, + { + contextId: 1, + data: { + result: ' Test', + segments: [ + { + t0: 0, + t1: 33, + text: ' Test', + }, + ], + }, + isCapturing: false, + processTime: 100, + recordingTime: 2000, + }, + ]) await context.release() await releaseAllWhisper() }) diff --git a/src/index.tsx b/src/index.tsx index a4d0bcf..cf42a6c 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,4 +1,10 @@ -import { NativeModules, Platform } from 'react-native' +import { + NativeEventEmitter, + DeviceEventEmitter, + NativeModules, + Platform, + DeviceEventEmitterStatic, +} from 'react-native' const LINKING_ERROR = `The package 'whisper.rn' doesn't seem to be linked. Make sure: \n\n${Platform.select({ ios: "- You have run 'pod install'\n", default: '' }) @@ -15,24 +21,58 @@ const RNWhisper = NativeModules.RNWhisper }, ) +let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic +if (Platform.OS === 'ios') { + EventEmitter = new NativeEventEmitter(RNWhisper) +} +if (Platform.OS === 'android') { + EventEmitter = DeviceEventEmitter +} + +const EVENT_ON_REALTIME_TRANSCRIBE = '@RNWhisper_onRealtimeTranscribe' +const EVENT_ON_REALTIME_TRANSCRIBE_END = '@RNWhisper_onRealtimeTranscribeEnd' + export type TranscribeOptions = { + /** Spoken language (Default: 'auto' for auto-detect) */ language?: string, + /** Translate from source language to english (Default: false) */ translate?: boolean, + /** Number of threads to use during computation (Default: 4) */ maxThreads?: number, + /** Maximum number of text context tokens to store */ maxContext?: number, + /** Maximum segment length in characters */ maxLen?: number, + /** Enable token-level timestamps */ tokenTimestamps?: boolean, + /** Word timestamp probability threshold */ + wordThold?: number, + /** Time offset in milliseconds */ offset?: number, + /** Duration of audio to process in milliseconds */ duration?: number, - wordThold?: number, + /** Tnitial decoding temperature */ temperature?: number, temperatureInc?: number, + /** Beam size for beam search */ beamSize?: number, + /** Number of best candidates to keep */ bestOf?: number, + /** Speed up audio by x2 (reduced accuracy) */ speedUp?: boolean, + /** Initial Prompt */ prompt?: string, } +export type TranscribeRealtimeOptions = TranscribeOptions & { + /** + * Realtime record max duration in seconds. + * Due to the whisper.cpp hard constraint - processes the audio in chunks of 30 seconds, + * the recommended value will be <= 30 seconds. (Default: 30) + */ + realtimeAudioSec?: number, +} + export type TranscribeResult = { result: string, segments: Array<{ @@ -42,6 +82,32 @@ export type TranscribeResult = { }>, } +export type TranscribeRealtimeEvent = { + contextId: number, + jobId: number, + /** Is capturing audio, when false, the event is the final result */ + isCapturing: boolean, + code: number, + processTime: number, + recordingTime: number, + data?: TranscribeResult, + error?: string, +} + +export type TranscribeRealtimeNativeEvent = { + contextId: number, + jobId: number, + payload: { + /** Is capturing audio, when false, the event is the final result */ + isCapturing: boolean, + code: number, + processTime: number, + recordingTime: number, + data?: TranscribeResult, + error?: string, + }, +} + class WhisperContext { id: number @@ -49,14 +115,58 @@ class WhisperContext { this.id = id } + /** Transcribe audio file */ transcribe(path: string, options: TranscribeOptions = {}): { + /** Stop the transcribe */ stop: () => void, + /** Transcribe result promise */ promise: Promise, } { const jobId: number = Math.floor(Math.random() * 10000) return { - stop: () => RNWhisper.abortTranscribe(jobId), - promise: RNWhisper.transcribe(this.id, jobId, path, options), + stop: () => RNWhisper.abortTranscribe(this.id, jobId), + promise: RNWhisper.transcribeFile(this.id, jobId, path, options), + } + } + + /** Transcribe the microphone audio stream, the microphone user permission is required */ + async transcribeRealtime(options: TranscribeRealtimeOptions = {}): Promise<{ + /** Stop the realtime transcribe */ + stop: () => void, + /** Subscribe to realtime transcribe events */ + subscribe: (callback: (event: TranscribeRealtimeEvent) => void) => void, + }> { + const jobId: number = Math.floor(Math.random() * 10000) + await RNWhisper.startRealtimeTranscribe(this.id, jobId, options) + let removeTranscribe: () => void + let removeEnd: () => void + let lastTranscribePayload: TranscribeRealtimeNativeEvent['payload'] + return { + stop: () => RNWhisper.abortTranscribe(this.id, jobId), + subscribe: (callback: (event: TranscribeRealtimeEvent) => void) => { + const transcribeListener = EventEmitter.addListener( + EVENT_ON_REALTIME_TRANSCRIBE, + (evt: TranscribeRealtimeNativeEvent) => { + const { contextId, payload } = evt + if (contextId !== this.id || evt.jobId !== jobId) return + lastTranscribePayload = payload + callback({ contextId, jobId: evt.jobId, ...payload }) + if (!payload.isCapturing) removeTranscribe() + } + ) + removeTranscribe = transcribeListener.remove + const endListener = EventEmitter.addListener( + EVENT_ON_REALTIME_TRANSCRIBE_END, + (evt: TranscribeRealtimeNativeEvent) => { + const { contextId } = evt + if (contextId !== this.id || evt.jobId !== jobId) return + callback({ contextId, jobId: evt.jobId, ...lastTranscribePayload, isCapturing: false }) + removeTranscribe?.() + removeEnd() + } + ) + removeEnd = endListener.remove + }, } }