Skip to content

Commit

Permalink
format with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
joachimchauvet committed Sep 24, 2024
1 parent fa609f1 commit b6e1d6e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 20 deletions.
19 changes: 15 additions & 4 deletions examples/foundational/01b-livekit-audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def generate_token(room_name: str, participant_name: str, api_key: str, api_secr

async def configure_livekit():
parser = argparse.ArgumentParser(description="LiveKit AI SDK Bot Sample")
parser.add_argument("-r", "--room", type=str, required=False, help="Name of the LiveKit room to join")
parser.add_argument(
"-r", "--room", type=str, required=False, help="Name of the LiveKit room to join"
)
parser.add_argument("-u", "--url", type=str, required=False, help="URL of the LiveKit server")

args, unknown = parser.parse_known_args()
Expand All @@ -56,7 +58,9 @@ async def configure_livekit():
)

if not api_key or not api_secret:
raise Exception("LIVEKIT_API_KEY and LIVEKIT_API_SECRET must be set in environment variables.")
raise Exception(
"LIVEKIT_API_KEY and LIVEKIT_API_SECRET must be set in environment variables."
)

token = generate_token(room_name, "Say One Thing", api_key, api_secret)

Expand All @@ -71,7 +75,10 @@ async def main():
(url, token, room_name) = await configure_livekit()

transport = LiveKitTransport(
url=url, token=token, room_name=room_name, params=LiveKitParams(audio_out_enabled=True, audio_out_sample_rate=16000)
url=url,
token=token,
room_name=room_name,
params=LiveKitParams(audio_out_enabled=True, audio_out_sample_rate=16000),
)

tts = CartesiaTTSService(
Expand All @@ -88,7 +95,11 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant_id):
await asyncio.sleep(1)
await task.queue_frame(TextFrame("Hello there! How are you doing today? Would you like to talk about the weather?"))
await task.queue_frame(
TextFrame(
"Hello there! How are you doing today? Would you like to talk about the weather?"
)
)

await runner.run(task)

Expand Down
70 changes: 54 additions & 16 deletions src/pipecat/transports/services/livekit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,12 @@ async def connect(self):
logger.info(f"Connected to {self._room_name}")

# Set up audio source and track
self._audio_source = rtc.AudioSource(self._params.audio_out_sample_rate, self._params.audio_out_channels)
self._audio_track = rtc.LocalAudioTrack.create_audio_track("pipecat-audio", self._audio_source)
self._audio_source = rtc.AudioSource(
self._params.audio_out_sample_rate, self._params.audio_out_channels
)
self._audio_track = rtc.LocalAudioTrack.create_audio_track(
"pipecat-audio", self._audio_source
)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await self._room.local_participant.publish_track(self._audio_track, options)
Expand All @@ -136,7 +140,9 @@ async def send_data(self, data: bytes, participant_id: str | None = None):

try:
if participant_id:
await self._room.local_participant.publish_data(data, reliable=True, destination_identities=[participant_id])
await self._room.local_participant.publish_data(
data, reliable=True, destination_identities=[participant_id]
)
else:
await self._room.local_participant.publish_data(data, reliable=True)
except Exception as e:
Expand Down Expand Up @@ -190,12 +196,18 @@ def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipan
asyncio.create_task(self._async_on_participant_disconnected(participant))

