Skip to content

Commit

Permalink
feat(cpp): unify some platform code (audio slices, utils, ...) (#166)
Browse files Browse the repository at this point in the history
* feat(cpp): create rnwhisper_job struct

* feat(ios): update rn-whisper api

* feat(android): update rn-whisper api

* fix: user_data should not deref

* chore: revert unnecessary change

* feat(cpp): move abort handler

* feat(ios): store job in RNWhisperContext

* feat(ios): move vad params

* feat(android): create createRealtimeTranscribeJob and update vadSimple jni methods

* feat(cpp): move audio slices

* feat(cpp): move audio utils & save audio

* feat(docs): update

* feat(android): cleanup unnecessary arguments

* feat(android): keep todo

* fix(cpp): store job pointer instead

* feat(cpp): add custom log for easy debug in android

* fix(android): build

* fix(cpp): str params should not be released early

* fix(example): revert some unnecessary change
  • Loading branch information
jhen0409 authored Dec 9, 2023
1 parent ce49fce commit 913954c
Show file tree
Hide file tree
Showing 17 changed files with 560 additions and 512 deletions.
5 changes: 5 additions & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(
${RNWHISPER_LIB_DIR}/ggml-backend.c
${RNWHISPER_LIB_DIR}/ggml-quants.c
${RNWHISPER_LIB_DIR}/whisper.cpp
${RNWHISPER_LIB_DIR}/rn-audioutils.cpp
${RNWHISPER_LIB_DIR}/rn-whisper.cpp
${CMAKE_SOURCE_DIR}/jni.cpp
)
Expand All @@ -33,6 +34,10 @@ function(build_library target_name)
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
endif ()

if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -DRNWHISPER_ANDROID_ENABLE_LOGGING)
endif ()

# NOTE: If you want to debug the native code, you can uncomment if and endif
# if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")

Expand Down
80 changes: 0 additions & 80 deletions android/src/main/java/com/rnwhisper/AudioUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import android.util.Log;

import java.util.ArrayList;
import java.lang.StringBuilder;
import java.io.IOException;
import java.io.FileReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -19,82 +15,6 @@
public class AudioUtils {
private static final String NAME = "RNWhisperAudioUtils";

private static final int SAMPLE_RATE = 16000;

private static byte[] shortToByte(short[] shortInts) {
int j = 0;
int length = shortInts.length;
byte[] byteData = new byte[length * 2];
for (int i = 0; i < length; i++) {
byteData[j++] = (byte) (shortInts[i] >>> 8);
byteData[j++] = (byte) (shortInts[i] >>> 0);
}
return byteData;
}

public static byte[] concatShortBuffers(ArrayList<short[]> buffers) {
int totalLength = 0;
for (int i = 0; i < buffers.size(); i++) {
totalLength += buffers.get(i).length;
}
byte[] result = new byte[totalLength * 2];
int offset = 0;
for (int i = 0; i < buffers.size(); i++) {
byte[] bytes = shortToByte(buffers.get(i));
System.arraycopy(bytes, 0, result, offset, bytes.length);
offset += bytes.length;
}

return result;
}

private static byte[] removeTrailingZeros(byte[] audioData) {
int i = audioData.length - 1;
while (i >= 0 && audioData[i] == 0) {
--i;
}
byte[] newData = new byte[i + 1];
System.arraycopy(audioData, 0, newData, 0, i + 1);
return newData;
}

public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException {
Log.d(NAME, "call saveWavFile");
rawData = removeTrailingZeros(rawData);
DataOutputStream output = null;
try {
output = new DataOutputStream(new FileOutputStream(audioOutputFile));
// WAVE header
// see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/
output.writeBytes("RIFF"); // chunk id
output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size
output.writeBytes("WAVE"); // format
output.writeBytes("fmt "); // subchunk 1 id
output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size
output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM)
output.writeShort(Short.reverseBytes((short) 1)); // number of channels
output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate
output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate
output.writeShort(Short.reverseBytes((short) 2)); // block align
output.writeShort(Short.reverseBytes((short) 16)); // bits per sample
output.writeBytes("data"); // subchunk 2 id
output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size
// Audio data (conversion big endian -> little endian)
short[] shorts = new short[rawData.length / 2];
ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2);
for (short s : shorts) {
bytes.putShort(s);
}
Log.d(NAME, "writing audio file: " + audioOutputFile);
output.write(bytes.array());
} finally {
if (output != null) {
output.close();
}
}
}

