From 420ce16807eb5ac21946e2b2dfa0a819f45b1344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 12 Dec 2024 22:15:44 -0800 Subject: [PATCH] riva: fix FastPitchTTSService audio stuttering --- CHANGELOG.md | 2 ++ src/pipecat/services/riva.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ca41b8a4..22e08ef4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an audio stuttering issue in `FastPitchTTSService`. + - Fixed a `BaseOutputTransport` issue that was causing non-audio frames being processed before the previous audio frames were played. This will allow, for example, sending a frame `A` after a `TTSSpeakFrame` and the frame `A` will diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva.py index 6be722d49..470fe6dc9 100644 --- a/src/pipecat/services/riva.py +++ b/src/pipecat/services/riva.py @@ -76,7 +76,10 @@ def __init__( ) async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - def read_audio_responses(): + def read_audio_responses(queue: asyncio.Queue): + def add_response(r): + asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop()) + try: responses = self._service.synthesize_online( text, @@ -87,26 +90,32 @@ def read_audio_responses(): quality=self._quality, custom_dictionary={}, ) - return responses + for r in responses: + add_response(r) + add_response(None) except Exception as e: logger.error(f"{self} exception: {e}") - return [] + add_response(None) await self.start_ttfb_metrics() yield TTSStartedFrame() logger.debug(f"Generating TTS: [{text}]") - responses = await asyncio.to_thread(read_audio_responses) - for resp in responses: - await self.stop_ttfb_metrics() + queue = asyncio.Queue() + await asyncio.to_thread(read_audio_responses, queue) + # Wait for the thread to start. + resp = await queue.get() + while resp: + await self.stop_ttfb_metrics() frame = TTSAudioRawFrame( audio=resp.audio, sample_rate=self._sample_rate, num_channels=1, ) yield frame + resp = await queue.get() await self.start_tts_usage_metrics(text) yield TTSStoppedFrame()