def _on_track_subscribed_wrapper(
self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant
self,
track: rtc.Track,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
asyncio.create_task(self._async_on_track_subscribed(track, publication, participant))

def _on_track_unsubscribed_wrapper(
self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant
self,
track: rtc.Track,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
asyncio.create_task(self._async_on_track_unsubscribed(track, publication, participant))

Expand All @@ -218,7 +230,10 @@ async def _async_on_participant_disconnected(self, participant: rtc.RemotePartic
await self._callbacks.on_participant_disconnected(participant.sid)

async def _async_on_track_subscribed(
self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant
self,
track: rtc.Track,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
if track.kind == rtc.TrackKind.KIND_AUDIO:
logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}")
Expand All @@ -227,7 +242,10 @@ async def _async_on_track_subscribed(
asyncio.create_task(self._process_audio_stream(audio_stream, participant.sid))

async def _async_on_track_unsubscribed(
self, track: rtc.Track, publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant
self,
track: rtc.Track,
publication: rtc.RemoteTrackPublication,
participant: rtc.RemoteParticipant,
):
logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}")
if track.kind == rtc.TrackKind.KIND_AUDIO:
Expand Down Expand Up @@ -268,7 +286,9 @@ def __init__(self, client: LiveKitTransportClient, params: LiveKitParams, **kwar
self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer
self._current_sample_rate: int = params.audio_in_sample_rate
if params.vad_enabled and not params.vad_analyzer:
self._vad_analyzer = VADAnalyzer(sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels)
self._vad_analyzer = VADAnalyzer(
sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels
)

async def start(self, frame: StartFrame):
await super().start(frame)
Expand Down Expand Up @@ -326,7 +346,9 @@ async def _audio_in_task_handler(self):
except Exception as e:
logger.error(f"Error in audio input task: {e}")

def _convert_livekit_audio_to_pipecat(self, audio_frame_event: rtc.AudioFrameEvent) -> AudioRawFrame:
def _convert_livekit_audio_to_pipecat(
self, audio_frame_event: rtc.AudioFrameEvent
) -> AudioRawFrame:
audio_frame = audio_frame_event.frame
audio_data = np.frombuffer(audio_frame.data, dtype=np.int16)
original_sample_rate = audio_frame.sample_rate
Expand All @@ -340,11 +362,19 @@ def _convert_livekit_audio_to_pipecat(self, audio_frame_event: rtc.AudioFrameEve

if sample_rate != self._current_sample_rate:
self._current_sample_rate = sample_rate
self._vad_analyzer = VADAnalyzer(sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels)
self._vad_analyzer = VADAnalyzer(
sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels
)

return AudioRawFrame(audio=audio_data.tobytes(), sample_rate=sample_rate, num_channels=audio_frame.num_channels)
return AudioRawFrame(
audio=audio_data.tobytes(),
sample_rate=sample_rate,
num_channels=audio_frame.num_channels,
)

def _resample_audio(self, audio_data: np.ndarray, original_rate: int, target_rate: int) -> np.ndarray:
def _resample_audio(
self, audio_data: np.ndarray, original_rate: int, target_rate: int
) -> np.ndarray:
num_samples = int(len(audio_data) * target_rate / original_rate)
resampled_audio = signal.resample(audio_data, num_samples)
return resampled_audio.astype(np.int16)
Expand Down Expand Up @@ -392,7 +422,9 @@ async def send_metrics(self, frame: MetricsFrame):
if hasattr(frame, "characters"):
metrics["characters"] = frame.characters

message = LiveKitTransportMessageFrame(message={"type": "pipecat-metrics", "metrics": metrics})
message = LiveKitTransportMessageFrame(
message={"type": "pipecat-metrics", "metrics": metrics}
)
await self._client.send_data(str(message.message).encode())

async def write_raw_audio_frames(self, frames: bytes):
Expand Down Expand Up @@ -430,7 +462,9 @@ def __init__(
self._room_name = room_name
self._params = params

self._client = LiveKitTransportClient(url, token, room_name, self._params, self._create_callbacks(), self._loop)
self._client = LiveKitTransportClient(
url, token, room_name, self._params, self._create_callbacks(), self._loop
)
self._input: LiveKitInputTransport | None = None
self._output: LiveKitOutputTransport | None = None

Expand Down Expand Up @@ -463,7 +497,9 @@ def input(self) -> FrameProcessor:

def output(self) -> FrameProcessor:
if not self._output:
self._output = LiveKitOutputTransport(self._client, self._params, name=self._output_name)
self._output = LiveKitOutputTransport(
self._client, self._params, name=self._output_name
)
return self._output

@property
Expand Down Expand Up @@ -519,7 +555,9 @@ async def _on_audio_track_subscribed(self, participant_id: str):
participant = self._client._room.remote_participants.get(participant_id)
if participant:
for publication in participant.audio_tracks.values():
self._client._on_track_subscribed_wrapper(publication.track, publication, participant)
self._client._on_track_subscribed_wrapper(
publication.track, publication, participant
)

async def _on_audio_track_unsubscribed(self, participant_id: str):
await self._call_event_handler("on_audio_track_unsubscribed", participant_id)
Expand Down

0 comments on commit b6e1d6e

Please sign in to comment.