public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
Expand Down
134 changes: 48 additions & 86 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public class WhisperContext {
private AudioRecord recorder = null;
private int bufferSize;
private int nSamplesTranscribing = 0;
private ArrayList<short[]> shortBufferSlices;
// Remember number of samples in each slice
private ArrayList<Integer> sliceNSamples;
// Current buffer slice index
Expand All @@ -66,7 +65,6 @@ public WhisperContext(int id, ReactApplicationContext reactContext, long context
}

private void rewind() {
shortBufferSlices = null;
sliceNSamples = null;
sliceIndex = 0;
transcribeSliceIndex = 0;
Expand All @@ -79,41 +77,14 @@ private void rewind() {
fullHandler = null;
}

private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
boolean isSpeech = true;
if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000;
if (vadMs < 2000) vadMs = 2000;
int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000);
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
float[] audioData = new float[sampleSize];
for (int i = 0; i < sampleSize; i++) {
audioData[i] = shortBuffer[i + start] / 32768.0f;
}
float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
} else {
isSpeech = false;
}
}
return isSpeech;
private boolean vad(int sliceIndex, int nSamples, int n) {
if (isTranscribing) return true;
return vadSimple(jobId, sliceIndex, nSamples, n);
}

private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
if (audioOutputPath != null) {
// TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
Log.d(NAME, "Begin saving wav file to " + audioOutputPath);
try {
AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath);
} catch (IOException e) {
Log.e(NAME, "Error saving wav file: " + e.getMessage());
}
}

private void finishRealtimeTranscribe(WritableMap result) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray());
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
Expand All @@ -135,16 +106,12 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {

int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;

int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0;
final int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < audioSec ? realtimeAudioSliceSec : audioSec;

isUseSlices = audioSliceSec < audioSec;

String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
createRealtimeTranscribeJob(jobId, context, options);

shortBufferSlices = new ArrayList<short[]>();
shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]);
sliceNSamples = new ArrayList<Integer>();
sliceNSamples.add(0);

Expand Down Expand Up @@ -175,37 +142,29 @@ public void run() {
nSamples == nSamplesTranscribing &&
sliceIndex == transcribeSliceIndex
) {
finishRealtimeTranscribe(options, Arguments.createMap());
finishRealtimeTranscribe(Arguments.createMap());
} else if (!isTranscribing) {
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
if (!isSpeech) {
finishRealtimeTranscribe(options, Arguments.createMap());
if (!vad(sliceIndex, nSamples, 0)) {
finishRealtimeTranscribe(Arguments.createMap());
break;
}
isTranscribing = true;
fullTranscribeSamples(options, true);
fullTranscribeSamples(true);
}
break;
}

// Append to buffer
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
if (nSamples + n > audioSliceSec * SAMPLE_RATE) {
Log.d(NAME, "next slice");

sliceIndex++;
nSamples = 0;
shortBuffer = new short[audioSliceSec * SAMPLE_RATE];
shortBufferSlices.add(shortBuffer);
sliceNSamples.add(0);
}
putPcmData(jobId, buffer, sliceIndex, nSamples, n);

for (int i = 0; i < n; i++) {
shortBuffer[nSamples + i] = buffer[i];
}

boolean isSpeech = vad(options, shortBuffer, nSamples, n);
boolean isSpeech = vad(sliceIndex, nSamples, n);

