From 9590cc2fbc0d5ad6a784e2a3b743005eb7beeb17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 5 Apr 2024 13:34:34 -0700 Subject: [PATCH] examples: fix whisper examples --- .../foundational/13-whisper-transcription.py | 26 ++++++++++++++----- examples/foundational/13a-whisper-local.py | 10 +++---- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/foundational/13-whisper-transcription.py b/examples/foundational/13-whisper-transcription.py index c634a9c0d..054cf8450 100644 --- a/examples/foundational/13-whisper-transcription.py +++ b/examples/foundational/13-whisper-transcription.py @@ -1,12 +1,16 @@ import asyncio import logging +from dailyai.pipeline.frames import EndFrame, TranscriptionFrame from dailyai.transports.daily_transport import DailyTransport from dailyai.services.whisper_ai_services import WhisperSTTService from dailyai.pipeline.pipeline import Pipeline from runner import configure +from dotenv import load_dotenv +load_dotenv(override=True) + logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") logger = logging.getLogger("dailyai") logger.setLevel(logging.DEBUG) @@ -26,17 +30,27 @@ async def main(room_url: str): stt = WhisperSTTService() transcription_output_queue = asyncio.Queue() + transport_done = asyncio.Event() - pipeline = Pipeline([stt]) - pipeline.set_sink(transcription_output_queue) + pipeline = Pipeline([stt], source=transport.receive_queue, sink=transcription_output_queue) async def handle_transcription(): print("`````````TRANSCRIPTION`````````") - while True: + while not transport_done.is_set(): item = await transcription_output_queue.get() - print(item.text) - - await asyncio.gather(transport.run(pipeline), handle_transcription()) + print("got item from queue", item) + if isinstance(item, TranscriptionFrame): + print(item.text) + elif isinstance(item, EndFrame): + break + print("handle_transcription done") + + async def run_until_done(): + await transport.run() + transport_done.set() + print("run_until_done done") + + await asyncio.gather(run_until_done(), pipeline.run_pipeline(), handle_transcription()) if __name__ == "__main__": diff --git a/examples/foundational/13a-whisper-local.py b/examples/foundational/13a-whisper-local.py index 00562b402..93ba93e4b 100644 --- a/examples/foundational/13a-whisper-local.py +++ b/examples/foundational/13a-whisper-local.py @@ -15,11 +15,10 @@ async def main(): meeting_duration_minutes = 1 transport = LocalTransport( - mic_enabled=False, + mic_enabled=True, camera_enabled=False, speaker_enabled=True, duration_minutes=meeting_duration_minutes, - start_transcription=False, ) stt = WhisperSTTService() @@ -27,8 +26,7 @@ async def main(): transcription_output_queue = asyncio.Queue() transport_done = asyncio.Event() - pipeline = Pipeline([stt]) - pipeline.set_sink(transcription_output_queue) + pipeline = Pipeline([stt], source=transport.receive_queue, sink=transcription_output_queue) async def handle_transcription(): print("`````````TRANSCRIPTION`````````") @@ -42,11 +40,11 @@ async def handle_transcription(): print("handle_transcription done") async def run_until_done(): - await transport.run(pipeline) + await transport.run() transport_done.set() print("run_until_done done") - await asyncio.gather(run_until_done(), handle_transcription()) + await asyncio.gather(run_until_done(), pipeline.run_pipeline(), handle_transcription()) if __name__ == "__main__":