Skip to content

Commit

Permalink
consistent speechrecognizer support
Browse files Browse the repository at this point in the history
  • Loading branch information
crc-32 committed Nov 6, 2024
1 parent 498ca14 commit b27ecff
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[submodule "android/speex_codec/src/main/cpp/speex"]
path = android/speex_codec/src/main/cpp/speex
url = [email protected]:xiph/speex.git
[submodule "android/speex_codec/src/main/cpp/speexdsp"]
path = android/speex_codec/src/main/cpp/speexdsp
url = [email protected]:xiph/speexdsp.git
[submodule "android/speex_codec/src/main/cpp/kissfft"]
path = android/speex_codec/src/main/cpp/kissfft
url = [email protected]:mborgerding/kissfft.git
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.rebble.cobble.shared.domain.voice.speechrecognizer
import android.content.Context
import android.content.Intent
import android.media.AudioFormat
import android.media.AudioTrack
import android.os.Build.VERSION_CODES
import android.os.Bundle
import android.os.ParcelFileDescriptor
Expand All @@ -20,16 +21,29 @@ import kotlinx.coroutines.flow.*
import org.koin.core.component.KoinComponent
import org.koin.core.component.inject
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.ShortBuffer
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.time.Duration.Companion.milliseconds


@RequiresApi(VERSION_CODES.TIRAMISU)
class SpeechRecognizerDictationService: DictationService, KoinComponent {
private val context: Context by inject()
private val scope = CoroutineScope(Dispatchers.IO)
private val audioTrack = AudioTrack.Builder()
.setAudioFormat(AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(16000)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build())
.setBufferSizeInBytes(16000)
.setTransferMode(AudioTrack.MODE_STREAM)
.build()

companion object {
private const val GAIN = 1.5f
private val AUDIO_LATENCY = 600.milliseconds
fun buildRecognizerIntent(audioSource: ParcelFileDescriptor? = null, encoding: Int = AudioFormat.ENCODING_PCM_16BIT, sampleRate: Int = 16000) = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(RecognizerIntent.EXTRA_LANGUAGE_MODEL, RecognizerIntent.LANGUAGE_MODEL_FREE_FORM)
audioSource?.let {
Expand Down Expand Up @@ -95,8 +109,7 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
})
speechRecognizer.startListening(intent)
awaitClose {
Logging.d("Closing speech recognition listener")
speechRecognizer.cancel()

}
}.flowOn(Dispatchers.Main)

Expand All @@ -105,9 +118,10 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
return@flow
}
val decoder = SpeexCodec(speexEncoderInfo.sampleRate, speexEncoderInfo.bitRate)
val decoder = SpeexCodec(speexEncoderInfo.sampleRate, speexEncoderInfo.bitRate, speexEncoderInfo.frameSize, setOf(SpeexCodec.Preprocessor.DENOISE, SpeexCodec.Preprocessor.AGC))
val decodeBufLength = Short.SIZE_BYTES * speexEncoderInfo.frameSize
val decodedBuf = ByteBuffer.allocateDirect(decodeBufLength)
decodedBuf.order(ByteOrder.nativeOrder())
val recognizerPipes = ParcelFileDescriptor.createSocketPair()
val recognizerReadPipe = recognizerPipes[0]
val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1])
Expand All @@ -127,17 +141,22 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
return@flow
}
audioTrack.play()

val audioJob = scope.launch {
audioStreamFrames
.onEach { frame ->
if (frame is AudioStreamFrame.Stop) {
//Logging.v("Stop")
withContext(Dispatchers.IO) {
// Pad with extra frame of silence
recognizerWritePipe.write(ByteArray(speexEncoderInfo.frameSize * Short.SIZE_BYTES))
}
recognizerWritePipe.flush()
delay(AUDIO_LATENCY)
withContext(Dispatchers.Main) {
//XXX: Shouldn't use main here for I/O call but recognizer has weird thread behaviour
recognizerWritePipe.close()
recognizerReadPipe.close()
speechRecognizer.stopListening()
}
} else if (frame is AudioStreamFrame.AudioData) {
Expand All @@ -146,7 +165,10 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
Logging.e("Speex decode error: ${result.name}")
}
decodedBuf.rewind()
recognizerWritePipe.write(decodedBuf.array(), decodedBuf.arrayOffset(), decodeBufLength)
withContext(Dispatchers.IO) {
audioTrack.write(decodedBuf.array(), decodedBuf.arrayOffset(), decodeBufLength)
recognizerWritePipe.write(decodedBuf.array(), decodedBuf.arrayOffset(), decodeBufLength)
}
}
}
.flowOn(Dispatchers.IO)
Expand Down Expand Up @@ -190,8 +212,10 @@ class SpeechRecognizerDictationService: DictationService, KoinComponent {
}
}
} finally {
audioTrack.stop()
audioJob.cancel()
speechRecognizer.destroy()
decoder.close()
}

}
Expand Down
33 changes: 33 additions & 0 deletions android/speex_codec/src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,30 @@
# Sets the minimum CMake version required for this project.
cmake_minimum_required(VERSION 3.22.1)

