Skip to content

Commit

Permalink
riva: make sure we don't block on fastpitch
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Dec 13, 2024
1 parent 8f24ca4 commit aac907a
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/pipecat/services/riva.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
)
raise Exception(f"Missing module: {e}")

FASTPITCH_TIMEOUT_SECS = 5


class FastPitchTTSService(TTSService):
class InputParams(BaseModel):
Expand Down Expand Up @@ -102,20 +104,23 @@ def add_response(r):

logger.debug(f"Generating TTS: [{text}]")

queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)

# Wait for the thread to start.
resp = await queue.get()
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self._sample_rate,
num_channels=1,
)
yield frame
resp = await queue.get()
try:
queue = asyncio.Queue()
await asyncio.to_thread(read_audio_responses, queue)

# Wait for the thread to start.
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
while resp:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self._sample_rate,
num_channels=1,
)
yield frame
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")

await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()
Expand Down

0 comments on commit aac907a

Please sign in to comment.