nSamples += n;
sliceNSamples.set(sliceIndex, nSamples);
Expand All @@ -217,7 +176,7 @@ public void run() {
fullHandler = new Thread(new Runnable() {
@Override
public void run() {
fullTranscribeSamples(options, false);
fullTranscribeSamples(false);
}
});
fullHandler.start();
Expand All @@ -228,7 +187,7 @@ public void run() {
}

if (!isTranscribing) {
finishRealtimeTranscribe(options, Arguments.createMap());
finishRealtimeTranscribe(Arguments.createMap());
}
if (fullHandler != null) {
fullHandler.join(); // Wait for full transcribe to finish
Expand All @@ -246,26 +205,16 @@ public void run() {
return state;
}

private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingCheck) {
private void fullTranscribeSamples(boolean skipCapturingCheck) {
int nSamplesOfIndex = sliceNSamples.get(transcribeSliceIndex);

if (!isCapturing && !skipCapturingCheck) return;

short[] shortBuffer = shortBufferSlices.get(transcribeSliceIndex);
int nSamples = sliceNSamples.get(transcribeSliceIndex);

nSamplesTranscribing = nSamplesOfIndex;

// convert I16 to F32
float[] nSamplesBuffer32 = new float[nSamplesTranscribing];
for (int i = 0; i < nSamplesTranscribing; i++) {
nSamplesBuffer32[i] = shortBuffer[i] / 32768.0f;
}

Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing);

int timeStart = (int) System.currentTimeMillis();
int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing);
int code = fullWithJob(jobId, context, transcribeSliceIndex, nSamplesTranscribing);
int timeEnd = (int) System.currentTimeMillis();
int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000);

Expand Down Expand Up @@ -302,7 +251,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe
if (isStopped && !continueNeeded) {
payload.putBoolean("isCapturing", false);
payload.putBoolean("isStoppedByAction", isStoppedByAction);
finishRealtimeTranscribe(options, payload);
finishRealtimeTranscribe(payload);
} else if (code == 0) {
payload.putBoolean("isCapturing", true);
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload);
Expand All @@ -313,7 +262,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe

if (continueNeeded) {
// If no more capturing, continue transcribing until all slices are transcribed
fullTranscribeSamples(options, true);
fullTranscribeSamples(true);
} else if (isStopped) {
// No next, cleanup
rewind();
Expand Down Expand Up @@ -383,32 +332,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
this.jobId = jobId;
isTranscribing = true;
float[] audioData = AudioUtils.decodeWaveFile(inputStream);
int code = full(jobId, options, audioData, audioData.length);
isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
return fullTranscribe(
int code = fullWithNewJob(
jobId,
context,
// float[] audio_data,
audioData,
// jint audio_data_len,
audioDataLen,
audioData.length,
// ReadableMap options,
options,
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);

isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private WritableMap getTextSegments(int start, int count) {
Expand Down Expand Up @@ -527,12 +474,13 @@ private static String cpuInfo() {
}
}


// JNI methods
protected static native long initContext(String modelPath);
protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
protected static native long initContextWithInputStream(PushbackInputStream inputStream);
protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
protected static native int fullTranscribe(
protected static native void freeContext(long contextPtr);

protected static native int fullWithNewJob(
int job_id,
long context,
float[] audio_data,
Expand All @@ -546,5 +494,19 @@ protected static native int fullTranscribe(
protected static native String getTextSegment(long context, int index);
protected static native int getTextSegmentT0(long context, int index);
protected static native int getTextSegmentT1(long context, int index);
protected static native void freeContext(long contextPtr);

protected static native void createRealtimeTranscribeJob(
int job_id,
long context,
ReadableMap options
);
protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples);
protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n);
protected static native void putPcmData(int job_id, short[] buffer, int slice_index, int n_samples, int n);
protected static native int fullWithJob(
int job_id,
long context,
int slice_index,
int n_samples
);
}
Loading

0 comments on commit 913954c

Please sign in to comment.