Skip to content

Commit

Permalink
feat(ios): move vad params
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Dec 7, 2023
1 parent da21edb commit 50f8713
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 86 deletions.
126 changes: 73 additions & 53 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,6 @@

namespace rnwhisper {

job::~job() {
fprintf(stderr, "%s: job_id: %d\n", __func__, job_id);
}

bool job::is_aborted() {
return aborted;
}

void job::abort() {
aborted = true;
}

std::unordered_map<int, job> job_map;

void job_abort_all() {
for (auto it = job_map.begin(); it != job_map.end(); ++it) {
it->second.abort();
}
}

job* job_new(int job_id, struct whisper_full_params params) {
job ctx;
ctx.job_id = job_id;
ctx.params = params;

// Abort handler
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
job *j = (job*)user_data;
return !j->is_aborted();
};
params.encoder_begin_callback_user_data = &ctx;
params.abort_callback = [](void * user_data) {
job *j = (job*)user_data;
return j->is_aborted();
};
params.abort_callback_user_data = &ctx;

job_map[job_id] = ctx;
return &job_map[job_id];
}

void job_remove(int job_id) {
job_map.erase(job_id);
}

job* job_get(int job_id) {
if (job_map.find(job_id) != job_map.end()) {
return &job_map[job_id];
}
return nullptr;
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
Expand All @@ -71,7 +19,7 @@ void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate
}
}

bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
bool vad_simple_impl(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

Expand Down Expand Up @@ -109,4 +57,76 @@ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float
return true;
}

job::~job() {
fprintf(stderr, "%s: job_id: %d\n", __func__, job_id);
}

void job::set_vad_params(vad_params params) {
vad = params;
if (vad.vad_ms < 2000) vad.vad_ms = 2000;
}

bool job::vad_simple(short* pcm, int n_samples, int n) {
if (!vad.use_vad) return true;

int sample_size = (int) (WHISPER_SAMPLE_RATE * vad.vad_ms / 1000);
if (n_samples + n > sample_size) {
int start = n_samples + n - sample_size;
std::vector<float> pcmf32(sample_size);
for (int i = 0; i < sample_size; i++) {
pcmf32[i] = (float)pcm[i + start] / 32768.0f;
}
return vad_simple_impl(pcmf32, WHISPER_SAMPLE_RATE, vad.last_ms, vad.vad_thold, vad.freq_thold, vad.verbose);
}
return false;
}

bool job::is_aborted() {
return aborted;
}

void job::abort() {
aborted = true;
}

std::unordered_map<int, job> job_map;

void job_abort_all() {
for (auto it = job_map.begin(); it != job_map.end(); ++it) {
it->second.abort();
}
}

job* job_new(int job_id, struct whisper_full_params params) {
job ctx;
ctx.job_id = job_id;
ctx.params = params;

// Abort handler
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
job *j = (job*)user_data;
return !j->is_aborted();
};
params.encoder_begin_callback_user_data = &ctx;
params.abort_callback = [](void * user_data) {
job *j = (job*)user_data;
return j->is_aborted();
};
params.abort_callback_user_data = &ctx;

job_map[job_id] = ctx;
return &job_map[job_id];
}

job* job_get(int job_id) {
if (job_map.find(job_id) != job_map.end()) {
return &job_map[job_id];
}
return nullptr;
}

void job_remove(int job_id) {
job_map.erase(job_id);
}

}
18 changes: 14 additions & 4 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@

namespace rnwhisper {

struct vad_params {
bool use_vad = false;
float vad_thold = 0.1;
float freq_thold = 0.1;
int vad_ms = 2000;
int last_ms = 1000;
bool verbose = false;
};

struct job {
int job_id;
whisper_full_params params;
bool aborted = false;
whisper_full_params params;
vad_params vad; // Realtime transcription only

~job();
void set_vad_params(vad_params vad);
bool vad_simple(short* pcm, int n_samples, int n);
bool is_aborted();
void abort();
};
Expand All @@ -21,9 +34,6 @@ job* job_new(int job_id, struct whisper_full_params params);
void job_remove(int job_id);
job* job_get(int job_id);

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate);
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);

} // namespace rnwhisper

#endif // RNWHISPER_H
5 changes: 0 additions & 5 deletions ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ typedef struct {
int audioSliceSec;
NSString* audioOutputPath;

bool useVad;
int vadMs;
float vadThold;
float vadFreqThold;

AudioQueueRef queue;
AudioStreamBasicDescription dataFormat;
AudioQueueBufferRef buffers[NUM_BUFFERS];
Expand Down
35 changes: 11 additions & 24 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,6 @@ - (void)prepareRealtime:(NSDictionary *)options {

self->recordState.audioOutputPath = options[@"audioOutputPath"];

self->recordState.useVad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false;
self->recordState.vadMs = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000;
if (self->recordState.vadMs < 2000) self->recordState.vadMs = 2000;

self->recordState.vadThold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f;
self->recordState.vadFreqThold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f;

self->recordState.audioSliceSec = audioSliceSec;
self->recordState.isUseSlices = audioSliceSec < maxAudioSec;

Expand Down Expand Up @@ -158,24 +151,10 @@ - (void)freeBufferIfNeeded {
}
}

bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSamples, int n)
bool vad(RNWhisperContextRecordState *state, short* pcm, int nSamples, int n)
{
bool isSpeech = true;
if (!state->isTranscribing && state->useVad) {
int sampleSize = (int) (WHISPER_SAMPLE_RATE * state->vadMs / 1000);
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
std::vector<float> audioBufferF32Vec(sampleSize);
for (int i = 0; i < sampleSize; i++) {
audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f;
}
isSpeech = rnwhisper::vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, state->vadThold, state->vadFreqThold, false);
NSLog(@"[RNWhisper] VAD result: %d", isSpeech);
} else {
isSpeech = false;
}
}
return isSpeech;
if (state->isTranscribing) return true;
return state->job->vad_simple(pcm, nSamples, n);
}

void AudioInputCallback(void * inUserData,
Expand Down Expand Up @@ -376,6 +355,14 @@ - (OSStatus)transcribeRealtime:(int)jobId
[self prepareRealtime:options];
self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]);

rnwhisper::vad_params vad = {
.use_vad = options[@"useVad"] != nil ? [options[@"useVad"] boolValue] : false,
.vad_ms = options[@"vadMs"] != nil ? [options[@"vadMs"] intValue] : 2000,
.vad_thold = options[@"vadThold"] != nil ? [options[@"vadThold"] floatValue] : 0.6f,
.freq_thold = options[@"vadFreqThold"] != nil ? [options[@"vadFreqThold"] floatValue] : 100.0f
};
self->recordState.job->set_vad_params(vad);

OSStatus status = AudioQueueNewInput(
&self->recordState.dataFormat,
AudioInputCallback,
Expand Down

0 comments on commit 50f8713

Please sign in to comment.