diff --git a/examples/foundational/01b-livekit-audio.py b/examples/foundational/01b-livekit-audio.py new file mode 100644 index 000000000..68e0d2803 --- /dev/null +++ b/examples/foundational/01b-livekit-audio.py @@ -0,0 +1,108 @@ +import argparse +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from livekit import api # pip install livekit-api +from loguru import logger + +from pipecat.frames.frames import TextFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.transports.services.livekit import LiveKitParams, LiveKitTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +def generate_token(room_name: str, participant_name: str, api_key: str, api_secret: str) -> str: + token = api.AccessToken(api_key, api_secret) + token.with_identity(participant_name).with_name(participant_name).with_grants( + api.VideoGrants( + room_join=True, + room=room_name, + ) + ) + + return token.to_jwt() + + +async def configure_livekit(): + parser = argparse.ArgumentParser(description="LiveKit AI SDK Bot Sample") + parser.add_argument( + "-r", "--room", type=str, required=False, help="Name of the LiveKit room to join" + ) + parser.add_argument("-u", "--url", type=str, required=False, help="URL of the LiveKit server") + + args, unknown = parser.parse_known_args() + + room_name = args.room or os.getenv("LIVEKIT_ROOM_NAME") + url = args.url or os.getenv("LIVEKIT_URL") + api_key = os.getenv("LIVEKIT_API_KEY") + api_secret = os.getenv("LIVEKIT_API_SECRET") + + if not room_name: + raise Exception( + "No LiveKit room specified. Use the -r/--room option from the command line, or set LIVEKIT_ROOM_NAME in your environment." + ) + + if not url: + raise Exception( + "No LiveKit server URL specified. Use the -u/--url option from the command line, or set LIVEKIT_URL in your environment." + ) + + if not api_key or not api_secret: + raise Exception( + "LIVEKIT_API_KEY and LIVEKIT_API_SECRET must be set in environment variables." + ) + + token = generate_token(room_name, "Say One Thing", api_key, api_secret) + + user_token = generate_token(room_name, "User", api_key, api_secret) + logger.info(f"User token: {user_token}") + + return (url, token, room_name) + + +async def main(): + async with aiohttp.ClientSession() as session: + (url, token, room_name) = await configure_livekit() + + transport = LiveKitTransport( + url=url, + token=token, + room_name=room_name, + params=LiveKitParams(audio_out_enabled=True, audio_out_sample_rate=16000), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + runner = PipelineRunner() + + task = PipelineTask(Pipeline([tts, transport.output()])) + + # Register an event handler so we can play the audio when the + # participant joins. + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant_id): + await asyncio.sleep(1) + await task.queue_frame( + TextFrame( + "Hello there! How are you doing today? Would you like to talk about the weather?" + ) + ) + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 8dcfd7cb0..73e9363db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ google = [ "google-generativeai~=0.7.2" ] gstreamer = [ "pygobject~=3.48.2" ] fireworks = [ "openai~=1.37.2" ] langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ] -livekit = [ "livekit~=0.13.1" ] +livekit = [ "livekit~=0.13.1", "tenacity~=9.0.0" ] lmnt = [ "lmnt~=1.1.4" ] local = [ "pyaudio~=0.2.14" ] moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ] diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py new file mode 100644 index 000000000..52bbbf89d --- /dev/null +++ b/src/pipecat/transports/services/livekit.py @@ -0,0 +1,610 @@ +import asyncio +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, List + +import numpy as np +from loguru import logger +from pydantic import BaseModel +from scipy import signal + +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + Frame, + MetricsFrame, + StartFrame, + TransportMessageFrame, +) +from pipecat.metrics.metrics import ( + LLMUsageMetricsData, + ProcessingMetricsData, + TTFBMetricsData, + TTSUsageMetricsData, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.vad.vad_analyzer import VADAnalyzer + +try: + from livekit import rtc + from tenacity import retry, stop_after_attempt, wait_exponential +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use LiveKit, you need to `pip install pipecat-ai[livekit]`.") + raise Exception(f"Missing module: {e}") + + +@dataclass +class LiveKitTransportMessageFrame(TransportMessageFrame): + participant_id: str | None = None + + +class LiveKitParams(TransportParams): + audio_out_sample_rate: int = 48000 + audio_out_channels: int = 1 + vad_enabled: bool = True + vad_analyzer: VADAnalyzer | None = None + audio_in_sample_rate: int = 16000 + + +class LiveKitCallbacks(BaseModel): + on_connected: Callable[[], Awaitable[None]] + on_disconnected: Callable[[], Awaitable[None]] + on_participant_connected: Callable[[str], Awaitable[None]] + on_participant_disconnected: Callable[[str], Awaitable[None]] + on_audio_track_subscribed: Callable[[str], Awaitable[None]] + on_audio_track_unsubscribed: Callable[[str], Awaitable[None]] + on_data_received: Callable[[bytes, str], Awaitable[None]] + + +class LiveKitTransportClient: + def __init__( + self, + url: str, + token: str, + room_name: str, + params: LiveKitParams, + callbacks: LiveKitCallbacks, + loop: asyncio.AbstractEventLoop, + ): + self._url = url + self._token = token + self._room_name = room_name + self._params = params + self._callbacks = callbacks + self._loop = loop + self._room = rtc.Room(loop=loop) + self._participant_id: str = "" + self._connected = False + self._audio_source: rtc.AudioSource | None = None + self._audio_track: rtc.LocalAudioTrack | None = None + self._audio_tracks = {} + self._audio_queue = asyncio.Queue() + + # Set up room event handlers + self._room.on("participant_connected")(self._on_participant_connected_wrapper) + self._room.on("participant_disconnected")(self._on_participant_disconnected_wrapper) + self._room.on("track_subscribed")(self._on_track_subscribed_wrapper) + self._room.on("track_unsubscribed")(self._on_track_unsubscribed_wrapper) + self._room.on("data_received")(self._on_data_received_wrapper) + self._room.on("connected")(self._on_connected_wrapper) + self._room.on("disconnected")(self._on_disconnected_wrapper) + + @property + def participant_id(self) -> str: + return self._participant_id + + @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) + async def connect(self): + if self._connected: + return + + logger.info(f"Connecting to {self._room_name}") + + try: + await self._room.connect( + self._url, + self._token, + options=rtc.RoomOptions(auto_subscribe=True), + ) + self._connected = True + self._participant_id = self._room.local_participant.sid + logger.info(f"Connected to {self._room_name}") + + # Set up audio source and track + self._audio_source = rtc.AudioSource( + self._params.audio_out_sample_rate, self._params.audio_out_channels + ) + self._audio_track = rtc.LocalAudioTrack.create_audio_track( + "pipecat-audio", self._audio_source + ) + options = rtc.TrackPublishOptions() + options.source = rtc.TrackSource.SOURCE_MICROPHONE + await self._room.local_participant.publish_track(self._audio_track, options) + + await self._callbacks.on_connected() + except Exception as e: + logger.error(f"Error connecting to {self._room_name}: {e}") + raise + + async def disconnect(self): + if not self._connected: + return + + logger.info(f"Disconnecting from {self._room_name}") + await self._room.disconnect() + self._connected = False + logger.info(f"Disconnected from {self._room_name}") + await self._callbacks.on_disconnected() + + async def send_data(self, data: bytes, participant_id: str | None = None): + if not self._connected: + return + + try: + if participant_id: + await self._room.local_participant.publish_data( + data, reliable=True, destination_identities=[participant_id] + ) + else: + await self._room.local_participant.publish_data(data, reliable=True) + except Exception as e: + logger.error(f"Error sending data: {e}") + + async def publish_audio(self, audio_frame: rtc.AudioFrame): + if not self._connected or not self._audio_source: + return + + try: + await self._audio_source.capture_frame(audio_frame) + except Exception as e: + logger.error(f"Error publishing audio: {e}") + + def get_participants(self) -> List[str]: + return [p.sid for p in self._room.remote_participants.values()] + + async def get_participant_metadata(self, participant_id: str) -> dict: + participant = self._room.remote_participants.get(participant_id) + if participant: + return { + "id": participant.sid, + "name": participant.name, + "metadata": participant.metadata, + "is_speaking": participant.is_speaking, + } + return {} + + async def set_participant_metadata(self, metadata: str): + await self._room.local_participant.set_metadata(metadata) + + async def mute_participant(self, participant_id: str): + participant = self._room.remote_participants.get(participant_id) + if participant: + for track in participant.tracks.values(): + if track.kind == "audio": + await track.set_enabled(False) + + async def unmute_participant(self, participant_id: str): + participant = self._room.remote_participants.get(participant_id) + if participant: + for track in participant.tracks.values(): + if track.kind == "audio": + await track.set_enabled(True) + + # Wrapper methods for event handlers + def _on_participant_connected_wrapper(self, participant: rtc.RemoteParticipant): + asyncio.create_task(self._async_on_participant_connected(participant)) + + def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipant): + asyncio.create_task(self._async_on_participant_disconnected(participant)) + + def _on_track_subscribed_wrapper( + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ): + asyncio.create_task(self._async_on_track_subscribed(track, publication, participant)) + + def _on_track_unsubscribed_wrapper( + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ): + asyncio.create_task(self._async_on_track_unsubscribed(track, publication, participant)) + + def _on_data_received_wrapper(self, data: rtc.DataPacket): + asyncio.create_task(self._async_on_data_received(data)) + + def _on_connected_wrapper(self): + asyncio.create_task(self._async_on_connected()) + + def _on_disconnected_wrapper(self): + asyncio.create_task(self._async_on_disconnected()) + + # Async methods for event handling + async def _async_on_participant_connected(self, participant: rtc.RemoteParticipant): + logger.info(f"Participant connected: {participant.identity}") + await self._callbacks.on_participant_connected(participant.sid) + + async def _async_on_participant_disconnected(self, participant: rtc.RemoteParticipant): + logger.info(f"Participant disconnected: {participant.identity}") + await self._callbacks.on_participant_disconnected(participant.sid) + + async def _async_on_track_subscribed( + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ): + if track.kind == rtc.TrackKind.KIND_AUDIO: + logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}") + self._audio_tracks[participant.sid] = track + audio_stream = rtc.AudioStream(track) + asyncio.create_task(self._process_audio_stream(audio_stream, participant.sid)) + + async def _async_on_track_unsubscribed( + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, + ): + logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}") + if track.kind == rtc.TrackKind.KIND_AUDIO: + await self._callbacks.on_audio_track_unsubscribed(participant.sid) + + async def _async_on_data_received(self, data: rtc.DataPacket): + await self._callbacks.on_data_received(data.data, data.participant.sid) + + async def _async_on_connected(self): + await self._callbacks.on_connected() + + async def _async_on_disconnected(self, reason=None): + self._connected = False + logger.info(f"Disconnected from {self._room_name}. Reason: {reason}") + await self._callbacks.on_disconnected() + + async def _process_audio_stream(self, audio_stream: rtc.AudioStream, participant_id: str): + logger.info(f"Started processing audio stream for participant {participant_id}") + async for event in audio_stream: + if isinstance(event, rtc.AudioFrameEvent): + await self._audio_queue.put((event, participant_id)) + else: + logger.warning(f"Received unexpected event type: {type(event)}") + + async def cleanup(self): + await self.disconnect() + + async def get_next_audio_frame(self): + frame, participant_id = await self._audio_queue.get() + return frame, participant_id + + +class LiveKitInputTransport(BaseInputTransport): + def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs): + super().__init__(params, **kwargs) + self._client = client + self._audio_in_task = None + self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer + self._current_sample_rate: int = params.audio_in_sample_rate + if params.vad_enabled and not params.vad_analyzer: + self._vad_analyzer = VADAnalyzer( + sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels + ) + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._client.connect() + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_in_task = asyncio.create_task(self._audio_in_task_handler()) + logger.info("LiveKitInputTransport started") + + async def stop(self, frame: EndFrame): + if self._audio_in_task: + self._audio_in_task.cancel() + try: + await self._audio_in_task + except asyncio.CancelledError: + pass + await super().stop(frame) + await self._client.disconnect() + logger.info("LiveKitInputTransport stopped") + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, EndFrame): + await self.stop(frame) + else: + await super().process_frame(frame, direction) + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._client.disconnect() + if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): + self._audio_in_task.cancel() + await self._audio_in_task + + def vad_analyzer(self) -> VADAnalyzer | None: + return self._vad_analyzer + + async def push_app_message(self, message: Any, sender: str): + frame = LiveKitTransportMessageFrame(message=message, participant_id=sender) + await self.push_frame(frame) + + async def _audio_in_task_handler(self): + logger.info("Audio input task started") + while True: + try: + audio_data = await self._client.get_next_audio_frame() + if audio_data: + audio_frame_event, participant_id = audio_data + pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event) + await self.push_audio_frame(pipecat_audio_frame) + await self.push_frame( + pipecat_audio_frame + ) # TODO: ensure audio frames are pushed with the default BaseInputTransport.push_audio_frame() + except asyncio.CancelledError: + logger.info("Audio input task cancelled") + break + except Exception as e: + logger.error(f"Error in audio input task: {e}") + + def _convert_livekit_audio_to_pipecat( + self, audio_frame_event: rtc.AudioFrameEvent + ) -> AudioRawFrame: + audio_frame = audio_frame_event.frame + audio_data = np.frombuffer(audio_frame.data, dtype=np.int16) + original_sample_rate = audio_frame.sample_rate + + # Allow 8kHz and 16kHz, convert anything else to 16kHz + if original_sample_rate not in [8000, 16000]: + audio_data = self._resample_audio(audio_data, original_sample_rate, 16000) + sample_rate = 16000 + else: + sample_rate = original_sample_rate + + if sample_rate != self._current_sample_rate: + self._current_sample_rate = sample_rate + self._vad_analyzer = VADAnalyzer( + sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels + ) + + return AudioRawFrame( + audio=audio_data.tobytes(), + sample_rate=sample_rate, + num_channels=audio_frame.num_channels, + ) + + def _resample_audio( + self, audio_data: np.ndarray, original_rate: int, target_rate: int + ) -> np.ndarray: + num_samples = int(len(audio_data) * target_rate / original_rate) + resampled_audio = signal.resample(audio_data, num_samples) + return resampled_audio.astype(np.int16) + + +class LiveKitOutputTransport(BaseOutputTransport): + def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwargs): + super().__init__(params, **kwargs) + self._client = client + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._client.connect() + logger.info("LiveKitOutputTransport started") + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._client.disconnect() + logger.info("LiveKitOutputTransport stopped") + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, EndFrame): + await self.stop(frame) + else: + await super().process_frame(frame, direction) + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._client.disconnect() + + async def send_message(self, frame: TransportMessageFrame): + if isinstance(frame, LiveKitTransportMessageFrame): + await self._client.send_data(frame.message.encode(), frame.participant_id) + else: + await self._client.send_data(frame.message.encode()) + + async def send_metrics(self, frame: MetricsFrame): + metrics = {} + for d in frame.data: + if isinstance(d, TTFBMetricsData): + if "ttfb" not in metrics: + metrics["ttfb"] = [] + metrics["ttfb"].append(d.model_dump()) + elif isinstance(d, ProcessingMetricsData): + if "processing" not in metrics: + metrics["processing"] = [] + metrics["processing"].append(d.model_dump()) + elif isinstance(d, LLMUsageMetricsData): + if "tokens" not in metrics: + metrics["tokens"] = [] + metrics["tokens"].append(d.value.model_dump(exclude_none=True)) + elif isinstance(d, TTSUsageMetricsData): + if "characters" not in metrics: + metrics["characters"] = [] + metrics["characters"].append(d.model_dump()) + + message = LiveKitTransportMessageFrame( + message={"type": "pipecat-metrics", "metrics": metrics} + ) + await self._client.send_data(str(message.message).encode()) + + async def write_raw_audio_frames(self, frames: bytes): + livekit_audio = self._convert_pipecat_audio_to_livekit(frames) + await self._client.publish_audio(livekit_audio) + + def _convert_pipecat_audio_to_livekit(self, pipecat_audio: bytes) -> rtc.AudioFrame: + bytes_per_sample = 2 # Assuming 16-bit audio + total_samples = len(pipecat_audio) // bytes_per_sample + samples_per_channel = total_samples // self._params.audio_out_channels + + return rtc.AudioFrame( + data=pipecat_audio, + sample_rate=self._params.audio_out_sample_rate, + num_channels=self._params.audio_out_channels, + samples_per_channel=samples_per_channel, + ) + + +class LiveKitTransport(BaseTransport): + def __init__( + self, + url: str, + token: str, + room_name: str, + params: LiveKitParams = LiveKitParams(), + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ): + super().__init__(input_name=input_name, output_name=output_name, loop=loop) + + self._url = url + self._token = token + self._room_name = room_name + self._params = params + + self._client = LiveKitTransportClient( + url, token, room_name, self._params, self._create_callbacks(), self._loop + ) + self._input: LiveKitInputTransport | None = None + self._output: LiveKitOutputTransport | None = None + + self._register_event_handler("on_connected") + self._register_event_handler("on_disconnected") + self._register_event_handler("on_participant_connected") + self._register_event_handler("on_participant_disconnected") + self._register_event_handler("on_audio_track_subscribed") + self._register_event_handler("on_audio_track_unsubscribed") + self._register_event_handler("on_data_received") + self._register_event_handler("on_first_participant_joined") + self._register_event_handler("on_participant_left") + self._register_event_handler("on_call_state_updated") + + def _create_callbacks(self) -> LiveKitCallbacks: + return LiveKitCallbacks( + on_connected=self._on_connected, + on_disconnected=self._on_disconnected, + on_participant_connected=self._on_participant_connected, + on_participant_disconnected=self._on_participant_disconnected, + on_audio_track_subscribed=self._on_audio_track_subscribed, + on_audio_track_unsubscribed=self._on_audio_track_unsubscribed, + on_data_received=self._on_data_received, + ) + + def input(self) -> FrameProcessor: + if not self._input: + self._input = LiveKitInputTransport(self._client, self._params, name=self._input_name) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = LiveKitOutputTransport( + self._client, self._params, name=self._output_name + ) + return self._output + + @property + def participant_id(self) -> str: + return self._client.participant_id + + async def send_audio(self, frame: AudioRawFrame): + if self._output: + await self._output.process_frame(frame, FrameDirection.DOWNSTREAM) + + def get_participants(self) -> List[str]: + return self._client.get_participants() + + async def get_participant_metadata(self, participant_id: str) -> dict: + return await self._client.get_participant_metadata(participant_id) + + async def set_metadata(self, metadata: str): + await self._client.set_participant_metadata(metadata) + + async def mute_participant(self, participant_id: str): + await self._client.mute_participant(participant_id) + + async def unmute_participant(self, participant_id: str): + await self._client.unmute_participant(participant_id) + + async def _on_connected(self): + await self._call_event_handler("on_connected") + + async def _on_disconnected(self): + await self._call_event_handler("on_disconnected") + # Attempt to reconnect + try: + await self._client.connect() + await self._call_event_handler("on_connected") + except Exception as e: + logger.error(f"Failed to reconnect: {e}") + + async def _on_participant_connected(self, participant_id: str): + await self._call_event_handler("on_participant_connected", participant_id) + if len(self.get_participants()) == 1: + await self._call_event_handler("on_first_participant_joined", participant_id) + + async def _on_participant_disconnected(self, participant_id: str): + await self._call_event_handler("on_participant_disconnected", participant_id) + await self._call_event_handler("on_participant_left", participant_id, "disconnected") + if self._input: + await self._input.process_frame(EndFrame(), FrameDirection.DOWNSTREAM) + if self._output: + await self._output.process_frame(EndFrame(), FrameDirection.DOWNSTREAM) + + async def _on_audio_track_subscribed(self, participant_id: str): + await self._call_event_handler("on_audio_track_subscribed", participant_id) + participant = self._client._room.remote_participants.get(participant_id) + if participant: + for publication in participant.audio_tracks.values(): + self._client._on_track_subscribed_wrapper( + publication.track, publication, participant + ) + + async def _on_audio_track_unsubscribed(self, participant_id: str): + await self._call_event_handler("on_audio_track_unsubscribed", participant_id) + + async def _on_data_received(self, data: bytes, participant_id: str): + if self._input: + await self._input.push_app_message(data.decode(), participant_id) + await self._call_event_handler("on_data_received", data, participant_id) + + async def send_message(self, message: str, participant_id: str | None = None): + if self._output: + frame = LiveKitTransportMessageFrame(message=message, participant_id=participant_id) + await self._output.send_message(frame) + + async def cleanup(self): + if self._input: + await self._input.cleanup() + if self._output: + await self._output.cleanup() + await self._client.disconnect() + + async def on_room_event(self, event): + # Handle room events + pass + + async def on_participant_event(self, event): + # Handle participant events + pass + + async def on_track_event(self, event): + # Handle track events + pass + + async def _on_call_state_updated(self, state: str): + await self._call_event_handler("on_call_state_updated", self, state)