set(PYTHON_EXECUTABLE "python3")

# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
# Since this is the top level CMakeLists.txt, the project name is also accessible
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
# build script scope).
project("speex_codec")

add_library(kissfft STATIC
kissfft/kiss_fft.c
kissfft/kfc.c
kissfft/kiss_fftnd.c
kissfft/kiss_fftndr.c
kissfft/kiss_fftr.c
)

target_include_directories(kissfft PRIVATE
kissfft
)

target_compile_definitions(kissfft PRIVATE
KISSFFT_DATATYPE=int16_t
KISSFFT_TOOLS=0
)
# Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code.
# You can define multiple libraries, and CMake builds them for you.
Expand Down Expand Up @@ -57,6 +75,18 @@ add_library(${CMAKE_PROJECT_NAME} SHARED
speex/libspeex/hexc_10_32_table.c
speex/libspeex/gain_table.c
speex/libspeex/gain_table_lbr.c

speexdsp/libspeexdsp/buffer.c
speexdsp/libspeexdsp/fftwrap.c
speexdsp/libspeexdsp/filterbank.c
speexdsp/libspeexdsp/jitter.c
speexdsp/libspeexdsp/kiss_fft.c
speexdsp/libspeexdsp/kiss_fftr.c
speexdsp/libspeexdsp/mdf.c
speexdsp/libspeexdsp/preprocess.c
speexdsp/libspeexdsp/resample.c
speexdsp/libspeexdsp/scal.c
speexdsp/libspeexdsp/smallft.c
)

target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE
Expand All @@ -68,11 +98,13 @@ target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
# List paths to include headers from
include
speex/include
speexdsp/include
)

target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
FIXED_POINT
"EXPORT=/* */"
USE_KISS_FFT
)

# Specifies libraries CMake should link to your target library. You
Expand All @@ -82,4 +114,5 @@ target_link_libraries(${CMAKE_PROJECT_NAME}
# List libraries link to the target library
android
log
kissfft
)
11 changes: 11 additions & 0 deletions android/speex_codec/src/main/cpp/include/speexdsp_config_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//
// Created by crc32 on 06/11/2024.
//

#ifndef ANDROID_SPEEXDSP_CONFIG_TYPES_H
#define ANDROID_SPEEXDSP_CONFIG_TYPES_H
typedef short spx_int16_t;
typedef unsigned short spx_uint16_t;
typedef int spx_int32_t;
typedef unsigned int spx_uint32_t;
#endif //ANDROID_SPEEXDSP_CONFIG_TYPES_H
1 change: 1 addition & 0 deletions android/speex_codec/src/main/cpp/kissfft
Submodule kissfft added at f5f2a3
45 changes: 45 additions & 0 deletions android/speex_codec/src/main/cpp/speex_codec.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#include <jni.h>
#include <string>
#include <speex/speex.h>
#include <speex/speex_preprocess.h>

static jfieldID speexDecBits;
static jfieldID speexDecState;
static jfieldID speexPreprocessState;

static const int FLAG_PREPROCESSOR_DENOISE = 1;
static const int FLAG_PREPROCESSOR_AGC = 2;
static const int FLAG_PREPROCESSOR_VAD = 4;

