Skip to content

Commit

Permalink
Merge pull request #466 from pipecat-ai/aleix/elevenlabs-cartesia-clo…
Browse files Browse the repository at this point in the history
…se-websocket-first

services(cartesia,elevenlabs): close websocket before the receiving task
  • Loading branch information
aconchillo authored Sep 17, 2024
2 parents d9d6571 + 20c019a commit 13a4a05
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
39 changes: 20 additions & 19 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,24 +136,25 @@ async def _connect(self):
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
logger.error(f"{self} initialization error: {e}")
self._websocket = None

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._websocket:
await self._websocket.close()
self._websocket = None

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._websocket:
await self._websocket.close()
self._websocket = None

self._context_id = None
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")
logger.error(f"{self} error closing websocket: {e}")

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
Expand All @@ -166,18 +167,18 @@ async def flush_audio(self):
return
logger.debug("Flushing audio")
msg = {
"transcript": "",
"continue": False,
"context_id": self._context_id,
"model_id": self._model_id,
"voice": {
"mode": "id",
"id": self._voice_id
},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
"transcript": "",
"continue": False,
"context_id": self._context_id,
"model_id": self._model_id,
"voice": {
"mode": "id",
"id": self._voice_id
},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
await self._websocket.send(json.dumps(msg))

async def _receive_task_handler(self):
Expand Down Expand Up @@ -217,7 +218,7 @@ async def _receive_task_handler(self):
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
Expand Down Expand Up @@ -255,4 +256,4 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")
20 changes: 10 additions & 10 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,18 @@ async def _connect(self):
}
await self._websocket.send(json.dumps(msg))
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
logger.error(f"{self} initialization error: {e}")
self._websocket = None

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._websocket:
await self._websocket.send(json.dumps({"text": ""}))
await self._websocket.close()
self._websocket = None

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
Expand All @@ -191,13 +196,9 @@ async def _disconnect(self):
await self._keepalive_task
self._keepalive_task = None

if self._websocket:
await self._websocket.close()
self._websocket = None

self._started = False
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")
logger.error(f"{self} error closing websocket: {e}")

async def _receive_task_handler(self):
try:
Expand All @@ -215,11 +216,10 @@ async def _receive_task_handler(self):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]

except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")

async def _keepalive_task_handler(self):
while True:
Expand All @@ -229,7 +229,7 @@ async def _keepalive_task_handler(self):
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")

async def _send_text(self, text: str):
if self._websocket:
Expand Down Expand Up @@ -260,4 +260,4 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")
logger.error(f"{self} exception: {e}")

0 comments on commit 13a4a05

Please sign in to comment.