Skip to content

Commit

Permalink
fix: wait queue/task finished on stopTranscribe (#119)
Browse files Browse the repository at this point in the history
* feat(ios): move dispatch_async to transcribeFile self

* feat(ios): create own dispatch_queue_t by context self

* fix(ios): wait dispatch_queue_t on stopTranscribe

* chore(ios): update lockfile

* fix(android): wait task finished on stopTranscribe
  • Loading branch information
jhen0409 authored Sep 7, 2023
1 parent a1f2fe0 commit ec68cba
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 57 deletions.
76 changes: 68 additions & 8 deletions android/src/main/java/com/rnwhisper/RNWhisper.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.PushbackInputStream;

public class RNWhisper implements LifecycleEventListener {
public static final String NAME = "RNWhisper";

private ReactApplicationContext reactContext;
private Downloader downloader;

Expand All @@ -40,6 +42,8 @@ public HashMap<String, Object> getTypedExportedConstants() {
return constants;
}

private HashMap<AsyncTask, String> tasks = new HashMap<>();

private HashMap<Integer, WhisperContext> contexts = new HashMap<>();

private int getResourceIdentifier(String filePath) {
Expand All @@ -59,7 +63,7 @@ private int getResourceIdentifier(String filePath) {
}

public void initContext(final ReadableMap options, final Promise promise) {
new AsyncTask<Void, Void, Integer>() {
AsyncTask task = new AsyncTask<Void, Void, Integer>() {
private Exception exception;

@Override
Expand Down Expand Up @@ -104,8 +108,10 @@ protected void onPostExecute(Integer id) {
return;
}
promise.resolve(id);
tasks.remove(this);
}
}.execute();
tasks.put(task, "initContext");
}

public void transcribeFile(double id, double jobId, String filePath, ReadableMap options, Promise promise) {
Expand All @@ -122,7 +128,7 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap
promise.reject("Context is already transcribing");
return;
}
new AsyncTask<Void, Void, WritableMap>() {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
Expand Down Expand Up @@ -161,8 +167,10 @@ protected void onPostExecute(WritableMap data) {
return;
}
promise.resolve(data);
tasks.remove(this);
}
}.execute();
tasks.put(task, "transcribeFile-" + id);
}

public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
Expand All @@ -183,18 +191,48 @@ public void startRealtimeTranscribe(double id, double jobId, ReadableMap options
promise.reject("Failed to start realtime transcribe. State: " + state);
}

public void abortTranscribe(double contextId, double jobId, Promise promise) {
WhisperContext context = contexts.get((int) contextId);
public void abortTranscribe(double id, double jobId, Promise promise) {
WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
context.stopTranscribe((int) jobId);
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
context.stopTranscribe((int) jobId);
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("transcribeFile-" + id)) {
task.get();
break;
}
}
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.execute();
tasks.put(task, "abortTranscribe-" + id);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
new AsyncTask<Void, Void, Void>() {
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
Expand All @@ -204,6 +242,14 @@ protected Void doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context " + id + " not found");
}
context.stopCurrentTranscribe();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("transcribeFile-" + contextId)) {
task.get();
break;
}
}
context.release();
contexts.remove(contextId);
} catch (Exception e) {
Expand All @@ -219,12 +265,14 @@ protected void onPostExecute(Void result) {
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.execute();
tasks.put(task, "releaseContext-" + id);
}

public void releaseAllContexts(Promise promise) {
new AsyncTask<Void, Void, Void>() {
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
Expand All @@ -244,8 +292,10 @@ protected void onPostExecute(Void result) {
return;
}
promise.resolve(null);
tasks.remove(this);
}
}.execute();
tasks.put(task, "releaseAllContexts");
}