extern "C"
JNIEXPORT jint JNICALL
Expand All @@ -18,6 +24,11 @@ Java_com_example_speex_1codec_SpeexCodec_decode(JNIEnv *env, jobject thiz,
speex_bits_read_from(bits, reinterpret_cast<char *>(encoded_frame_data)+offset, encoded_frame_length-offset);
int result = speex_decode_int(dec_state, bits, out_frame_data);
env->ReleaseByteArrayElements(encoded_frame, encoded_frame_data, 0);

auto *preprocess_state = reinterpret_cast<SpeexPreprocessState *>(env->GetLongField(thiz, speexPreprocessState));
if (preprocess_state != nullptr) {
speex_preprocess_run(preprocess_state, out_frame_data);
}
return result;
}
extern "C"
Expand Down Expand Up @@ -55,4 +66,38 @@ Java_com_example_speex_1codec_SpeexCodec_initNative(JNIEnv *env, jobject thiz) {
jclass clazz = env->GetObjectClass(thiz);
speexDecBits = env->GetFieldID(clazz, "speexDecBits", "J");
speexDecState = env->GetFieldID(clazz, "speexDecState", "J");
speexPreprocessState = env->GetFieldID(clazz, "speexPreprocessState", "J");
}
extern "C"
JNIEXPORT void JNICALL
Java_com_example_speex_1codec_SpeexCodec_destroyPreprocessState(JNIEnv *env, jobject thiz,
jlong preprocess_state) {
if (preprocess_state == 0) {
return;
}
auto *state = reinterpret_cast<SpeexPreprocessState *>(preprocess_state);
speex_preprocess_state_destroy(state);
}
extern "C"
JNIEXPORT jlong JNICALL
Java_com_example_speex_1codec_SpeexCodec_initPreprocessState(JNIEnv *env, jobject thiz,
jint preprocessors, jint sample_rate,
jint frame_size) {
if (preprocessors == 0) {
return 0;
}
auto *preprocess_state = speex_preprocess_state_init(frame_size, sample_rate);
if (preprocessors & FLAG_PREPROCESSOR_DENOISE) {
int denoise = 1;
speex_preprocess_ctl(preprocess_state, SPEEX_PREPROCESS_SET_DENOISE, &denoise);
}
if (preprocessors & FLAG_PREPROCESSOR_AGC) {
int agc = 1;
speex_preprocess_ctl(preprocess_state, SPEEX_PREPROCESS_SET_AGC, &agc);
}
if (preprocessors & FLAG_PREPROCESSOR_VAD) {
int vad = 1;
speex_preprocess_ctl(preprocess_state, SPEEX_PREPROCESS_SET_VAD, &vad);
}
return reinterpret_cast<jlong>(preprocess_state);
}
1 change: 1 addition & 0 deletions android/speex_codec/src/main/cpp/speexdsp
Submodule speexdsp added at dbd421
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@ package com.example.speex_codec
import android.media.MediaCodec
import java.nio.ByteBuffer

class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCloseable {
class SpeexCodec(private val sampleRate: Long, private val bitRate: Int, private val frameSize: Int, private val preprocessors: Set<Preprocessor> = emptySet()): AutoCloseable {
enum class Preprocessor(val flagValue: Int) {
DENOISE(1),
AGC(2),
VAD(4)
}

init {
initNative()
}
private val speexDecBits: Long = initSpeexBits()
private val speexDecState: Long = initDecState(sampleRate, bitRate)
private val speexPreprocessState: Long = initPreprocessState(preprocessors.fold(0) { acc, preprocessor -> acc or preprocessor.flagValue }, sampleRate.toInt(), frameSize)

/**
* Decode a frame of audio data.
Expand All @@ -23,12 +30,15 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl
override fun close() {
destroySpeexBits(speexDecBits)
destroyDecState(speexDecState)
destroyPreprocessState(speexPreprocessState)
}

private external fun initNative()
private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteBuffer, hasHeaderByte: Boolean): Int
private external fun initSpeexBits(): Long
private external fun initDecState(sampleRate: Long, bitRate: Int): Long
private external fun initPreprocessState(preprocessors: Int, sampleRate: Int, frameSize: Int): Long
private external fun destroyPreprocessState(preprocessState: Long)
private external fun destroySpeexBits(speexBits: Long)
private external fun destroyDecState(decState: Long)

Expand Down

0 comments on commit b27ecff

Please sign in to comment.