From 20c019ae1612a7aac76994a1842b3a80d407970b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 16 Sep 2024 23:54:21 -0700 Subject: [PATCH] services(cartesia,elevenlabs): close websocket before the receiving task --- src/pipecat/services/cartesia.py | 39 +++++++++++++++--------------- src/pipecat/services/elevenlabs.py | 20 +++++++-------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index a9d5aae67..7b4463812 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -136,24 +136,25 @@ async def _connect(self): ) self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) except Exception as e: - logger.exception(f"{self} initialization error: {e}") + logger.error(f"{self} initialization error: {e}") self._websocket = None async def _disconnect(self): try: await self.stop_all_metrics() + if self._websocket: + await self._websocket.close() + self._websocket = None + if self._receive_task: self._receive_task.cancel() await self._receive_task self._receive_task = None - if self._websocket: - await self._websocket.close() - self._websocket = None self._context_id = None except Exception as e: - logger.exception(f"{self} error closing websocket: {e}") + logger.error(f"{self} error closing websocket: {e}") async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): await super()._handle_interruption(frame, direction) @@ -166,18 +167,18 @@ async def flush_audio(self): return logger.debug("Flushing audio") msg = { - "transcript": "", - "continue": False, - "context_id": self._context_id, - "model_id": self._model_id, - "voice": { - "mode": "id", - "id": self._voice_id - }, - "output_format": self._output_format, - "language": self._language, - "add_timestamps": True, - } + "transcript": "", + "continue": False, + "context_id": self._context_id, + "model_id": self._model_id, + "voice": { + "mode": "id", + "id": self._voice_id + }, + "output_format": self._output_format, + "language": self._language, + "add_timestamps": True, + } await self._websocket.send(json.dumps(msg)) async def _receive_task_handler(self): @@ -217,7 +218,7 @@ async def _receive_task_handler(self): except asyncio.CancelledError: pass except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") @@ -255,4 +256,4 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: return yield None except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 6a3b85914..a7a80033e 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -174,13 +174,18 @@ async def _connect(self): } await self._websocket.send(json.dumps(msg)) except Exception as e: - logger.exception(f"{self} initialization error: {e}") + logger.error(f"{self} initialization error: {e}") self._websocket = None async def _disconnect(self): try: await self.stop_all_metrics() + if self._websocket: + await self._websocket.send(json.dumps({"text": ""})) + await self._websocket.close() + self._websocket = None + if self._receive_task: self._receive_task.cancel() await self._receive_task @@ -191,13 +196,9 @@ async def _disconnect(self): await self._keepalive_task self._keepalive_task = None - if self._websocket: - await self._websocket.close() - self._websocket = None - self._started = False except Exception as e: - logger.exception(f"{self} error closing websocket: {e}") + logger.error(f"{self} error closing websocket: {e}") async def _receive_task_handler(self): try: @@ -215,11 +216,10 @@ async def _receive_task_handler(self): word_times = calculate_word_times(msg["alignment"], self._cumulative_time) await self.add_word_timestamps(word_times) self._cumulative_time = word_times[-1][1] - except asyncio.CancelledError: pass except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") async def _keepalive_task_handler(self): while True: @@ -229,7 +229,7 @@ async def _keepalive_task_handler(self): except asyncio.CancelledError: break except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") async def _send_text(self, text: str): if self._websocket: @@ -260,4 +260,4 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: return yield None except Exception as e: - logger.exception(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}")