diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index 1efc13baf..8456c4536 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -32,7 +32,8 @@ class WebsocketServerParams(TransportParams): class WebsocketServerCallbacks(BaseModel): - on_connection: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] class WebsocketServerInputTransport(BaseInputTransport): @@ -84,7 +85,7 @@ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, p self._websocket = websocket # Notify - await self._callbacks.on_connection(websocket) + await self._callbacks.on_client_connected(websocket) # Handle incoming messages async for message in websocket: @@ -94,6 +95,9 @@ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, p else: await self._internal_push_frame(frame) + # Notify disconnection + await self._callbacks.on_client_disconnected(websocket) + await self._websocket.close() self._websocket = None @@ -111,7 +115,7 @@ def __init__(self, params: WebsocketServerParams): self._audio_buffer = bytes() - async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol): + async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None): if self._websocket: await self._websocket.close() logger.warning("Only one client allowed, using new connection") @@ -164,7 +168,8 @@ def __init__( self._params = params self._callbacks = WebsocketServerCallbacks( - on_connection=self._on_connection + on_client_connected=self._on_client_connected, + on_client_disconnected=self._on_client_disconnected ) self._input: WebsocketServerInputTransport | None = None self._output: WebsocketServerOutputTransport | None = None @@ -173,6 +178,7 @@ def __init__( # Register supported handlers. The user will only be able to register # these handlers. self._register_event_handler("on_client_connected") + self._register_event_handler("on_client_disconnected") def input(self) -> FrameProcessor: if not self._input: @@ -185,9 +191,16 @@ def output(self) -> FrameProcessor: self._output = WebsocketServerOutputTransport(self._params) return self._output - async def _on_connection(self, websocket): + async def _on_client_connected(self, websocket): if self._output: await self._output.set_client_connection(websocket) await self._call_event_handler("on_client_connected", websocket) else: logger.error("A WebsocketServerTransport output is missing in the pipeline") + + async def _on_client_disconnected(self, websocket): + if self._output: + await self._output.set_client_connection(None) + await self._call_event_handler("on_client_disconnected", websocket) + else: + logger.error("A WebsocketServerTransport output is missing in the pipeline")