diff --git a/android/src/main/java/com/rnwhisper/RNWhisper.java b/android/src/main/java/com/rnwhisper/RNWhisper.java index 45b2c03..32c90be 100644 --- a/android/src/main/java/com/rnwhisper/RNWhisper.java +++ b/android/src/main/java/com/rnwhisper/RNWhisper.java @@ -21,6 +21,8 @@ import java.io.PushbackInputStream; public class RNWhisper implements LifecycleEventListener { + public static final String NAME = "RNWhisper"; + private ReactApplicationContext reactContext; private Downloader downloader; @@ -40,6 +42,8 @@ public HashMap getTypedExportedConstants() { return constants; } + private HashMap tasks = new HashMap<>(); + private HashMap contexts = new HashMap<>(); private int getResourceIdentifier(String filePath) { @@ -59,7 +63,7 @@ private int getResourceIdentifier(String filePath) { } public void initContext(final ReadableMap options, final Promise promise) { - new AsyncTask() { + AsyncTask task = new AsyncTask() { private Exception exception; @Override @@ -104,8 +108,10 @@ protected void onPostExecute(Integer id) { return; } promise.resolve(id); + tasks.remove(this); } }.execute(); + tasks.put(task, "initContext"); } public void transcribeFile(double id, double jobId, String filePath, ReadableMap options, Promise promise) { @@ -122,7 +128,7 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap promise.reject("Context is already transcribing"); return; } - new AsyncTask() { + AsyncTask task = new AsyncTask() { private Exception exception; @Override @@ -161,8 +167,10 @@ protected void onPostExecute(WritableMap data) { return; } promise.resolve(data); + tasks.remove(this); } }.execute(); + tasks.put(task, "transcribeFile-" + id); } public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) { @@ -183,18 +191,48 @@ public void startRealtimeTranscribe(double id, double jobId, ReadableMap options promise.reject("Failed to start realtime transcribe. State: " + state); } - public void abortTranscribe(double contextId, double jobId, Promise promise) { - WhisperContext context = contexts.get((int) contextId); + public void abortTranscribe(double id, double jobId, Promise promise) { + WhisperContext context = contexts.get((int) id); if (context == null) { promise.reject("Context not found"); return; } - context.stopTranscribe((int) jobId); + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected Void doInBackground(Void... voids) { + try { + context.stopTranscribe((int) jobId); + AsyncTask completionTask = null; + for (AsyncTask task : tasks.keySet()) { + if (tasks.get(task).equals("transcribeFile-" + id)) { + task.get(); + break; + } + } + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(Void result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(null); + tasks.remove(this); + } + }.execute(); + tasks.put(task, "abortTranscribe-" + id); } public void releaseContext(double id, Promise promise) { final int contextId = (int) id; - new AsyncTask() { + AsyncTask task = new AsyncTask() { private Exception exception; @Override @@ -204,6 +242,14 @@ protected Void doInBackground(Void... voids) { if (context == null) { throw new Exception("Context " + id + " not found"); } + context.stopCurrentTranscribe(); + AsyncTask completionTask = null; + for (AsyncTask task : tasks.keySet()) { + if (tasks.get(task).equals("transcribeFile-" + contextId)) { + task.get(); + break; + } + } context.release(); contexts.remove(contextId); } catch (Exception e) { @@ -219,12 +265,14 @@ protected void onPostExecute(Void result) { return; } promise.resolve(null); + tasks.remove(this); } }.execute(); + tasks.put(task, "releaseContext-" + id); } public void releaseAllContexts(Promise promise) { - new AsyncTask() { + AsyncTask task = new AsyncTask() { private Exception exception; @Override @@ -244,8 +292,10 @@ protected void onPostExecute(Void result) { return; } promise.resolve(null); + tasks.remove(this); } }.execute(); + tasks.put(task, "releaseAllContexts"); } @Override @@ -258,10 +308,20 @@ public void onHostPause() { @Override public void onHostDestroy() { - WhisperContext.abortAllTranscribe(); + for (WhisperContext context : contexts.values()) { + context.stopCurrentTranscribe(); + } + for (AsyncTask task : tasks.keySet()) { + try { + task.get(); + } catch (Exception e) { + Log.e(NAME, "Failed to wait for task", e); + } + } for (WhisperContext context : contexts.values()) { context.release(); } + WhisperContext.abortAllTranscribe(); // graceful abort contexts.clear(); downloader.clearCache(); } diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 0af041b..3b73e72 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -61,6 +61,7 @@ public class WhisperContext { private boolean isCapturing = false; private boolean isStoppedByAction = false; private boolean isTranscribing = false; + private Thread rootFullHandler = null; private Thread fullHandler = null; public WhisperContext(int id, ReactApplicationContext reactContext, long context) { @@ -81,6 +82,7 @@ private void rewind() { isCapturing = false; isStoppedByAction = false; isTranscribing = false; + rootFullHandler = null; fullHandler = null; } @@ -117,7 +119,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { isCapturing = true; recorder.startRecording(); - new Thread(new Runnable() { + rootFullHandler = new Thread(new Runnable() { @Override public void run() { try { @@ -195,7 +197,8 @@ public void run() { recorder = null; } } - }).start(); + }); + rootFullHandler.start(); return state; } @@ -402,6 +405,14 @@ public void stopTranscribe(int jobId) { abortTranscribe(jobId); isCapturing = false; isStoppedByAction = true; + if (rootFullHandler != null) { + try { + rootFullHandler.join(); + } catch (Exception e) { + Log.e(NAME, "Error joining rootFullHandler: " + e.getMessage()); + } + rootFullHandler = null; + } } public void stopCurrentTranscribe() { diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index 343bf76..a94a1fd 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -759,7 +759,7 @@ PODS: - SSZipArchive (~> 2.2) - SocketRocket (0.6.0) - SSZipArchive (2.4.3) - - whisper-rn (0.3.4): + - whisper-rn (0.3.5): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -997,7 +997,7 @@ SPEC CHECKSUMS: RNZipArchive: 68a0c6db4b1c103f846f1559622050df254a3ade SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608 SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef - whisper-rn: a24af0dd79eb6dc1ebacbb110a7120fad0d64818 + whisper-rn: 6f293154b175fee138a994fa00d0f414fb1f44e9 Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa YogaKit: f782866e155069a2cca2517aafea43200b01fd5a diff --git a/example/src/App.tsx b/example/src/App.tsx index 8f1eb66..d75a013 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -250,7 +250,10 @@ export default function App() { onPress={async () => { if (!whisperContext) return log('No context') if (stopTranscribe?.stop) { - stopTranscribe?.stop() + const t0 = Date.now() + await stopTranscribe?.stop() + const t1 = Date.now() + log('Stopped transcribing in', t1 - t0, 'ms') setStopTranscribe(null) return } diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index d99ae0b..95bfaf3 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -68,13 +68,17 @@ - (NSDictionary *)constantsToExport path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil]; } - RNWhisperContext *context = [RNWhisperContext initWithModelPath:path]; + int contextId = arc4random_uniform(1000000); + + RNWhisperContext *context = [RNWhisperContext + initWithModelPath:path + contextId:contextId + ]; if ([context getContext] == NULL) { reject(@"whisper_cpp_error", @"Failed to load the model", nil); return; } - int contextId = arc4random_uniform(1000000); [contexts setObject:context forKey:[NSNumber numberWithInt:contextId]]; resolve([NSNumber numberWithInt:contextId]); @@ -122,36 +126,36 @@ - (NSArray *)supportedEvents { reject(@"whisper_error", @"Invalid file", nil); return; } - dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ - int code = [context transcribeFile:jobId - audioData:waveFile - audioDataCount:count - options:options - onProgress: ^(int progress) { - if (rn_whisper_transcribe_is_aborted(jobId)) { - return; - } - dispatch_async(dispatch_get_main_queue(), ^{ - [self sendEventWithName:@"@RNWhisper_onTranscribeProgress" - body:@{ - @"contextId": [NSNumber numberWithInt:contextId], - @"jobId": [NSNumber numberWithInt:jobId], - @"progress": [NSNumber numberWithInt:progress] - } - ]; - }); + [context transcribeFile:jobId + audioData:waveFile + audioDataCount:count + options:options + onProgress: ^(int progress) { + if (rn_whisper_transcribe_is_aborted(jobId)) { + return; + } + dispatch_async(dispatch_get_main_queue(), ^{ + [self sendEventWithName:@"@RNWhisper_onTranscribeProgress" + body:@{ + @"contextId": [NSNumber numberWithInt:contextId], + @"jobId": [NSNumber numberWithInt:jobId], + @"progress": [NSNumber numberWithInt:progress] + } + ]; + }); + } + onEnd: ^(int code) { + if (code != 0) { + free(waveFile); + reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil); + return; } - ]; - if (code != 0) { free(waveFile); - reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil); - return; + NSMutableDictionary *result = [context getTextSegments]; + result[@"isAborted"] = @([context isStoppedByAction]); + resolve(result); } - free(waveFile); - NSMutableDictionary *result = [context getTextSegments]; - result[@"isAborted"] = @([context isStoppedByAction]); - resolve(result); - }); + ]; } RCT_REMAP_METHOD(startRealtimeTranscribe, @@ -260,7 +264,7 @@ - (float *)decodeWaveFile:(NSString*)filePath count:(int *)count { } - (void)invalidate { - rn_whisper_abort_all_transcribe(); + [super invalidate]; if (contexts == nil) { return; @@ -271,6 +275,8 @@ - (void)invalidate { [context invalidate]; } + rn_whisper_abort_all_transcribe(); // graceful abort + [contexts removeAllObjects]; contexts = nil; diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index ef6f0d7..30aadec 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -36,21 +36,26 @@ typedef struct { } RNWhisperContextRecordState; @interface RNWhisperContext : NSObject { + int contextId; + dispatch_queue_t dQueue; struct whisper_context * ctx; RNWhisperContextRecordState recordState; } -+ (instancetype)initWithModelPath:(NSString *)modelPath; ++ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId; - (struct whisper_context *)getContext; +- (dispatch_queue_t)getDispatchQueue; - (OSStatus)transcribeRealtime:(int)jobId options:(NSDictionary *)options onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe; -- (int)transcribeFile:(int)jobId +- (void)transcribeFile:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount options:(NSDictionary *)options - onProgress:(void (^)(int))onProgress; + onProgress:(void (^)(int))onProgress + onEnd:(void (^)(int))onEnd; - (void)stopTranscribe:(int)jobId; +- (void)stopCurrentTranscribe; - (bool)isCapturing; - (bool)isTranscribing; - (bool)isStoppedByAction; diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 3c82659..57baa8f 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -4,9 +4,14 @@ @implementation RNWhisperContext -+ (instancetype)initWithModelPath:(NSString *)modelPath { ++ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId { RNWhisperContext *context = [[RNWhisperContext alloc] init]; + context->contextId = contextId; context->ctx = whisper_init_from_file([modelPath UTF8String]); + context->dQueue = dispatch_queue_create( + [[NSString stringWithFormat:@"RNWhisperContext-%d", contextId] UTF8String], + DISPATCH_QUEUE_SERIAL + ); return context; } @@ -14,6 +19,10 @@ - (struct whisper_context *)getContext { return self->ctx; } +- (dispatch_queue_t)getDispatchQueue { + return self->dQueue; +} + - (void)prepareRealtime:(NSDictionary *)options { self->recordState.options = options; @@ -109,7 +118,7 @@ void AudioInputCallback(void * inUserData, nSamples != state->nSamplesTranscribing ) { state->isTranscribing = true; - dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + dispatch_async([state->mSelf getDispatchQueue], ^{ [state->mSelf fullTranscribeSamples:state]; }); } @@ -140,7 +149,7 @@ void AudioInputCallback(void * inUserData, if (!state->isTranscribing) { state->isTranscribing = true; - dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + dispatch_async([state->mSelf getDispatchQueue], ^{ [state->mSelf fullTranscribeSamples:state]; }); } @@ -263,19 +272,22 @@ - (OSStatus)transcribeRealtime:(int)jobId return status; } -- (int)transcribeFile:(int)jobId +- (void)transcribeFile:(int)jobId audioData:(float *)audioData audioDataCount:(int)audioDataCount options:(NSDictionary *)options onProgress:(void (^)(int))onProgress + onEnd:(void (^)(int))onEnd { - self->recordState.isStoppedByAction = false; - self->recordState.isTranscribing = true; - self->recordState.jobId = jobId; - int code = [self fullTranscribeWithProgress:onProgress jobId:jobId audioData:audioData audioDataCount:audioDataCount options:options]; - self->recordState.jobId = -1; - self->recordState.isTranscribing = false; - return code; + dispatch_async(dQueue, ^{ + self->recordState.isStoppedByAction = false; + self->recordState.isTranscribing = true; + self->recordState.jobId = jobId; + int code = [self fullTranscribeWithProgress:onProgress jobId:jobId audioData:audioData audioDataCount:audioDataCount options:options]; + self->recordState.jobId = -1; + self->recordState.isTranscribing = false; + onEnd(code); + }); } - (void)stopAudio { @@ -293,6 +305,7 @@ - (void)stopTranscribe:(int)jobId { } self->recordState.isCapturing = false; self->recordState.isStoppedByAction = true; + dispatch_barrier_sync(dQueue, ^{}); } - (void)stopCurrentTranscribe {