From 973c992c450d5ee7e760e17666c7b18abefcfd43 Mon Sep 17 00:00:00 2001 From: joachimchauvet Date: Tue, 1 Oct 2024 10:48:34 +0300 Subject: [PATCH] match behavior of Daily's on_first_participant_joined --- src/pipecat/transports/services/livekit.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 52bbbf89d..ef4a6bc0a 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -58,6 +58,7 @@ class LiveKitCallbacks(BaseModel): on_audio_track_subscribed: Callable[[str], Awaitable[None]] on_audio_track_unsubscribed: Callable[[str], Awaitable[None]] on_data_received: Callable[[bytes, str], Awaitable[None]] + on_first_participant_joined: Callable[[str], Awaitable[None]] class LiveKitTransportClient: @@ -83,6 +84,7 @@ def __init__( self._audio_track: rtc.LocalAudioTrack | None = None self._audio_tracks = {} self._audio_queue = asyncio.Queue() + self._other_participant_has_joined = False # Set up room event handlers self._room.on("participant_connected")(self._on_participant_connected_wrapper) @@ -126,6 +128,12 @@ async def connect(self): await self._room.local_participant.publish_track(self._audio_track, options) await self._callbacks.on_connected() + + # Check if there are already participants in the room + participants = self.get_participants() + if participants and not self._other_participant_has_joined: + self._other_participant_has_joined = True + await self._callbacks.on_first_participant_joined(participants[0]) except Exception as e: logger.error(f"Error connecting to {self._room_name}: {e}") raise @@ -230,10 +238,15 @@ def _on_disconnected_wrapper(self): async def _async_on_participant_connected(self, participant: rtc.RemoteParticipant): logger.info(f"Participant connected: {participant.identity}") await self._callbacks.on_participant_connected(participant.sid) + if not self._other_participant_has_joined: + self._other_participant_has_joined = True + await self._callbacks.on_first_participant_joined(participant.sid) async def _async_on_participant_disconnected(self, participant: rtc.RemoteParticipant): logger.info(f"Participant disconnected: {participant.identity}") await self._callbacks.on_participant_disconnected(participant.sid) + if len(self.get_participants()) == 0: + self._other_participant_has_joined = False async def _async_on_track_subscribed( self, @@ -503,6 +516,7 @@ def _create_callbacks(self) -> LiveKitCallbacks: on_audio_track_subscribed=self._on_audio_track_subscribed, on_audio_track_unsubscribed=self._on_audio_track_unsubscribed, on_data_received=self._on_data_received, + on_first_participant_joined=self._on_first_participant_joined, ) def input(self) -> FrameProcessor: @@ -554,8 +568,6 @@ async def _on_disconnected(self): async def _on_participant_connected(self, participant_id: str): await self._call_event_handler("on_participant_connected", participant_id) - if len(self.get_participants()) == 1: - await self._call_event_handler("on_first_participant_joined", participant_id) async def _on_participant_disconnected(self, participant_id: str): await self._call_event_handler("on_participant_disconnected", participant_id) @@ -608,3 +620,6 @@ async def on_track_event(self, event): async def _on_call_state_updated(self, state: str): await self._call_event_handler("on_call_state_updated", self, state) + + async def _on_first_participant_joined(self, participant_id: str): + await self._call_event_handler("on_first_participant_joined", participant_id)