@Override
Expand All @@ -258,10 +308,20 @@ public void onHostPause() {

@Override
public void onHostDestroy() {
WhisperContext.abortAllTranscribe();
for (WhisperContext context : contexts.values()) {
context.stopCurrentTranscribe();
}
for (AsyncTask task : tasks.keySet()) {
try {
task.get();
} catch (Exception e) {
Log.e(NAME, "Failed to wait for task", e);
}
}
for (WhisperContext context : contexts.values()) {
context.release();
}
WhisperContext.abortAllTranscribe(); // graceful abort
contexts.clear();
downloader.clearCache();
}
Expand Down
15 changes: 13 additions & 2 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class WhisperContext {
private boolean isCapturing = false;
private boolean isStoppedByAction = false;
private boolean isTranscribing = false;
private Thread rootFullHandler = null;
private Thread fullHandler = null;

public WhisperContext(int id, ReactApplicationContext reactContext, long context) {
Expand All @@ -81,6 +82,7 @@ private void rewind() {
isCapturing = false;
isStoppedByAction = false;
isTranscribing = false;
rootFullHandler = null;
fullHandler = null;
}

Expand Down Expand Up @@ -117,7 +119,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {
isCapturing = true;
recorder.startRecording();

new Thread(new Runnable() {
rootFullHandler = new Thread(new Runnable() {
@Override
public void run() {
try {
Expand Down Expand Up @@ -195,7 +197,8 @@ public void run() {
recorder = null;
}
}
}).start();
});
rootFullHandler.start();
return state;
}

Expand Down Expand Up @@ -402,6 +405,14 @@ public void stopTranscribe(int jobId) {
abortTranscribe(jobId);
isCapturing = false;
isStoppedByAction = true;
if (rootFullHandler != null) {
try {
rootFullHandler.join();
} catch (Exception e) {
Log.e(NAME, "Error joining rootFullHandler: " + e.getMessage());
}
rootFullHandler = null;
}
}

public void stopCurrentTranscribe() {
Expand Down
4 changes: 2 additions & 2 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ PODS:
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.4):
- whisper-rn (0.3.5):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -997,7 +997,7 @@ SPEC CHECKSUMS:
RNZipArchive: 68a0c6db4b1c103f846f1559622050df254a3ade
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: a24af0dd79eb6dc1ebacbb110a7120fad0d64818
whisper-rn: 6f293154b175fee138a994fa00d0f414fb1f44e9
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

Expand Down
5 changes: 4 additions & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ export default function App() {
onPress={async () => {
if (!whisperContext) return log('No context')
if (stopTranscribe?.stop) {
stopTranscribe?.stop()
const t0 = Date.now()
await stopTranscribe?.stop()
const t1 = Date.now()
log('Stopped transcribing in', t1 - t0, 'ms')
setStopTranscribe(null)
return
}
Expand Down
66 changes: 36 additions & 30 deletions ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,17 @@ - (NSDictionary *)constantsToExport
path = [[NSBundle mainBundle] pathForResource:modelPath ofType:nil];
}

RNWhisperContext *context = [RNWhisperContext initWithModelPath:path];
int contextId = arc4random_uniform(1000000);

RNWhisperContext *context = [RNWhisperContext
initWithModelPath:path
contextId:contextId
];
if ([context getContext] == NULL) {
reject(@"whisper_cpp_error", @"Failed to load the model", nil);
return;
}

int contextId = arc4random_uniform(1000000);
[contexts setObject:context forKey:[NSNumber numberWithInt:contextId]];

resolve([NSNumber numberWithInt:contextId]);
Expand Down Expand Up @@ -122,36 +126,36 @@ - (NSArray *)supportedEvents {
reject(@"whisper_error", @"Invalid file", nil);
return;
}
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
int code = [context transcribeFile:jobId
audioData:waveFile
audioDataCount:count
options:options
onProgress: ^(int progress) {
if (rn_whisper_transcribe_is_aborted(jobId)) {
return;
}
dispatch_async(dispatch_get_main_queue(), ^{
[self sendEventWithName:@"@RNWhisper_onTranscribeProgress"
body:@{
@"contextId": [NSNumber numberWithInt:contextId],
@"jobId": [NSNumber numberWithInt:jobId],
@"progress": [NSNumber numberWithInt:progress]
}
];
});
[context transcribeFile:jobId
audioData:waveFile
audioDataCount:count
options:options
onProgress: ^(int progress) {
if (rn_whisper_transcribe_is_aborted(jobId)) {
return;
}
dispatch_async(dispatch_get_main_queue(), ^{
[self sendEventWithName:@"@RNWhisper_onTranscribeProgress"
body:@{
@"contextId": [NSNumber numberWithInt:contextId],
@"jobId": [NSNumber numberWithInt:jobId],
@"progress": [NSNumber numberWithInt:progress]
}
];
});
}
onEnd: ^(int code) {
if (code != 0) {
free(waveFile);
reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil);
return;
}
];
if (code != 0) {
free(waveFile);
reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil);
return;
NSMutableDictionary *result = [context getTextSegments];
result[@"isAborted"] = @([context isStoppedByAction]);
resolve(result);
}
free(waveFile);
NSMutableDictionary *result = [context getTextSegments];
result[@"isAborted"] = @([context isStoppedByAction]);
resolve(result);
});
];
}

RCT_REMAP_METHOD(startRealtimeTranscribe,
Expand Down Expand Up @@ -260,7 +264,7 @@ - (float *)decodeWaveFile:(NSString*)filePath count:(int *)count {
}

- (void)invalidate {
rn_whisper_abort_all_transcribe();
[super invalidate];

if (contexts == nil) {
return;
Expand All @@ -271,6 +275,8 @@ - (void)invalidate {
[context invalidate];
}

rn_whisper_abort_all_transcribe(); // graceful abort

[contexts removeAllObjects];
contexts = nil;

Expand Down
11 changes: 8 additions & 3 deletions ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,26 @@ typedef struct {
} RNWhisperContextRecordState;

@interface RNWhisperContext : NSObject {
int contextId;
dispatch_queue_t dQueue;
struct whisper_context * ctx;
RNWhisperContextRecordState recordState;
}

+ (instancetype)initWithModelPath:(NSString *)modelPath;
+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId;
- (struct whisper_context *)getContext;
- (dispatch_queue_t)getDispatchQueue;
- (OSStatus)transcribeRealtime:(int)jobId
options:(NSDictionary *)options
onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe;
- (int)transcribeFile:(int)jobId
- (void)transcribeFile:(int)jobId
audioData:(float *)audioData
audioDataCount:(int)audioDataCount
options:(NSDictionary *)options
onProgress:(void (^)(int))onProgress;
onProgress:(void (^)(int))onProgress
onEnd:(void (^)(int))onEnd;
- (void)stopTranscribe:(int)jobId;
- (void)stopCurrentTranscribe;
- (bool)isCapturing;
- (bool)isTranscribing;
- (bool)isStoppedByAction;
Expand Down
Loading

0 comments on commit ec68cba

Please sign in to comment.