Skip to content

Commit

Permalink
feat: add new abort callback for stop transcribe (#143)
Browse files Browse the repository at this point in the history
* feat: add abort callback

* feat: sync whisper.cpp
  • Loading branch information
jhen0409 authored Oct 10, 2023
1 parent 8a58082 commit f044d6b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
8 changes: 7 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ static void input_stream_close(void *ctx) {
JNIEnv *env = context->env;
jobject input_stream = context->input_stream;
jclass input_stream_class = env->GetObjectClass(input_stream);

env->CallVoidMethod(
input_stream,
env->GetMethodID(input_stream_class, "close", "()V")
Expand Down Expand Up @@ -296,11 +296,17 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
}

// abort handlers
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);
params.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
params.abort_callback_user_data = rn_whisper_assign_abort_map(job_id);

if (callback_instance != nullptr) {
callback_context *cb_ctx = new callback_context;
Expand Down
3 changes: 3 additions & 0 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3773,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,

/*.abort_callback =*/ nullptr,
/*.abort_callback_user_data =*/ nullptr,

/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};
Expand Down
12 changes: 9 additions & 3 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ - (void)transcribeFile:(int)jobId
params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
struct rnwhisper_segments_callback_data *data = (struct rnwhisper_segments_callback_data *)user_data;
data->total_n_new += n_new;

NSString *text = @"";
NSMutableArray *segments = [[NSMutableArray alloc] init];
for (int i = data->total_n_new - n_new; i < data->total_n_new; i++) {
Expand Down Expand Up @@ -451,7 +451,7 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId
if (options[@"maxContext"] != nil) {
params.n_max_text_ctx = [options[@"maxContext"] intValue];
}

if (options[@"offset"] != nil) {
params.offset_ms = [options[@"offset"] intValue];
}
Expand All @@ -467,16 +467,22 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId
if (options[@"temperatureInc"] != nil) {
params.temperature_inc = [options[@"temperature_inc"] floatValue];
}

if (options[@"prompt"] != nil) {
params.initial_prompt = [options[@"prompt"] UTF8String];
}

// abort handler
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(jobId);
params.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
params.abort_callback_user_data = rn_whisper_assign_abort_map(jobId);

return params;
}
Expand Down
2 changes: 1 addition & 1 deletion whisper.cpp

0 comments on commit f044d6b

Please sign in to comment.