From b6e1d6e6aedd61831e6b6cef6cb1a1b462891bb9 Mon Sep 17 00:00:00 2001 From: joachimchauvet Date: Tue, 24 Sep 2024 10:21:02 +0300 Subject: [PATCH] format with ruff --- examples/foundational/01b-livekit-audio.py | 19 ++++-- src/pipecat/transports/services/livekit.py | 70 +++++++++++++++++----- 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/examples/foundational/01b-livekit-audio.py b/examples/foundational/01b-livekit-audio.py index 8b3f38cd8..68e0d2803 100644 --- a/examples/foundational/01b-livekit-audio.py +++ b/examples/foundational/01b-livekit-audio.py @@ -35,7 +35,9 @@ def generate_token(room_name: str, participant_name: str, api_key: str, api_secr async def configure_livekit(): parser = argparse.ArgumentParser(description="LiveKit AI SDK Bot Sample") - parser.add_argument("-r", "--room", type=str, required=False, help="Name of the LiveKit room to join") + parser.add_argument( + "-r", "--room", type=str, required=False, help="Name of the LiveKit room to join" + ) parser.add_argument("-u", "--url", type=str, required=False, help="URL of the LiveKit server") args, unknown = parser.parse_known_args() @@ -56,7 +58,9 @@ async def configure_livekit(): ) if not api_key or not api_secret: - raise Exception("LIVEKIT_API_KEY and LIVEKIT_API_SECRET must be set in environment variables.") + raise Exception( + "LIVEKIT_API_KEY and LIVEKIT_API_SECRET must be set in environment variables." + ) token = generate_token(room_name, "Say One Thing", api_key, api_secret) @@ -71,7 +75,10 @@ async def main(): (url, token, room_name) = await configure_livekit() transport = LiveKitTransport( - url=url, token=token, room_name=room_name, params=LiveKitParams(audio_out_enabled=True, audio_out_sample_rate=16000) + url=url, + token=token, + room_name=room_name, + params=LiveKitParams(audio_out_enabled=True, audio_out_sample_rate=16000), ) tts = CartesiaTTSService( @@ -88,7 +95,11 @@ async def main(): @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant_id): await asyncio.sleep(1) - await task.queue_frame(TextFrame("Hello there! How are you doing today? Would you like to talk about the weather?")) + await task.queue_frame( + TextFrame( + "Hello there! How are you doing today? Would you like to talk about the weather?" + ) + ) await runner.run(task) diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index b0c630c90..1dcb069ee 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -109,8 +109,12 @@ async def connect(self): logger.info(f"Connected to {self._room_name}") # Set up audio source and track - self._audio_source = rtc.AudioSource(self._params.audio_out_sample_rate, self._params.audio_out_channels) - self._audio_track = rtc.LocalAudioTrack.create_audio_track("pipecat-audio", self._audio_source) + self._audio_source = rtc.AudioSource( + self._params.audio_out_sample_rate, self._params.audio_out_channels + ) + self._audio_track = rtc.LocalAudioTrack.create_audio_track( + "pipecat-audio", self._audio_source + ) options = rtc.TrackPublishOptions() options.source = rtc.TrackSource.SOURCE_MICROPHONE await self._room.local_participant.publish_track(self._audio_track, options) @@ -136,7 +140,9 @@ async def send_data(self, data: bytes, participant_id: str | None = None): try: if participant_id: - await self._room.local_participant.publish_data(data, reliable=True, destination_identities=[participant_id]) + await self._room.local_participant.publish_data( + data, reliable=True, destination_identities=[participant_id] + ) else: await self._room.local_participant.publish_data(data, reliable=True) except Exception as e: @@ -190,12 +196,18 @@ def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipan asyncio.create_task(self._async_on_participant_disconnected(participant)) def _on_track_subscribed_wrapper( - self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, ): asyncio.create_task(self._async_on_track_subscribed(track, publication, participant)) def _on_track_unsubscribed_wrapper( - self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, ): asyncio.create_task(self._async_on_track_unsubscribed(track, publication, participant)) @@ -218,7 +230,10 @@ async def _async_on_participant_disconnected(self, participant: rtc.RemotePartic await self._callbacks.on_participant_disconnected(participant.sid) async def _async_on_track_subscribed( - self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, ): if track.kind == rtc.TrackKind.KIND_AUDIO: logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}") @@ -227,7 +242,10 @@ async def _async_on_track_subscribed( asyncio.create_task(self._process_audio_stream(audio_stream, participant.sid)) async def _async_on_track_unsubscribed( - self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant + self, + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + participant: rtc.RemoteParticipant, ): logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}") if track.kind == rtc.TrackKind.KIND_AUDIO: @@ -268,7 +286,9 @@ def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwar self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer self._current_sample_rate: int = params.audio_in_sample_rate if params.vad_enabled and not params.vad_analyzer: - self._vad_analyzer = VADAnalyzer(sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels) + self._vad_analyzer = VADAnalyzer( + sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels + ) async def start(self, frame: StartFrame): await super().start(frame) @@ -326,7 +346,9 @@ async def _audio_in_task_handler(self): except Exception as e: logger.error(f"Error in audio input task: {e}") - def _convert_livekit_audio_to_pipecat(self, audio_frame_event: rtc.AudioFrameEvent) -> AudioRawFrame: + def _convert_livekit_audio_to_pipecat( + self, audio_frame_event: rtc.AudioFrameEvent + ) -> AudioRawFrame: audio_frame = audio_frame_event.frame audio_data = np.frombuffer(audio_frame.data, dtype=np.int16) original_sample_rate = audio_frame.sample_rate @@ -340,11 +362,19 @@ def _convert_livekit_audio_to_pipecat(self, audio_frame_event: rtc.AudioFrameEve if sample_rate != self._current_sample_rate: self._current_sample_rate = sample_rate - self._vad_analyzer = VADAnalyzer(sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels) + self._vad_analyzer = VADAnalyzer( + sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels + ) - return AudioRawFrame(audio=audio_data.tobytes(), sample_rate=sample_rate, num_channels=audio_frame.num_channels) + return AudioRawFrame( + audio=audio_data.tobytes(), + sample_rate=sample_rate, + num_channels=audio_frame.num_channels, + ) - def _resample_audio(self, audio_data: np.ndarray, original_rate: int, target_rate: int) -> np.ndarray: + def _resample_audio( + self, audio_data: np.ndarray, original_rate: int, target_rate: int + ) -> np.ndarray: num_samples = int(len(audio_data) * target_rate / original_rate) resampled_audio = signal.resample(audio_data, num_samples) return resampled_audio.astype(np.int16) @@ -392,7 +422,9 @@ async def send_metrics(self, frame: MetricsFrame): if hasattr(frame, "characters"): metrics["characters"] = frame.characters - message = LiveKitTransportMessageFrame(message={"type": "pipecat-metrics", "metrics": metrics}) + message = LiveKitTransportMessageFrame( + message={"type": "pipecat-metrics", "metrics": metrics} + ) await self._client.send_data(str(message.message).encode()) async def write_raw_audio_frames(self, frames: bytes): @@ -430,7 +462,9 @@ def __init__( self._room_name = room_name self._params = params - self._client = LiveKitTransportClient(url, token, room_name, self._params, self._create_callbacks(), self._loop) + self._client = LiveKitTransportClient( + url, token, room_name, self._params, self._create_callbacks(), self._loop + ) self._input: LiveKitInputTransport | None = None self._output: LiveKitOutputTransport | None = None @@ -463,7 +497,9 @@ def input(self) -> FrameProcessor: def output(self) -> FrameProcessor: if not self._output: - self._output = LiveKitOutputTransport(self._client, self._params, name=self._output_name) + self._output = LiveKitOutputTransport( + self._client, self._params, name=self._output_name + ) return self._output @property @@ -519,7 +555,9 @@ async def _on_audio_track_subscribed(self, participant_id: str): participant = self._client._room.remote_participants.get(participant_id) if participant: for publication in participant.audio_tracks.values(): - self._client._on_track_subscribed_wrapper(publication.track, publication, participant) + self._client._on_track_subscribed_wrapper( + publication.track, publication, participant + ) async def _on_audio_track_unsubscribed(self, participant_id: str): await self._call_event_handler("on_audio_track_unsubscribed", participant_id)