From d61b32ca7266d467c8fa85f0f61646bef9f0b078 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 19 Sep 2024 15:45:21 -0700 Subject: [PATCH] Update API methods to use `transcribe` method with confidence levels --- neon_speech/service.py | 22 +++++++++++++--------- tests/unit_tests.py | 6 ++++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/neon_speech/service.py b/neon_speech/service.py index 7b9c2aa..3144fdb 100644 --- a/neon_speech/service.py +++ b/neon_speech/service.py @@ -27,7 +27,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os -from typing import Dict +from typing import Dict, List, Tuple import ovos_dinkum_listener.plugins @@ -421,9 +421,11 @@ def handle_get_stt(self, message: Message): message.context['timing']['client_to_core'] = \ received_time - sent_time message.context['timing']['response_sent'] = time() + transcribed_str = [t[0] for t in transcriptions] self.bus.emit(message.reply(ident, data={"parser_data": parser_data, - "transcripts": transcriptions})) + "transcripts": transcribed_str, + "transcripts_with_conf": transcriptions})) except Exception as e: LOG.error(e) message.context['timing']['response_sent'] = time() @@ -474,8 +476,9 @@ def build_context(msg: Message): message.context.setdefault('timing', dict()) message.context['timing'] = {**timing, **message.context['timing']} context = build_context(message) + transribed_str = [t[0] for t in transcriptions] data = { - "utterances": transcriptions, + "utterances": transribed_str, "lang": message.data.get("lang", "en-us") } # Send a new message to the skills module with proper routing ctx @@ -485,7 +488,8 @@ def build_context(msg: Message): # Reply to original message with transcription/audio parser data self.bus.emit(message.reply(ident, data={"parser_data": parser_data, - "transcripts": transcriptions, + "transcripts": transribed_str, + "transcripts_with_conf": transcriptions, "skills_recv": handled})) except Exception as e: LOG.error(e) @@ -535,7 +539,7 @@ def _write_encoded_file(audio_data: str) -> str: return wav_file_path def _get_stt_from_file(self, wav_file: str, - lang: str = None) -> (AudioData, dict, list): + lang: str = None) -> (AudioData, dict, List[Tuple[str, float]]): """ Performs STT and audio processing on the specified wav_file :param wav_file: wav audio file to process @@ -569,18 +573,18 @@ def _get_stt_from_file(self, wav_file: str, self.api_stt.stream_data(data) except EOFError: break - transcriptions = self.api_stt.stream_stop() + transcriptions = self.api_stt.transcribe(None, None) self.lock.release() else: LOG.error(f"Timed out acquiring lock, not processing: {wav_file}") transcriptions = [] else: - transcriptions = self.api_stt.execute(audio_data, lang) + transcriptions = self.api_stt.transcribe(audio_data, lang) if isinstance(transcriptions, str): - LOG.warning("Transcriptions is a str, no alternatives provided") + LOG.error("Transcriptions is a str, no alternatives provided") transcriptions = [transcriptions] - transcriptions = [clean_quotes(t) for t in transcriptions] + transcriptions = [(clean_quotes(t[0]), t[1]) for t in transcriptions] get_stt = float(_stopwatch.time) with _stopwatch: diff --git a/tests/unit_tests.py b/tests/unit_tests.py index 001eee1..fc02d2d 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -133,7 +133,8 @@ def test_get_stt_from_file(self): self.assertIsInstance(audio, AudioData) self.assertIsInstance(context, dict) self.assertIsInstance(transcripts, list) - self.assertIn("stop", transcripts) + tr_str = [t[0] for t in transcripts] + self.assertIn("stop", tr_str) def threaded_get_stt(): audio, context, transcripts = \ @@ -141,7 +142,8 @@ def threaded_get_stt(): self.assertIsInstance(audio, AudioData) self.assertIsInstance(context, dict) self.assertIsInstance(transcripts, list) - self.assertIn("stop", transcripts) + tr_str = [t[0] for t in transcripts] + self.assertIn("stop", tr_str) threads = list() for i in range(0, 12):