From 13137bfe240e1568b10c645542ffe160710bb774 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Sat, 9 Nov 2024 11:25:50 +0800 Subject: [PATCH] feat(android): support transcribeData & transcribeFile with base64 --- .../main/java/com/rnwhisper/AudioUtils.java | 39 ++++--- .../main/java/com/rnwhisper/RNWhisper.java | 100 ++++++++++++------ .../java/com/rnwhisper/WhisperContext.java | 3 +- .../java/com/rnwhisper/RNWhisperModule.java | 5 + .../java/com/rnwhisper/RNWhisperModule.java | 5 + 5 files changed, 104 insertions(+), 48 deletions(-) diff --git a/android/src/main/java/com/rnwhisper/AudioUtils.java b/android/src/main/java/com/rnwhisper/AudioUtils.java index b6c614dd..dab2d022 100644 --- a/android/src/main/java/com/rnwhisper/AudioUtils.java +++ b/android/src/main/java/com/rnwhisper/AudioUtils.java @@ -2,8 +2,6 @@ import android.util.Log; -import java.io.IOException; -import java.io.FileReader; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; @@ -11,23 +9,22 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.ShortBuffer; +import java.util.Base64; + +import java.util.Arrays; public class AudioUtils { private static final String NAME = "RNWhisperAudioUtils"; - public static float[] decodeWaveFile(InputStream inputStream) throws IOException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - baos.write(buffer, 0, bytesRead); - } - ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray()); + private static float[] bufferToFloatArray(byte[] buffer, Boolean cutHeader) { + ByteBuffer byteBuffer = ByteBuffer.wrap(buffer); byteBuffer.order(ByteOrder.LITTLE_ENDIAN); - byteBuffer.position(44); ShortBuffer shortBuffer = byteBuffer.asShortBuffer(); short[] shortArray = new short[shortBuffer.limit()]; shortBuffer.get(shortArray); + if (cutHeader) { + shortArray = Arrays.copyOfRange(shortArray, 44, shortArray.length); + } float[] floatArray = new float[shortArray.length]; for (int i = 0; i < shortArray.length; i++) { floatArray[i] = ((float) shortArray[i]) / 32767.0f; @@ -36,4 +33,22 @@ public static float[] decodeWaveFile(InputStream inputStream) throws IOException } return floatArray; } -} \ No newline at end of file + + public static float[] decodeWaveFile(InputStream inputStream) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + baos.write(buffer, 0, bytesRead); + } + return bufferToFloatArray(baos.toByteArray(), true); + } + + public static float[] decodeWaveData(String dataBase64) throws IOException { + return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), true); + } + + public static float[] decodePcmData(String dataBase64) { + return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), false); + } +} diff --git a/android/src/main/java/com/rnwhisper/RNWhisper.java b/android/src/main/java/com/rnwhisper/RNWhisper.java index e04a5d95..447fa67a 100644 --- a/android/src/main/java/com/rnwhisper/RNWhisper.java +++ b/android/src/main/java/com/rnwhisper/RNWhisper.java @@ -19,6 +19,7 @@ import java.util.Random; import java.io.File; import java.io.FileInputStream; +import java.io.InputStream; import java.io.PushbackInputStream; public class RNWhisper implements LifecycleEventListener { @@ -119,44 +120,16 @@ protected void onPostExecute(Integer id) { tasks.put(task, "initContext"); } - public void transcribeFile(double id, double jobId, String filePath, ReadableMap options, Promise promise) { - final WhisperContext context = contexts.get((int) id); - if (context == null) { - promise.reject("Context not found"); - return; - } - if (context.isCapturing()) { - promise.reject("The context is in realtime transcribe mode"); - return; - } - if (context.isTranscribing()) { - promise.reject("Context is already transcribing"); - return; - } + private AsyncTask transcribe(WhisperContext context, double jobId, final float[] audioData, final ReadableMap options, Promise promise) { AsyncTask task = new AsyncTask() { private Exception exception; @Override protected WritableMap doInBackground(Void... voids) { try { - String waveFilePath = filePath; - - if (filePath.startsWith("http://") || filePath.startsWith("https://")) { - waveFilePath = downloader.downloadFile(filePath); - } - - int resId = getResourceIdentifier(waveFilePath); - if (resId > 0) { - return context.transcribeInputStream( - (int) jobId, - reactContext.getResources().openRawResource(resId), - options - ); - } - - return context.transcribeInputStream( + return context.transcribe( (int) jobId, - new FileInputStream(new File(waveFilePath)), + audioData, options ); } catch (Exception e) { @@ -175,7 +148,66 @@ protected void onPostExecute(WritableMap data) { tasks.remove(this); } }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); - tasks.put(task, "transcribeFile-" + id); + return task; + } + + public void transcribeFile(double id, double jobId, String filePathOrBase64, ReadableMap options, Promise promise) { + final WhisperContext context = contexts.get((int) id); + if (context == null) { + promise.reject("Context not found"); + return; + } + if (context.isCapturing()) { + promise.reject("The context is in realtime transcribe mode"); + return; + } + if (context.isTranscribing()) { + promise.reject("Context is already transcribing"); + return; + } + + String waveFilePath = filePathOrBase64; + try { + if (filePathOrBase64.startsWith("http://") || filePathOrBase64.startsWith("https://")) { + waveFilePath = downloader.downloadFile(filePathOrBase64); + } + + float[] audioData; + int resId = getResourceIdentifier(waveFilePath); + if (resId > 0) { + audioData = AudioUtils.decodeWaveFile(reactContext.getResources().openRawResource(resId)); + } else if (filePathOrBase64.startsWith("data:audio/wav;base64,")) { + audioData = AudioUtils.decodeWaveData(filePathOrBase64); + } else { + audioData = AudioUtils.decodeWaveFile(new FileInputStream(new File(waveFilePath))); + } + + AsyncTask task = transcribe(context, jobId, audioData, options, promise); + tasks.put(task, "transcribeFile-" + id); + } catch (Exception e) { + promise.reject(e); + } + } + + public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) { + final WhisperContext context = contexts.get((int) id); + if (context == null) { + promise.reject("Context not found"); + return; + } + if (context.isCapturing()) { + promise.reject("The context is in realtime transcribe mode"); + return; + } + if (context.isTranscribing()) { + promise.reject("Context is already transcribing"); + return; + } + + float[] audioData = AudioUtils.decodePcmData(dataBase64); + AsyncTask task = transcribe(context, jobId, audioData, options, promise); + + tasks.put(task, "transcribeData-" + id); } public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) { @@ -211,7 +243,7 @@ protected Void doInBackground(Void... voids) { context.stopTranscribe((int) jobId); AsyncTask completionTask = null; for (AsyncTask task : tasks.keySet()) { - if (tasks.get(task).equals("transcribeFile-" + id)) { + if (tasks.get(task).equals("transcribeFile-" + id) || tasks.get(task).equals("transcribeData-" + id)) { task.get(); break; } @@ -259,7 +291,7 @@ protected Void doInBackground(Void... voids) { context.stopCurrentTranscribe(); AsyncTask completionTask = null; for (AsyncTask task : tasks.keySet()) { - if (tasks.get(task).equals("transcribeFile-" + contextId)) { + if (tasks.get(task).equals("transcribeFile-" + contextId) || tasks.get(task).equals("transcribeData-" + contextId)) { task.get(); break; } diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 0b5b2be6..f3508af8 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -332,7 +332,7 @@ void onNewSegments(int nNew) { } } - public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception { + public WritableMap transcribe(int jobId, float[] audioData, ReadableMap options) throws IOException, Exception { if (isCapturing || isTranscribing) { throw new Exception("Context is already in capturing or transcribing"); } @@ -341,7 +341,6 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea this.isTdrzEnable = options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable"); isTranscribing = true; - float[] audioData = AudioUtils.decodeWaveFile(inputStream); boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress"); boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments"); diff --git a/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java b/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java index bdf6972c..a901d9f5 100644 --- a/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +++ b/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java @@ -47,6 +47,11 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap rnwhisper.transcribeFile(id, jobId, filePath, options, promise); } + @ReactMethod + public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) { + rnwhisper.transcribeData(id, jobId, dataBase64, options, promise); + } + @ReactMethod public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) { rnwhisper.startRealtimeTranscribe(id, jobId, options, promise); diff --git a/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java b/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java index e0f37c74..aba2463c 100644 --- a/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +++ b/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java @@ -47,6 +47,11 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap rnwhisper.transcribeFile(id, jobId, filePath, options, promise); } + @ReactMethod + public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) { + rnwhisper.transcribeData(id, jobId, dataBase64, options, promise); + } + @ReactMethod public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) { rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);