Skip to content

Commit

Permalink
transport(websocket-server): add on_client_disconnected
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed May 31, 2024
1 parent 38befe1 commit 58d20ec
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/pipecat/transports/network/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class WebsocketServerParams(TransportParams):


class WebsocketServerCallbacks(BaseModel):
on_connection: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]


class WebsocketServerInputTransport(BaseInputTransport):
Expand Down Expand Up @@ -84,7 +85,7 @@ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, p
self._websocket = websocket

# Notify
await self._callbacks.on_connection(websocket)
await self._callbacks.on_client_connected(websocket)

# Handle incoming messages
async for message in websocket:
Expand All @@ -94,6 +95,9 @@ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, p
else:
await self._internal_push_frame(frame)

# Notify disconnection
await self._callbacks.on_client_disconnected(websocket)

await self._websocket.close()
self._websocket = None

Expand All @@ -111,7 +115,7 @@ def __init__(self, params: WebsocketServerParams):

self._audio_buffer = bytes()

async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol):
async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None):
if self._websocket:
await self._websocket.close()
logger.warning("Only one client allowed, using new connection")
Expand Down Expand Up @@ -164,7 +168,8 @@ def __init__(
self._params = params

self._callbacks = WebsocketServerCallbacks(
on_connection=self._on_connection
on_client_connected=self._on_client_connected,
on_client_disconnected=self._on_client_disconnected
)
self._input: WebsocketServerInputTransport | None = None
self._output: WebsocketServerOutputTransport | None = None
Expand All @@ -173,6 +178,7 @@ def __init__(
# Register supported handlers. The user will only be able to register
# these handlers.
self._register_event_handler("on_client_connected")
self._register_event_handler("on_client_disconnected")

def input(self) -> FrameProcessor:
if not self._input:
Expand All @@ -185,9 +191,16 @@ def output(self) -> FrameProcessor:
self._output = WebsocketServerOutputTransport(self._params)
return self._output

async def _on_connection(self, websocket):
async def _on_client_connected(self, websocket):
if self._output:
await self._output.set_client_connection(websocket)
await self._call_event_handler("on_client_connected", websocket)
else:
logger.error("A WebsocketServerTransport output is missing in the pipeline")

async def _on_client_disconnected(self, websocket):
if self._output:
await self._output.set_client_connection(None)
await self._call_event_handler("on_client_disconnected", websocket)
else:
logger.error("A WebsocketServerTransport output is missing in the pipeline")

0 comments on commit 58d20ec

Please sign in to comment.