From 420bb2c48ba3ceee551b28f2c0f7d1640634245f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 24 Apr 2024 18:29:24 -0700 Subject: [PATCH] wip proposal: initial commit --- examples/foundational/01-say-one-thing.py | 35 +- examples/foundational/01a-local-transport.py | 2 +- examples/foundational/02-llm-say-one-thing.py | 31 +- examples/foundational/03-still-frame.py | 23 +- .../foundational/04-utterance-and-speech.py | 2 +- .../foundational/05-sync-speech-and-image.py | 122 +-- .../05a-local-sync-speech-and-text.py | 2 +- .../foundational/06-listen-and-respond.py | 2 +- examples/foundational/06a-image-sync.py | 2 +- examples/foundational/07-interruptible.py | 2 +- examples/foundational/08-bots-arguing.py | 2 +- examples/foundational/10-wake-word.py | 2 +- examples/foundational/11-sound-effects.py | 2 +- examples/foundational/12-describe-video.py | 2 +- examples/foundational/new-test-2.py | 71 ++ examples/foundational/new-test.py | 66 ++ .../foundational/websocket-server/sample.py | 2 +- examples/image-gen.py | 3 +- examples/starter-apps/chatbot.py | 2 +- examples/starter-apps/patient-intake.py | 2 +- examples/starter-apps/storybot.py | 2 +- src/dailyai/aggregators/__init__.py | 0 src/dailyai/aggregators/gated.py | 67 ++ src/dailyai/aggregators/llm_context.py | 89 ++ src/dailyai/aggregators/llm_response.py | 188 ++++ .../openai_llm_context.py} | 75 +- src/dailyai/aggregators/sentence.py | 49 ++ src/dailyai/aggregators/user_response.py | 138 +++ src/dailyai/aggregators/vision_image_frame.py | 42 + src/dailyai/frames/__init__.py | 0 src/dailyai/{pipeline => frames}/frames.proto | 6 + src/dailyai/frames/frames.py | 406 +++++++++ src/dailyai/frames/openai_frames.py | 15 + .../protobufs/frames_pb2.py | 0 src/dailyai/pipeline/aggregators.py | 549 ------------ src/dailyai/pipeline/frame_processor.py | 34 - src/dailyai/pipeline/frames.py | 253 ------ src/dailyai/pipeline/openai_frames.py | 12 - src/dailyai/pipeline/parallel_pipeline.py | 124 +++ src/dailyai/pipeline/pipeline.py | 207 ++--- src/dailyai/processors/__init__.py | 0 src/dailyai/processors/frame_processor.py | 58 ++ src/dailyai/processors/passthrough.py | 36 + src/dailyai/processors/text_transformer.py | 36 + src/dailyai/services/ai_services.py | 193 ++-- src/dailyai/services/anthropic.py | 49 ++ src/dailyai/services/anthropic_llm_service.py | 44 - .../{azure_ai_services.py => azure.py} | 0 .../{deepgram_ai_services.py => deepgram.py} | 23 +- src/dailyai/services/deepgram_ai_service.py | 36 - src/dailyai/services/elevenlabs.py | 58 ++ src/dailyai/services/elevenlabs_ai_service.py | 46 - .../services/{fal_ai_services.py => fal.py} | 28 +- ...{fireworks_ai_services.py => fireworks.py} | 16 +- src/dailyai/services/live_stream.py | 323 +++++++ .../{moondream_ai_service.py => moondream.py} | 25 +- .../{ollama_ai_services.py => ollama.py} | 8 +- src/dailyai/services/open_ai_services.py | 58 -- .../{openai_api_llm_service.py => openai.py} | 115 ++- src/dailyai/services/openai_llm_context.py | 61 -- .../{playht_ai_service.py => playht.py} | 44 +- .../{whisper_ai_services.py => whisper.py} | 0 src/dailyai/transports/abstract_transport.py | 42 - src/dailyai/transports/daily_transport.py | 822 ++++++++++++------ .../transports/live_stream_transport.py | 196 +++++ src/dailyai/utils/__init__.py | 0 src/dailyai/utils/utils.py | 21 + src/dailyai/vad/__init__.py | 0 src/dailyai/vad/silero_vad.py | 65 ++ src/dailyai/vad/vad_analyzer.py | 104 +++ 70 files changed, 3301 insertions(+), 1839 deletions(-) create mode 100644 examples/foundational/new-test-2.py create mode 100644 examples/foundational/new-test.py create mode 100644 src/dailyai/aggregators/__init__.py create mode 100644 src/dailyai/aggregators/gated.py create mode 100644 src/dailyai/aggregators/llm_context.py create mode 100644 src/dailyai/aggregators/llm_response.py rename src/dailyai/{pipeline/opeanai_llm_aggregator.py => aggregators/openai_llm_context.py} (63%) create mode 100644 src/dailyai/aggregators/sentence.py create mode 100644 src/dailyai/aggregators/user_response.py create mode 100644 src/dailyai/aggregators/vision_image_frame.py create mode 100644 src/dailyai/frames/__init__.py rename src/dailyai/{pipeline => frames}/frames.proto (81%) create mode 100644 src/dailyai/frames/frames.py create mode 100644 src/dailyai/frames/openai_frames.py rename src/dailyai/{pipeline => frames}/protobufs/frames_pb2.py (100%) delete mode 100644 src/dailyai/pipeline/aggregators.py delete mode 100644 src/dailyai/pipeline/frame_processor.py delete mode 100644 src/dailyai/pipeline/frames.py delete mode 100644 src/dailyai/pipeline/openai_frames.py create mode 100644 src/dailyai/pipeline/parallel_pipeline.py create mode 100644 src/dailyai/processors/__init__.py create mode 100644 src/dailyai/processors/frame_processor.py create mode 100644 src/dailyai/processors/passthrough.py create mode 100644 src/dailyai/processors/text_transformer.py create mode 100644 src/dailyai/services/anthropic.py delete mode 100644 src/dailyai/services/anthropic_llm_service.py rename src/dailyai/services/{azure_ai_services.py => azure.py} (100%) rename src/dailyai/services/{deepgram_ai_services.py => deepgram.py} (65%) delete mode 100644 src/dailyai/services/deepgram_ai_service.py create mode 100644 src/dailyai/services/elevenlabs.py delete mode 100644 src/dailyai/services/elevenlabs_ai_service.py rename src/dailyai/services/{fal_ai_services.py => fal.py} (72%) rename src/dailyai/services/{fireworks_ai_services.py => fireworks.py} (50%) create mode 100644 src/dailyai/services/live_stream.py rename src/dailyai/services/{moondream_ai_service.py => moondream.py} (68%) rename src/dailyai/services/{ollama_ai_services.py => ollama.py} (59%) delete mode 100644 src/dailyai/services/open_ai_services.py rename src/dailyai/services/{openai_api_llm_service.py => openai.py} (60%) delete mode 100644 src/dailyai/services/openai_llm_context.py rename src/dailyai/services/{playht_ai_service.py => playht.py} (65%) rename src/dailyai/services/{whisper_ai_services.py => whisper.py} (100%) delete mode 100644 src/dailyai/transports/abstract_transport.py create mode 100644 src/dailyai/transports/live_stream_transport.py create mode 100644 src/dailyai/utils/__init__.py create mode 100644 src/dailyai/utils/utils.py create mode 100644 src/dailyai/vad/__init__.py create mode 100644 src/dailyai/vad/silero_vad.py create mode 100644 src/dailyai/vad/vad_analyzer.py diff --git a/examples/foundational/01-say-one-thing.py b/examples/foundational/01-say-one-thing.py index aecda2963..93196454c 100644 --- a/examples/foundational/01-say-one-thing.py +++ b/examples/foundational/01-say-one-thing.py @@ -1,22 +1,24 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio import aiohttp -import logging import os -from dailyai.pipeline.frames import EndFrame, TextFrame -from dailyai.pipeline.pipeline import Pipeline +from dailyai.frames.frames import TextFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.live_stream import LiveStream +from dailyai.services.elevenlabs import ElevenLabsTTSService from dailyai.transports.daily_transport import DailyTransport -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from runner import configure from dotenv import load_dotenv load_dotenv(override=True) -logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") -logger = logging.getLogger("dailyai") -logger.setLevel(logging.DEBUG) - async def main(room_url): async with aiohttp.ClientSession() as session: @@ -27,26 +29,25 @@ async def main(room_url): mic_enabled=True, ) + livestream = LiveStream(transport, mic_enabled=True) + tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), voice_id=os.getenv("ELEVENLABS_VOICE_ID"), ) - pipeline = Pipeline([tts]) + pipeline = Pipeline([tts, livestream]) # Register an event handler so we can play the audio when the # participant joins. - @transport.event_handler("on_participant_joined") - async def on_participant_joined(transport, participant): - if participant["info"]["isLocal"]: - return - + @livestream.event_handler("on_participant_joined") + async def on_participant_joined(livestream, participant): participant_name = participant["info"]["userName"] or '' - await pipeline.queue_frames([TextFrame("Hello there, " + participant_name + "!"), EndFrame()]) + await pipeline.queue_frames([TextFrame("Hello there, " + participant_name + "!")]) + await pipeline.stop() - await transport.run(pipeline) - del tts + await pipeline.run() if __name__ == "__main__": diff --git a/examples/foundational/01a-local-transport.py b/examples/foundational/01a-local-transport.py index 617459590..7be101f56 100644 --- a/examples/foundational/01a-local-transport.py +++ b/examples/foundational/01a-local-transport.py @@ -3,7 +3,7 @@ import logging import os -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.transports.local_transport import LocalTransport from dotenv import load_dotenv diff --git a/examples/foundational/02-llm-say-one-thing.py b/examples/foundational/02-llm-say-one-thing.py index a98815f1d..1a687511a 100644 --- a/examples/foundational/02-llm-say-one-thing.py +++ b/examples/foundational/02-llm-say-one-thing.py @@ -1,24 +1,26 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio import os -import logging import aiohttp -from dailyai.pipeline.frames import EndFrame, LLMMessagesFrame +from dailyai.frames.frames import LLMMessagesFrame from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.live_stream import LiveStream +from dailyai.services.elevenlabs import ElevenLabsTTSService +from dailyai.services.openai import OpenAILLMService from dailyai.transports.daily_transport import DailyTransport -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService -from dailyai.services.open_ai_services import OpenAILLMService from runner import configure from dotenv import load_dotenv load_dotenv(override=True) -logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") -logger = logging.getLogger("dailyai") -logger.setLevel(logging.DEBUG) - async def main(room_url): async with aiohttp.ClientSession() as session: @@ -29,6 +31,8 @@ async def main(room_url): mic_enabled=True, ) + livestream = LiveStream(transport, mic_enabled=True) + tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -45,13 +49,14 @@ async def main(room_url): "content": "You are an LLM in a WebRTC session, and this is a 'hello world' demo. Say hello to the world.", }] - pipeline = Pipeline([llm, tts]) + pipeline = Pipeline([llm, tts, livestream]) - @transport.event_handler("on_first_other_participant_joined") - async def on_first_other_participant_joined(transport, participant): - await pipeline.queue_frames([LLMMessagesFrame(messages), EndFrame()]) + @livestream.event_handler("on_first_participant_joined") + async def on_first_participant_joined(livestream, participant): + await pipeline.queue_frames([LLMMessagesFrame(messages)]) + await pipeline.stop() - await transport.run(pipeline) + await pipeline.run() if __name__ == "__main__": diff --git a/examples/foundational/03-still-frame.py b/examples/foundational/03-still-frame.py index 51ef47de8..d800e861c 100644 --- a/examples/foundational/03-still-frame.py +++ b/examples/foundational/03-still-frame.py @@ -1,12 +1,19 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio import aiohttp import logging import os -from dailyai.pipeline.frames import TextFrame +from dailyai.frames.frames import TextFrame from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.live_stream import LiveStream +from dailyai.services.fal import FalImageGenService from dailyai.transports.daily_transport import DailyTransport -from dailyai.services.fal_ai_services import FalImageGenService from runner import configure @@ -30,6 +37,8 @@ async def main(room_url): duration_minutes=1 ) + livestream = LiveStream(transport, camera_enabled=True) + imagegen = FalImageGenService( params=FalImageGenService.InputParams( image_size="square_hd" @@ -38,19 +47,17 @@ async def main(room_url): key=os.getenv("FAL_KEY"), ) - pipeline = Pipeline([imagegen]) + pipeline = Pipeline([imagegen, livestream]) - @transport.event_handler("on_first_other_participant_joined") + @livestream.event_handler("on_first_participant_joined") async def on_first_other_participant_joined(transport, participant): # Note that we do not put an EndFrame() item in the pipeline for this demo. # This means that the bot will stay in the channel until it times out. # An EndFrame() in the pipeline would cause the transport to shut # down. - await pipeline.queue_frames( - [TextFrame("a cat in the style of picasso")] - ) + await pipeline.queue_frames([TextFrame("a cat in the style of picasso")]) - await transport.run(pipeline) + await pipeline.run() if __name__ == "__main__": diff --git a/examples/foundational/04-utterance-and-speech.py b/examples/foundational/04-utterance-and-speech.py index 908be03b4..c52e5c960 100644 --- a/examples/foundational/04-utterance-and-speech.py +++ b/examples/foundational/04-utterance-and-speech.py @@ -10,7 +10,7 @@ from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService from dailyai.services.deepgram_ai_services import DeepgramTTSService from dailyai.pipeline.frames import EndPipeFrame, LLMMessagesFrame, TextFrame -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from runner import configure diff --git a/examples/foundational/05-sync-speech-and-image.py b/examples/foundational/05-sync-speech-and-image.py index 377e8579b..591d17600 100644 --- a/examples/foundational/05-sync-speech-and-image.py +++ b/examples/foundational/05-sync-speech-and-image.py @@ -1,64 +1,79 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio import aiohttp import os -import logging +import sys -from dataclasses import dataclass -from typing import AsyncGenerator +from typing import List -from dailyai.pipeline.aggregators import ( - GatedAggregator, - LLMFullResponseAggregator, - ParallelPipeline, - SentenceAggregator, -) -from dailyai.pipeline.frames import ( +from dailyai.frames.frames import ( + ControlFrame, Frame, + ImageRawFrame, TextFrame, EndFrame, - ImageFrame, LLMMessagesFrame, LLMResponseStartFrame, ) -from dailyai.pipeline.frame_processor import FrameProcessor - from dailyai.pipeline.pipeline import Pipeline +from dailyai.pipeline.parallel_pipeline import ParallelPipeline +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.aggregators.gated import GatedAggregator +from dailyai.aggregators.llm_response import LLMFullResponseAggregator +from dailyai.aggregators.sentence import SentenceAggregator +from dailyai.services.live_stream import LiveStream +from dailyai.services.openai import OpenAILLMService +from dailyai.services.elevenlabs import ElevenLabsTTSService +from dailyai.services.fal import FalImageGenService from dailyai.transports.daily_transport import DailyTransport -from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService -from dailyai.services.fal_ai_services import FalImageGenService from runner import configure +from loguru import logger + from dotenv import load_dotenv load_dotenv(override=True) -logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") -logger = logging.getLogger("dailyai") -logger.setLevel(logging.DEBUG) +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class MonthFrame(ControlFrame): + def __init__(self, month): + super().__init__() + self.metadata["month"] = month + + @ property + def month(self) -> str: + return self.metadata["month"] + def __str__(self): + return f"{self.name}(month: {self.month})" -@dataclass -class MonthFrame(Frame): month: str class MonthPrepender(FrameProcessor): def __init__(self): + super().__init__() self.most_recent_month = "Placeholder, month frame not yet received" self.prepend_to_next_text_frame = False - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, MonthFrame): self.most_recent_month = frame.month elif self.prepend_to_next_text_frame and isinstance(frame, TextFrame): - yield TextFrame(f"{self.most_recent_month}: {frame.text}") + await self.push_frame(TextFrame(f"{self.most_recent_month}: {frame.data}")) self.prepend_to_next_text_frame = False elif isinstance(frame, LLMResponseStartFrame): self.prepend_to_next_text_frame = True - yield frame - else: - yield frame + + await super().process_frame(frame, direction) async def main(room_url): @@ -69,11 +84,12 @@ async def main(room_url): "Month Narration Bot", mic_enabled=True, camera_enabled=True, - mic_sample_rate=16000, camera_width=1024, camera_height=1024, ) + livestream = LiveStream(transport, mic_enabled=True, camera_enabled=True) + tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -93,39 +109,40 @@ async def main(room_url): ) gated_aggregator = GatedAggregator( - gate_open_fn=lambda frame: isinstance( - frame, ImageFrame), gate_close_fn=lambda frame: isinstance( - frame, LLMResponseStartFrame), start_open=False, ) + gate_open_fn=lambda frame: isinstance(frame, ImageRawFrame), + gate_close_fn=lambda frame: isinstance(frame, LLMResponseStartFrame), + start_open=False + ) sentence_aggregator = SentenceAggregator() month_prepender = MonthPrepender() llm_full_response_aggregator = LLMFullResponseAggregator() - pipeline = Pipeline( - processors=[ - llm, - sentence_aggregator, - ParallelPipeline( - [[month_prepender, tts], [llm_full_response_aggregator, imagegen]] - ), - gated_aggregator, - ], - ) + pipeline = Pipeline([ + llm, + sentence_aggregator, + ParallelPipeline( + [month_prepender, tts], + [llm_full_response_aggregator, imagegen] + ), + gated_aggregator, + livestream + ]) frames = [] for month in [ "January", "February", - "March", - "April", - "May", - "June", - "July", - "August", - "September", - "October", - "November", - "December", + # "March", + # "April", + # "May", + # "June", + # "July", + # "August", + # "September", + # "October", + # "November", + # "December", ]: messages = [ { @@ -136,10 +153,11 @@ async def main(room_url): frames.append(MonthFrame(month)) frames.append(LLMMessagesFrame(messages)) - frames.append(EndFrame()) +# frames.append(EndFrame()) + await pipeline.queue_frames(frames) - await transport.run(pipeline, override_pipeline_source_queue=False) + await pipeline.run() if __name__ == "__main__": diff --git a/examples/foundational/05a-local-sync-speech-and-text.py b/examples/foundational/05a-local-sync-speech-and-text.py index 7c4cf0186..de4dee7b4 100644 --- a/examples/foundational/05a-local-sync-speech-and-text.py +++ b/examples/foundational/05a-local-sync-speech-and-text.py @@ -7,7 +7,7 @@ from dailyai.pipeline.frames import AudioFrame, URLImageFrame, LLMMessagesFrame, TextFrame from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.services.fal_ai_services import FalImageGenService from dailyai.transports.local_transport import LocalTransport diff --git a/examples/foundational/06-listen-and-respond.py b/examples/foundational/06-listen-and-respond.py index 0de16b270..b5a66a3d5 100644 --- a/examples/foundational/06-listen-and-respond.py +++ b/examples/foundational/06-listen-and-respond.py @@ -6,7 +6,7 @@ from dailyai.pipeline.pipeline import Pipeline from dailyai.transports.daily_transport import DailyTransport -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.services.open_ai_services import OpenAILLMService from dailyai.services.ai_services import FrameLogger from dailyai.pipeline.aggregators import ( diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 912586ec4..9431d1320 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -14,7 +14,7 @@ LLMUserContextAggregator, ) from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from runner import configure diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 3f35a3536..1f83e424b 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -11,7 +11,7 @@ from dailyai.services.ai_services import FrameLogger from dailyai.transports.daily_transport import DailyTransport from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from runner import configure diff --git a/examples/foundational/08-bots-arguing.py b/examples/foundational/08-bots-arguing.py index ea6208827..7c95e77ef 100644 --- a/examples/foundational/08-bots-arguing.py +++ b/examples/foundational/08-bots-arguing.py @@ -8,7 +8,7 @@ from dailyai.transports.daily_transport import DailyTransport from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.services.fal_ai_services import FalImageGenService from dailyai.pipeline.frames import AudioFrame, EndFrame, ImageFrame, LLMMessagesFrame, TextFrame diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py index 4d997153a..775ac26e8 100644 --- a/examples/foundational/10-wake-word.py +++ b/examples/foundational/10-wake-word.py @@ -9,7 +9,7 @@ from dailyai.transports.daily_transport import DailyTransport from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.pipeline.aggregators import ( LLMUserContextAggregator, LLMAssistantContextAggregator, diff --git a/examples/foundational/11-sound-effects.py b/examples/foundational/11-sound-effects.py index ee8a29ce3..04fe55be9 100644 --- a/examples/foundational/11-sound-effects.py +++ b/examples/foundational/11-sound-effects.py @@ -7,7 +7,7 @@ from dailyai.transports.daily_transport import DailyTransport from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.pipeline.aggregators import ( LLMUserContextAggregator, LLMAssistantContextAggregator, diff --git a/examples/foundational/12-describe-video.py b/examples/foundational/12-describe-video.py index 62e116020..a62e97533 100644 --- a/examples/foundational/12-describe-video.py +++ b/examples/foundational/12-describe-video.py @@ -9,7 +9,7 @@ from dailyai.pipeline.frames import Frame, TextFrame, UserImageRequestFrame from dailyai.pipeline.pipeline import Pipeline -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.services.moondream_ai_service import MoondreamService from dailyai.transports.daily_transport import DailyTransport diff --git a/examples/foundational/new-test-2.py b/examples/foundational/new-test-2.py new file mode 100644 index 000000000..86a3cad07 --- /dev/null +++ b/examples/foundational/new-test-2.py @@ -0,0 +1,71 @@ +import asyncio +import aiohttp +import os +import sys + +from dailyai.pipeline.frames import FrameTypes +from dailyai.pipeline.pipeline import Pipeline +from dailyai.processors.passthrough import Passthrough +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService +from dailyai.services.live_stream import LiveStream +from dailyai.services.open_ai_services import OpenAILLMService +from dailyai.transports.daily_transport import DailyTransport +from dailyai.processors.demuxer import Demuxer +from dailyai.processors.llm_response_aggregator import LLMUserResponseAggregator + +from dailyai.vad.silero_vad import SileroVADAnalyzer +from runner import configure + +from loguru import logger + +from dotenv import load_dotenv +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="TRACE") + + +async def main(room_url, token): + async with aiohttp.ClientSession() as session: + transport = DailyTransport( + room_url, token, + mic_enabled=True, + speaker_enabled=True, + transcription_enabled=True, + vad_analyzer=SileroVADAnalyzer() + ) + + livestream_source = LiveStream(transport, speaker_enabled=True) + livestream_sink = LiveStream(transport, mic_enabled=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way.", + }, + ] + + llm_user_response = LLMUserResponseAggregator(messages) + + llm = OpenAILLMService( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4-turbo-preview") + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + @livestream_source.event_handler("on_first_participant_joined") + async def on_first_participant_joined(livestream, participant): + livestream_source.capture_participant_transcription(participant["id"]) + + pipeline = Pipeline([livestream_source, llm_user_response, llm, tts livestream_sink]) + + await pipeline.run() + + +if __name__ == "__main__": + (url, token) = configure() + asyncio.run(main(url, token)) diff --git a/examples/foundational/new-test.py b/examples/foundational/new-test.py new file mode 100644 index 000000000..78bb444cd --- /dev/null +++ b/examples/foundational/new-test.py @@ -0,0 +1,66 @@ +import asyncio +import sys + +from dailyai.pipeline.frames import FrameTypes +from dailyai.pipeline.pipeline import Pipeline +from dailyai.processors.passthrough import Passthrough +from dailyai.services.live_stream import LiveStream +from dailyai.transports.daily_transport import DailyTransport +from dailyai.processors.demuxer import Demuxer +from dailyai.processors.llm_response_aggregator import LLMUserResponseAggregator + +from dailyai.vad.silero_vad import SileroVADAnalyzer +from runner import configure + +from loguru import logger + +from dotenv import load_dotenv +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="TRACE") + + +async def main(room_url, token): + transport = DailyTransport( + room_url, token, + camera_enabled=False, + camera_width=1280, + camera_height=720, + mic_enabled=False, + speaker_enabled=True, + video_capture_enabled=False, + transcription_enabled=True, + vad_analyzer=SileroVADAnalyzer() + ) + + media_passthrough = Passthrough([FrameTypes.AUDIO_RAW, FrameTypes.IMAGE_RAW]) + + livestream_source = LiveStream(transport, speaker_enabled=True) + livestream_sink = LiveStream(transport, camera_enabled=True, mic_enabled=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way.", + }, + ] + + llm_user_response = LLMUserResponseAggregator(messages) + + @livestream_source.event_handler("on_first_participant_joined") + async def on_first_participant_joined(livestream, participant): + livestream_source.capture_participant_transcription(participant["id"]) + # livestream_source.capture_participant_video(participant["id"]) + + pipeline = Pipeline([livestream_source, + Demuxer([llm_user_response], + [media_passthrough]), + livestream_sink]) + + await pipeline.run() + + +if __name__ == "__main__": + (url, token) = configure() + asyncio.run(main(url, token)) diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py index 22792270e..ab91aaae4 100644 --- a/examples/foundational/websocket-server/sample.py +++ b/examples/foundational/websocket-server/sample.py @@ -5,7 +5,7 @@ from dailyai.pipeline.frame_processor import FrameProcessor from dailyai.pipeline.frames import TextFrame, TranscriptionFrame from dailyai.pipeline.pipeline import Pipeline -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.transports.websocket_transport import WebsocketTransport from dailyai.services.whisper_ai_services import WhisperSTTService diff --git a/examples/image-gen.py b/examples/image-gen.py index 30d207447..ffa621e2b 100644 --- a/examples/image-gen.py +++ b/examples/image-gen.py @@ -7,9 +7,8 @@ from dailyai.transports.daily_transport import DailyTransport from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService -from dailyai.pipeline.frames import Frame, FrameType +from dailyai.pipeline.frames import Frame from dailyai.services.fal_ai_services import FalImageGenService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService async def main(room_url: str, token): diff --git a/examples/starter-apps/chatbot.py b/examples/starter-apps/chatbot.py index a46f54c5a..02e483d50 100644 --- a/examples/starter-apps/chatbot.py +++ b/examples/starter-apps/chatbot.py @@ -22,7 +22,7 @@ from dailyai.pipeline.pipeline import Pipeline from dailyai.transports.daily_transport import DailyTransport from dailyai.services.open_ai_services import OpenAILLMService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from runner import configure diff --git a/examples/starter-apps/patient-intake.py b/examples/starter-apps/patient-intake.py index d8b11a93f..9b94b13b5 100644 --- a/examples/starter-apps/patient-intake.py +++ b/examples/starter-apps/patient-intake.py @@ -17,7 +17,7 @@ from dailyai.services.openai_llm_context import OpenAILLMContext from dailyai.services.open_ai_services import OpenAILLMService # from dailyai.services.deepgram_ai_services import DeepgramTTSService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.services.fireworks_ai_services import FireworksLLMService from dailyai.pipeline.frames import ( Frame, diff --git a/examples/starter-apps/storybot.py b/examples/starter-apps/storybot.py index 69be94095..cd94f9c7c 100644 --- a/examples/starter-apps/storybot.py +++ b/examples/starter-apps/storybot.py @@ -16,7 +16,7 @@ from dailyai.services.fal_ai_services import FalImageGenService from dailyai.services.open_ai_services import OpenAILLMService from dailyai.services.deepgram_ai_services import DeepgramTTSService -from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.elevenlabs_ai_services import ElevenLabsTTSService from dailyai.pipeline.aggregators import ( LLMAssistantContextAggregator, LLMAssistantResponseAggregator, diff --git a/src/dailyai/aggregators/__init__.py b/src/dailyai/aggregators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dailyai/aggregators/gated.py b/src/dailyai/aggregators/gated.py new file mode 100644 index 000000000..4ed1edffa --- /dev/null +++ b/src/dailyai/aggregators/gated.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from dailyai.frames.frames import Frame, PipelineFrame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor + + +class GatedAggregator(FrameProcessor): + """Accumulate frames, with custom functions to start and stop accumulation. + Yields gate-opening frame before any accumulated frames, then ensuing frames + until and not including the gate-closed frame. + + >>> from dailyai.pipeline.frames import ImageFrame + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = GatedAggregator( + ... gate_close_fn=lambda x: isinstance(x, LLMResponseStartFrame), + ... gate_open_fn=lambda x: isinstance(x, ImageFrame), + ... start_open=False) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello again."))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) + ImageFrame + Hello + Hello again. + >>> asyncio.run(print_frames(aggregator, TextFrame("Goodbye."))) + Goodbye. + """ + + def __init__(self, gate_open_fn, gate_close_fn, start_open): + super().__init__() + self._gate_open_fn = gate_open_fn + self._gate_close_fn = gate_close_fn + self._gate_open = start_open + self._accumulator: List[Frame] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + # We must not block pipeline control frames. + if isinstance(frame, PipelineFrame): + await super().process_frame(frame, direction) + return + + if self._gate_open: + if self._gate_close_fn(frame): + self._gate_open = False + else: + if self._gate_open_fn(frame): + self._gate_open = True + + if self._gate_open: + await self.push_frame(frame, direction) + for frame in self._accumulator: + await self.push_frame(frame, direction) + self._accumulator = [] + else: + self._accumulator.append(frame) diff --git a/src/dailyai/aggregators/llm_context.py b/src/dailyai/aggregators/llm_context.py new file mode 100644 index 000000000..45f6fc3cf --- /dev/null +++ b/src/dailyai/aggregators/llm_context.py @@ -0,0 +1,89 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.frames.frames import (Frame, LLMMessagesFrame, TextFrame, TranscriptionFrame) +from dailyai.processors.frame_processor import FrameProcessor + + +class LLMContextAggregator(FrameProcessor): + def __init__( + self, + messages: list[dict], + role: str, + bot_participant_id=None, + complete_sentences=True, + pass_through=True, + ): + super().__init__() + self.messages = messages + self.bot_participant_id = bot_participant_id + self.role = role + self.sentence = "" + self.complete_sentences = complete_sentences + self.pass_through = pass_through + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + # We don't do anything with non-text frames, pass it along to next in + # the pipeline. + if not isinstance(frame, TextFrame): + yield frame + return + + # Ignore transcription frames from the bot + if isinstance(frame, TranscriptionFrame): + if frame.participantId == self.bot_participant_id: + return + + # The common case for "pass through" is receiving frames from the LLM that we'll + # use to update the "assistant" LLM messages, but also passing the text frames + # along to a TTS service to be spoken to the user. + if self.pass_through: + yield frame + + # TODO: split up transcription by participant + if self.complete_sentences: + # type: ignore -- the linter thinks this isn't a TextFrame, even + # though we check it above + self.sentence += frame.text + if self.sentence.endswith((".", "?", "!")): + self.messages.append( + {"role": self.role, "content": self.sentence}) + self.sentence = "" + yield LLMMessagesFrame(self.messages) + else: + # type: ignore -- the linter thinks this isn't a TextFrame, even + # though we check it above + self.messages.append({"role": self.role, "content": frame.text}) + yield LLMMessagesFrame(self.messages) + + +class LLMUserContextAggregator(LLMContextAggregator): + def __init__( + self, + messages: list[dict], + bot_participant_id=None, + complete_sentences=True): + super().__init__( + messages, + "user", + bot_participant_id, + complete_sentences, + pass_through=False) + + +class LLMAssistantContextAggregator(LLMContextAggregator): + def __init__( + self, + messages: list[dict], + bot_participant_id=None, + complete_sentences=True): + super().__init__( + messages, + "assistant", + bot_participant_id, + complete_sentences, + pass_through=True, + ) diff --git a/src/dailyai/aggregators/llm_response.py b/src/dailyai/aggregators/llm_response.py new file mode 100644 index 000000000..baa025dd8 --- /dev/null +++ b/src/dailyai/aggregators/llm_response.py @@ -0,0 +1,188 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.frames.frames import ( + Frame, + InterimTranscriptionFrame, + LLMMessagesFrame, + LLMResponseStartFrame, + TextFrame, + LLMResponseEndFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) + + +class LLMResponseAggregator(FrameProcessor): + + def __init__( + self, + *, + messages: list[dict] | None, + role: str, + start_frame, + end_frame, + accumulator_frame, + interim_accumulator_frame=None, + **kwargs + ): + super().__init__(**kwargs) + + self._messages = messages + self._role = role + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._interim_accumulator_frame = interim_accumulator_frame + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + self._aggregation = "" + self._aggregating = False + + # + # Frame processor + # + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> X + # S I T E -> X + # S I E T -> X + # S I E I T -> X + # + # The following case would not be supported: + # + # S I E T1 I T2 -> X + # + # and T2 would be dropped. + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if not self._messages: + return + + send_aggregation = False + + if isinstance(frame, self._start_frame): + self._seen_start_frame = True + self._aggregating = True + elif isinstance(frame, self._end_frame): + self._seen_end_frame = True + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self._aggregating = self._seen_interim_results + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self._aggregating + elif isinstance(frame, self._accumulator_frame): + if self._aggregating: + self._aggregation += f" {frame.data}" + # We have recevied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + send_aggregation = self._seen_end_frame + + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + self._seen_interim_results = True + + if send_aggregation: + await self._push_aggregation() + + async def _push_aggregation(self): + if len(self._aggregation) > 0: + self._messages.append({"role": self._role, "content": self._aggregation}) + frame = LLMMessagesFrame(self._messages) + await self.push_frame(frame) + + # Reset + self._aggregation = "" + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + +class LLMAssistantResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: list[dict]): + super().__init__( + messages=messages, + role="assistant", + start_frame=LLMResponseStartFrame, + end_frame=LLMResponseEndFrame, + accumulator_frame=TextFrame + ) + + +class LLMUserResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: list[dict]): + super().__init__( + messages=messages, + role="user", + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + interim_accumulator_frame=InterimTranscriptionFrame + ) + + +class LLMFullResponseAggregator(FrameProcessor): + """This class aggregates Text frames until it receives a + LLMResponseEndFrame, then emits the concatenated text as + a single text frame. + + given the following frames: + + TextFrame("Hello,") + TextFrame(" world.") + TextFrame(" I am") + TextFrame(" an LLM.") + LLMResponseEndFrame()] + + this processor will yield nothing for the first 4 frames, then + + TextFrame("Hello, world. I am an LLM.") + LLMResponseEndFrame() + + when passed the last frame. + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = LLMFullResponseAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) + >>> asyncio.run(print_frames(aggregator, LLMResponseEndFrame())) + Hello, world. I am an LLM. + LLMResponseEndFrame + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + self._aggregation += frame.data + elif isinstance(frame, LLMResponseEndFrame): + await self.push_frame(TextFrame(self._aggregation)) + self._aggregation = "" + + await super().process_frame(frame, direction) diff --git a/src/dailyai/pipeline/opeanai_llm_aggregator.py b/src/dailyai/aggregators/openai_llm_context.py similarity index 63% rename from src/dailyai/pipeline/opeanai_llm_aggregator.py rename to src/dailyai/aggregators/openai_llm_context.py index b4b254087..eb33d9d23 100644 --- a/src/dailyai/pipeline/opeanai_llm_aggregator.py +++ b/src/dailyai/aggregators/openai_llm_context.py @@ -1,6 +1,12 @@ -from typing import AsyncGenerator, Callable -from dailyai.pipeline.frame_processor import FrameProcessor -from dailyai.pipeline.frames import ( +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator, Callable, List + +from dailyai.frames.frames import ( Frame, LLMResponseEndFrame, LLMResponseStartFrame, @@ -9,16 +15,59 @@ UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) -from dailyai.pipeline.openai_frames import OpenAILLMContextFrame -from dailyai.services.openai_llm_context import OpenAILLMContext - -try: - from openai.types.chat import ChatCompletionRole -except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( - "In order to use OpenAI, you need to `pip install dailyai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") - raise Exception(f"Missing module: {e}") +from dailyai.frames.openai_frames import OpenAILLMContextFrame +from dailyai.processors.frame_processor import FrameProcessor + +from openai._types import NOT_GIVEN, NotGiven + +from openai.types.chat import ( + ChatCompletionRole, + ChatCompletionToolParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionMessageParam +) + + +class OpenAILLMContext: + + def __init__( + self, + messages: List[ChatCompletionMessageParam] | None = None, + tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN + ): + self.messages: List[ChatCompletionMessageParam] = messages if messages else [ + ] + self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice + self.tools: List[ChatCompletionToolParam] | NotGiven = tools + + @ staticmethod + def from_messages(messages: List[dict]) -> "OpenAILLMContext": + context = OpenAILLMContext() + for message in messages: + context.add_message({ + "content": message["content"], + "role": message["role"], + "name": message["name"] if "name" in message else message["role"] + }) + return context + + def add_message(self, message: ChatCompletionMessageParam): + self.messages.append(message) + + def get_messages(self) -> List[ChatCompletionMessageParam]: + return self.messages + + def set_tool_choice( + self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven + ): + self.tool_choice = tool_choice + + def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN): + if tools != NOT_GIVEN and len(tools) == 0: + tools = NOT_GIVEN + + self.tools = tools class OpenAIContextAggregator(FrameProcessor): diff --git a/src/dailyai/aggregators/sentence.py b/src/dailyai/aggregators/sentence.py new file mode 100644 index 000000000..d4c469cea --- /dev/null +++ b/src/dailyai/aggregators/sentence.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import re + +from typing import List + +from dailyai.frames.frames import EndFrame, Frame, TextFrame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor + + +class SentenceAggregator(FrameProcessor): + """This frame processor aggregates text frames into complete sentences. + + Frame input/output: + TextFrame("Hello,") -> None + TextFrame(" world.") -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame.text) + + >>> aggregator = SentenceAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + Hello, world. + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + m = re.search("(.*[?.!])(.*)", frame.data) + if m: + await self.push_frame(TextFrame(self._aggregation + m.group(1))) + self._aggregation = m.group(2) + else: + self._aggregation += frame.data + elif isinstance(frame, EndFrame): + if self._aggregation: + await self.push_frame(TextFrame(self._aggregation)) + + await super().process_frame(frame, direction) diff --git a/src/dailyai/aggregators/user_response.py b/src/dailyai/aggregators/user_response.py new file mode 100644 index 000000000..ae10c82dd --- /dev/null +++ b/src/dailyai/aggregators/user_response.py @@ -0,0 +1,138 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.frames.frames import ( + Frame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) + + +class ResponseAggregator(FrameProcessor): + """This frame processor aggregates frames between a start and an end frame + into complete text frame sentences. + + For example, frame input/output: + UserStartedSpeakingFrame() -> None + TranscriptionFrame("Hello,") -> None + TranscriptionFrame(" world.") -> None + UserStoppedSpeakingFrame() -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + + >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) + >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) + >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) + Hello, world. + + """ + + def __init__( + self, + *, + start_frame, + end_frame, + accumulator_frame + interim_accumulator_frame=None, + **kwargs + ): + super().__init__(**kwargs) + + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._interim_accumulator_frame = interim_accumulator_frame + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + self._aggregation = "" + self._aggregating = False + + # + # Frame processor + # + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> X + # S I T E -> X + # S I E T -> X + # S I E I T -> X + # + # The following case would not be supported: + # + # S I E T1 I T2 -> X + # + # and T2 would be dropped. + + async def process_frame(self, frame: Frame, direction: FrameDirection): + send_aggregation = False + + if isinstance(frame, self._start_frame): + self._seen_start_frame = True + self._aggregating = True + elif isinstance(frame, self._end_frame): + self._seen_end_frame = True + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self._aggregating = self._seen_interim_results + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self._aggregating + elif isinstance(frame, self._accumulator_frame): + if self._aggregating: + self._aggregation += f" {frame.data}" + # We have recevied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + send_aggregation = self._seen_end_frame + + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + self._seen_interim_results = True + + if send_aggregation: + await self._push_aggregation() + + await super().process_frame(frame, direction) + + async def _push_aggregation(self): + if len(self._aggregation) > 0: + await self.push_frame(TextFrame(self._aggregation.strip())) + + # Reset + self._aggregation = "" + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + +class UserResponseAggregator(ResponseAggregator): + def __init__(self): + super().__init__( + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + ) diff --git a/src/dailyai/aggregators/vision_image_frame.py b/src/dailyai/aggregators/vision_image_frame.py new file mode 100644 index 000000000..768476e87 --- /dev/null +++ b/src/dailyai/aggregators/vision_image_frame.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.frames.frames import Frame, ImageRawFrame, TextFrame, VisionImageRawFrame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor + + +class VisionImageFrameAggregator(FrameProcessor): + """This aggregator waits for a consecutive TextFrame and an + ImageFrame. After the ImageFrame arrives it will output a VisionImageFrame. + + >>> from dailyai.pipeline.frames import ImageFrame + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame) + + >>> aggregator = VisionImageFrameAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("What do you see?"))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) + VisionImageFrame, text: What do you see?, image size: 0x0, buffer size: 0 B + + """ + + def __init__(self): + super().__init__() + self._describe_text = None + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + self._describe_text = frame.text + elif isinstance(frame, ImageRawFrame): + if self._describe_text: + frame = VisionImageRawFrame( + self._describe_text, frame.image, frame.size, frame.format) + await self.push_frame(frame) + self._describe_text = None + + await super().process_frame(frame, direction) diff --git a/src/dailyai/frames/__init__.py b/src/dailyai/frames/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/frames/frames.proto similarity index 81% rename from src/dailyai/pipeline/frames.proto rename to src/dailyai/frames/frames.proto index b19fbccbf..ae084d649 100644 --- a/src/dailyai/pipeline/frames.proto +++ b/src/dailyai/frames/frames.proto @@ -1,3 +1,9 @@ +// +// Copyright (c) 2024, Daily +// +// SPDX-License-Identifier: BSD 2-Clause License +// + syntax = "proto3"; package dailyai_proto; diff --git a/src/dailyai/frames/frames.py b/src/dailyai/frames/frames.py new file mode 100644 index 000000000..c857328c6 --- /dev/null +++ b/src/dailyai/frames/frames.py @@ -0,0 +1,406 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Any + +from dailyai.utils.utils import obj_count + + +class Frame: + def __init__(self, data=None): + self.id: int = id(self) + self.data: Any = data + self.metadata = {} + self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" + + def __str__(self): + return self.name + + +class AudioRawFrame(Frame): + def __init__(self, data, sample_rate: int, num_channels: int): + super().__init__(data) + self.metadata["sample_rate"] = sample_rate + self.metadata["num_channels"] = num_channels + self.metadata["num_frames"] = int(len(data) / (num_channels * 2)) + + @property + def num_frames(self) -> int: + return self.metadata["num_frames"] + + @property + def sample_rate(self) -> int: + return self.metadata["sample_rate"] + + @property + def num_channels(self) -> int: + return self.metadata["num_channels"] + + def __str__(self): + return f"{self.name}(frames: {self.num_frames}, sample_rate: {self.sample_rate}, channels: {self.num_channels})" + + +class ImageRawFrame(Frame): + def __init__(self, data, size, format): + super().__init__(data) + self.metadata["size"] = size + self.metadata["format"] = format + + @property + def image(self) -> bytes: + return self.data + + @property + def size(self) -> tuple[int, int]: + return self.metadata["size"] + + @property + def format(self) -> str: + return self.metadata["format"] + + def __str__(self): + return f"{self.name}(size: {self.size}, format: {self.format})" + + +class URLImageRawFrame(ImageRawFrame): + def __init__(self, url, data, size, format): + super().__init__(data, size, format) + self.metadata["url"] = url + + @property + def url(self) -> str: + return self.metadata["url"] + + def __str__(self): + return f"{self.name}(url: {self.url}, size: {self.size}, format: {self.format})" + + +class VisionImageRawFrame(ImageRawFrame): + def __init__(self, text, data, size, format): + super().__init__(data, size, format) + self.metadata["text"] = text + + @property + def text(self) -> str: + return self.metadata["text"] + + def __str__(self): + return f"{self.name}(text: {self.text}, size: {self.size}, format: {self.format})" + + +class TextFrame(Frame): + def __init__(self, data): + super().__init__(data) + + @property + def text(self) -> str: + return self.data + + +class TranscriptionFrame(Frame): + def __init__(self, data, user_id, timestamp): + super().__init__(data) + self.metadata["user_id"] = user_id + self.metadata["timestamp"] = timestamp + + @property + def text(self) -> str: + return self.data + + @property + def user_id(self) -> str: + return self.metadata["user_id"] + + @property + def timestamp(self) -> str: + return self.metadata["timestamp"] + + def __str__(self): + return f"{self.name}(user: {self.user_id}, timestamp: {self.timestamp})" + + +class InterimTranscriptionFrame(Frame): + def __init__(self, data, user_id, timestamp): + super().__init__(data) + self.metadata["user_id"] = user_id + self.metadata["timestamp"] = timestamp + + @property + def text(self) -> str: + return self.data + + @property + def user_id(self) -> str: + return self.metadata["user_id"] + + @property + def timestamp(self) -> str: + return self.metadata["timestamp"] + + def __str__(self): + return f"{self.name}(user: {self.user_id}, timestamp: {self.timestamp})" + + +class LLMMessagesFrame(Frame): + """A frame containing a list of LLM messages. Used to signal that an LLM + service should run a chat completion and emit an LLM started response event, + text frames and an LLM stopped response event. + """ + + def __init__(self, messages): + super().__init__(messages) + + +# +# Pipeline Control frames +# + + +class PipelineFrame(Frame): + def __init__(self): + super().__init__() + + +class StartFrame(PipelineFrame): + def __init__(self): + super().__init__() + + +class EndFrame(PipelineFrame): + def __init__(self): + super().__init__() + + +# +# Control frames +# + + +class ControlFrame(Frame): + def __init__(self): + super().__init__() + + +class LLMResponseStartFrame(ControlFrame): + """Used to indicate the beginning of an LLM response. Following TextFrames + are part of the LLM response until an LLMResponseEndFrame""" + + def __init__(self): + super().__init__() + + +class LLMResponseEndFrame(ControlFrame): + """Indicates the end of an LLM response.""" + + def __init__(self): + super().__init__() + + +class UserStartedSpeakingFrame(ControlFrame): + def __init__(self): + super().__init__() + + +class UserStoppedSpeakingFrame(ControlFrame): + def __init__(self): + super().__init__() + + +class TTSStartedFrame(ControlFrame): + def __init__(self): + super().__init__() + + +class TTSStoppedFrame(ControlFrame): + def __init__(self): + super().__init__() + +# class StartFrame(ControlFrame): +# """Used (but not required) to start a pipeline, and is also used to +# indicate that an interruption has ended and the transport should start +# processing frames again.""" +# pass + + +# class EndFrame(ControlFrame): +# """Indicates that a pipeline has ended and frame processors and pipelines +# should be shut down. If the transport receives this frame, it will stop +# sending frames to its output channel(s) and close all its threads.""" +# pass + + +# class EndPipeFrame(ControlFrame): +# """Indicates that a pipeline has ended but that the transport should +# continue processing. This frame is used in parallel pipelines and other +# sub-pipelines.""" +# pass + + +# class PipelineStartedFrame(ControlFrame): +# """ +# Used by the transport to indicate that execution of a pipeline is starting +# (or restarting). It should be the first frame your app receives when it +# starts, or when an interruptible pipeline has been interrupted. +# """ + +# pass + + +# @dataclass() +# class URLImageFrame(ImageFrame): +# """An image with an associated URL. Will be shown by the transport if the +# transport's camera is enabled. + +# """ +# url: str | None + +# def __init__(self, url, image, size): +# super().__init__(image, size) +# self.url = url + +# def __str__(self): +# return f"{self.__class__.__name__}, url: {self.url}, image size: +# {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" + + +# @dataclass() +# class VisionImageFrame(ImageFrame): +# """An image with an associated text to ask for a description of it. Will be shown by the +# transport if the transport's camera is enabled. + +# """ +# text: str | None + +# def __init__(self, text, image, size): +# super().__init__(image, size) +# self.text = text + +# def __str__(self): +# return f"{self.__class__.__name__}, text: {self.text}, image size: +# {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" + + +# @dataclass() +# class UserImageFrame(ImageFrame): +# """An image associated to a user. Will be shown by the transport if the transport's camera is +# enabled.""" +# user_id: str + +# def __init__(self, user_id, image, size): +# super().__init__(image, size) +# self.user_id = user_id + +# def __str__(self): +# return f"{self.__class__.__name__}, user: {self.user_id}, image size: +# {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" + + +# @dataclass() +# class UserImageRequestFrame(Frame): +# """A frame user to request an image from the given user.""" +# user_id: str + +# def __str__(self): +# return f"{self.__class__.__name__}, user: {self.user_id}" + + +# @dataclass() +# class SpriteFrame(Frame): +# """An animated sprite. Will be shown by the transport if the transport's +# camera is enabled. Will play at the framerate specified in the transport's +# `fps` constructor parameter.""" +# images: list[bytes] + +# def __str__(self): +# return f"{self.__class__.__name__}, list size: {len(self.images)}" + + +# @dataclass() +# class TextFrame(Frame): +# """A chunk of text. Emitted by LLM services, consumed by TTS services, can +# be used to send text through pipelines.""" +# text: str + +# def __str__(self): +# return f'{self.__class__.__name__}: "{self.text}"' + + +# class TTSStartFrame(ControlFrame): +# """Used to indicate the beginning of a TTS response. Following AudioFrames +# are part of the TTS response until an TTEndFrame. These frames can be used +# for aggregating audio frames in a transport to optimize the size of frames +# sent to the session, without needing to control this in the TTS service.""" +# pass + + +# class TTSEndFrame(ControlFrame): +# """Indicates the end of a TTS response.""" +# pass + + +# @dataclass() +# class LLMMessagesFrame(Frame): +# """A frame containing a list of LLM messages. Used to signal that an LLM +# service should run a chat completion and emit an LLMStartFrames, TextFrames +# and an LLMEndFrame. +# Note that the messages property on this class is mutable, and will be +# be updated by various ResponseAggregator frame processors.""" +# messages: List[dict] + + +# @dataclass() +# class ReceivedAppMessageFrame(Frame): +# message: Any +# sender: str + +# def __str__(self): +# return f"ReceivedAppMessageFrame: sender: {self.sender}, message: {self.message}" + + +# @dataclass() +# class SendAppMessageFrame(Frame): +# message: Any +# participant_id: str | None + +# def __str__(self): +# return f"SendAppMessageFrame: participant: {self.participant_id}, message: {self.message}" + + +# class UserStartedSpeakingFrame(Frame): +# """Emitted by VAD to indicate that a participant has started speaking. +# This can be used for interruptions or other times when detecting that +# someone is speaking is more important than knowing what they're saying +# (as you will with a TranscriptionFrame)""" +# pass + + +# class UserStoppedSpeakingFrame(Frame): +# """Emitted by the VAD to indicate that a user stopped speaking.""" +# pass + + +# class BotStartedSpeakingFrame(Frame): +# pass + + +# class BotStoppedSpeakingFrame(Frame): +# pass + + +# @dataclass() +# class LLMFunctionStartFrame(Frame): +# """Emitted when the LLM receives the beginning of a function call +# completion. A frame processor can use this frame to indicate that it should +# start preparing to make a function call, if it can do so in the absence of +# any arguments.""" +# function_name: str + + +# @dataclass() +# class LLMFunctionCallFrame(Frame): +# """Emitted when the LLM has received an entire function call completion.""" +# function_name: str +# arguments: str diff --git a/src/dailyai/frames/openai_frames.py b/src/dailyai/frames/openai_frames.py new file mode 100644 index 000000000..fcc0df4f5 --- /dev/null +++ b/src/dailyai/frames/openai_frames.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.frames.frames import Frame + + +class OpenAILLMContextFrame(Frame): + """Like an LLMMessagesFrame, but with extra context specific to the + OpenAI API.""" + + def __init__(self, data): + super().__init__(data) diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/frames/protobufs/frames_pb2.py similarity index 100% rename from src/dailyai/pipeline/protobufs/frames_pb2.py rename to src/dailyai/frames/protobufs/frames_pb2.py diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py deleted file mode 100644 index 81ea5815c..000000000 --- a/src/dailyai/pipeline/aggregators.py +++ /dev/null @@ -1,549 +0,0 @@ -import asyncio -import re -import time - -from dailyai.pipeline.frame_processor import FrameProcessor - -from dailyai.pipeline.frames import ( - EndFrame, - EndPipeFrame, - Frame, - ImageFrame, - InterimTranscriptionFrame, - LLMMessagesFrame, - LLMResponseEndFrame, - LLMResponseStartFrame, - TextFrame, - TranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - VisionImageFrame, -) -from dailyai.pipeline.pipeline import Pipeline -from dailyai.services.ai_services import AIService - -from typing import AsyncGenerator, Coroutine, List - - -class ResponseAggregator(FrameProcessor): - """This frame processor aggregates frames between a start and an end frame - into complete text frame sentences. - - For example, frame input/output: - UserStartedSpeakingFrame() -> None - TranscriptionFrame("Hello,") -> None - TranscriptionFrame(" world.") -> None - UserStoppedSpeakingFrame() -> TextFrame("Hello world.") - - Doctest: - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - - >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, - ... end_frame=UserStoppedSpeakingFrame, - ... accumulator_frame=TranscriptionFrame, - ... pass_through=False) - >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) - >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) - >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) - Hello, world. - - """ - - def __init__( - self, - *, - start_frame, - end_frame, - accumulator_frame, - pass_through=True, - ): - self.aggregation = "" - self.aggregating = False - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._pass_through = pass_through - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, self._start_frame): - self.aggregating = True - elif isinstance(frame, self._end_frame): - self.aggregating = False - # Sometimes VAD triggers quickly on and off. If we don't get any transcription, - # it creates empty LLM message queue frames - if len(self.aggregation) > 0: - output = self.aggregation - self.aggregation = "" - yield self._end_frame() - yield TextFrame(output.strip()) - elif isinstance(frame, self._accumulator_frame) and self.aggregating: - self.aggregation += f" {frame.text}" - if self._pass_through: - yield frame - else: - yield frame - - -class UserResponseAggregator(ResponseAggregator): - def __init__(self): - super().__init__( - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - pass_through=False, - ) - - -class LLMResponseAggregator(FrameProcessor): - - def __init__( - self, - *, - messages: list[dict] | None, - role: str, - start_frame, - end_frame, - accumulator_frame, - interim_accumulator_frame=None, - pass_through=True, - ): - self.aggregation = "" - self.aggregating = False - self.messages = messages - self._role = role - self._start_frame = start_frame - self._end_frame = end_frame - self._accumulator_frame = accumulator_frame - self._interim_accumulator_frame = interim_accumulator_frame - self._pass_through = pass_through - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - # Use cases implemented: - # - # S: Start, E: End, T: Transcription, I: Interim, X: Text - # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if not self.messages: - return - - send_aggregation = False - - if isinstance(frame, self._start_frame): - self._seen_start_frame = True - self.aggregating = True - elif isinstance(frame, self._end_frame): - self._seen_end_frame = True - - # We might have received the end frame but we might still be - # aggregating (i.e. we have seen interim results but not the final - # text). - self.aggregating = self._seen_interim_results - - # Send the aggregation if we are not aggregating anymore (i.e. no - # more interim results received). - send_aggregation = not self.aggregating - elif isinstance(frame, self._accumulator_frame): - if self.aggregating: - self.aggregation += f" {frame.text}" - # We have receied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame - - if self._pass_through: - yield frame - - # We just got our final result, so let's reset interim results. - self._seen_interim_results = False - elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): - self._seen_interim_results = True - else: - yield frame - - if send_aggregation and len(self.aggregation) > 0: - self.messages.append({"role": self._role, "content": self.aggregation}) - yield self._end_frame() - yield LLMMessagesFrame(self.messages) - # Reset - self.aggregation = "" - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - - -class LLMAssistantResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: list[dict]): - super().__init__( - messages=messages, - role="assistant", - start_frame=LLMResponseStartFrame, - end_frame=LLMResponseEndFrame, - accumulator_frame=TextFrame, - ) - - -class LLMUserResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: list[dict]): - super().__init__( - messages=messages, - role="user", - start_frame=UserStartedSpeakingFrame, - end_frame=UserStoppedSpeakingFrame, - accumulator_frame=TranscriptionFrame, - interim_accumulator_frame=InterimTranscriptionFrame, - pass_through=False, - ) - - -class LLMContextAggregator(AIService): - def __init__( - self, - messages: list[dict], - role: str, - bot_participant_id=None, - complete_sentences=True, - pass_through=True, - ): - super().__init__() - self.messages = messages - self.bot_participant_id = bot_participant_id - self.role = role - self.sentence = "" - self.complete_sentences = complete_sentences - self.pass_through = pass_through - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - # We don't do anything with non-text frames, pass it along to next in - # the pipeline. - if not isinstance(frame, TextFrame): - yield frame - return - - # Ignore transcription frames from the bot - if isinstance(frame, TranscriptionFrame): - if frame.participantId == self.bot_participant_id: - return - - # The common case for "pass through" is receiving frames from the LLM that we'll - # use to update the "assistant" LLM messages, but also passing the text frames - # along to a TTS service to be spoken to the user. - if self.pass_through: - yield frame - - # TODO: split up transcription by participant - if self.complete_sentences: - # type: ignore -- the linter thinks this isn't a TextFrame, even - # though we check it above - self.sentence += frame.text - if self.sentence.endswith((".", "?", "!")): - self.messages.append( - {"role": self.role, "content": self.sentence}) - self.sentence = "" - yield LLMMessagesFrame(self.messages) - else: - # type: ignore -- the linter thinks this isn't a TextFrame, even - # though we check it above - self.messages.append({"role": self.role, "content": frame.text}) - yield LLMMessagesFrame(self.messages) - - -class LLMUserContextAggregator(LLMContextAggregator): - def __init__( - self, - messages: list[dict], - bot_participant_id=None, - complete_sentences=True): - super().__init__( - messages, - "user", - bot_participant_id, - complete_sentences, - pass_through=False) - - -class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__( - self, - messages: list[dict], - bot_participant_id=None, - complete_sentences=True): - super().__init__( - messages, - "assistant", - bot_participant_id, - complete_sentences, - pass_through=True, - ) - - -class SentenceAggregator(FrameProcessor): - """This frame processor aggregates text frames into complete sentences. - - Frame input/output: - TextFrame("Hello,") -> None - TextFrame(" world.") -> TextFrame("Hello world.") - - Doctest: - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... print(frame.text) - - >>> aggregator = SentenceAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - Hello, world. - """ - - def __init__(self): - self.aggregation = "" - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, TextFrame): - m = re.search("(.*[?.!])(.*)", frame.text) - if m: - yield TextFrame(self.aggregation + m.group(1)) - self.aggregation = m.group(2) - else: - self.aggregation += frame.text - elif isinstance(frame, EndFrame): - if self.aggregation: - yield TextFrame(self.aggregation) - yield frame - else: - yield frame - - -class LLMFullResponseAggregator(FrameProcessor): - """This class aggregates Text frames until it receives a - LLMResponseEndFrame, then emits the concatenated text as - a single text frame. - - given the following frames: - - TextFrame("Hello,") - TextFrame(" world.") - TextFrame(" I am") - TextFrame(" an LLM.") - LLMResponseEndFrame()] - - this processor will yield nothing for the first 4 frames, then - - TextFrame("Hello, world. I am an LLM.") - LLMResponseEndFrame() - - when passed the last frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = LLMFullResponseAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) - >>> asyncio.run(print_frames(aggregator, LLMResponseEndFrame())) - Hello, world. I am an LLM. - LLMResponseEndFrame - """ - - def __init__(self): - self.aggregation = "" - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, TextFrame): - self.aggregation += frame.text - elif isinstance(frame, LLMResponseEndFrame): - yield TextFrame(self.aggregation) - yield frame - self.aggregation = "" - else: - yield frame - - -class StatelessTextTransformer(FrameProcessor): - """This processor calls the given function on any text in a text frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... print(frame.text) - - >>> aggregator = StatelessTextTransformer(lambda x: x.upper()) - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) - HELLO - """ - - def __init__(self, transform_fn): - self.transform_fn = transform_fn - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, TextFrame): - result = self.transform_fn(frame.text) - if isinstance(result, Coroutine): - result = await result - - yield TextFrame(result) - else: - yield frame - - -class ParallelPipeline(FrameProcessor): - """Run multiple pipelines in parallel. - - This class takes frames from its source queue and sends them to each - sub-pipeline. Each sub-pipeline emits its frames into this class's - sink queue. No guarantees are made about the ordering of frames in - the sink queue (that is, no sub-pipeline has higher priority than - any other, frames are put on the sink in the order they're emitted - by the sub-pipelines). - - After each frame is taken from this class's source queue and placed - in each sub-pipeline's source queue, an EndPipeFrame is put on each - sub-pipeline's source queue. This indicates to the sub-pipe runner - that it should exit. - - Since frame handlers pass through unhandled frames by convention, this - class de-dupes frames in its sink before yielding them. - """ - - def __init__(self, pipeline_definitions: List[List[FrameProcessor]]): - self.sources = [asyncio.Queue() for _ in pipeline_definitions] - self.sink: asyncio.Queue[Frame] = asyncio.Queue() - self.pipelines: list[Pipeline] = [ - Pipeline( - pipeline_definition, - source, - self.sink, - ) - for source, pipeline_definition in zip(self.sources, pipeline_definitions) - ] - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - for source in self.sources: - await source.put(frame) - await source.put(EndPipeFrame()) - - await asyncio.gather(*[pipeline.run_pipeline() for pipeline in self.pipelines]) - - seen_ids = set() - while not self.sink.empty(): - frame = await self.sink.get() - - # de-dup frames. Because the convention is to yield a frame that isn't processed, - # each pipeline will likely yield the same frame, so we will end up with _n_ copies - # of unprocessed frames where _n_ is the number of parallel pipes that don't - # process that frame. - if id(frame) in seen_ids: - continue - seen_ids.add(id(frame)) - - # Skip passing along EndPipeFrame, because we use them - # for our own flow control. - if not isinstance(frame, EndPipeFrame): - yield frame - - -class GatedAggregator(FrameProcessor): - """Accumulate frames, with custom functions to start and stop accumulation. - Yields gate-opening frame before any accumulated frames, then ensuing frames - until and not including the gate-closed frame. - - >>> from dailyai.pipeline.frames import ImageFrame - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = GatedAggregator( - ... gate_close_fn=lambda x: isinstance(x, LLMResponseStartFrame), - ... gate_open_fn=lambda x: isinstance(x, ImageFrame), - ... start_open=False) - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello again."))) - >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) - ImageFrame - Hello - Hello again. - >>> asyncio.run(print_frames(aggregator, TextFrame("Goodbye."))) - Goodbye. - """ - - def __init__(self, gate_open_fn, gate_close_fn, start_open): - self.gate_open_fn = gate_open_fn - self.gate_close_fn = gate_close_fn - self.gate_open = start_open - self.accumulator: List[Frame] = [] - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if self.gate_open: - if self.gate_close_fn(frame): - self.gate_open = False - else: - if self.gate_open_fn(frame): - self.gate_open = True - - if self.gate_open: - yield frame - if self.accumulator: - for frame in self.accumulator: - yield frame - self.accumulator = [] - else: - self.accumulator.append(frame) - - -class VisionImageFrameAggregator(FrameProcessor): - """This aggregator waits for a consecutive TextFrame and an - ImageFrame. After the ImageFrame arrives it will output a VisionImageFrame. - - >>> from dailyai.pipeline.frames import ImageFrame - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... print(frame) - - >>> aggregator = VisionImageFrameAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("What do you see?"))) - >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) - VisionImageFrame, text: What do you see?, image size: 0x0, buffer size: 0 B - - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._describe_text = None - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, TextFrame): - self._describe_text = frame.text - elif isinstance(frame, ImageFrame): - if self._describe_text: - yield VisionImageFrame(self._describe_text, frame.image, frame.size) - self._describe_text = None - else: - yield frame - else: - yield frame diff --git a/src/dailyai/pipeline/frame_processor.py b/src/dailyai/pipeline/frame_processor.py deleted file mode 100644 index e8c78e3e2..000000000 --- a/src/dailyai/pipeline/frame_processor.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import abstractmethod -from typing import AsyncGenerator - -from dailyai.pipeline.frames import ControlFrame, Frame - - -class FrameProcessor: - """This is the base class for all frame processors. Frame processors consume a frame - and yield 0 or more frames. Generally frame processors are used as part of a pipeline - where frames come from a source queue, are processed by a series of frame processors, - then placed on a sink queue. - - By convention, FrameProcessors should immediately yield any frames they don't process. - - Stateful FrameProcessors should watch for the EndFrame and finalize their - output, eg. yielding an unfinished sentence if they're aggregating LLM output to full - sentences. EndFrame is also a chance to clean up any services that need to - be closed, del'd, etc. - """ - - @abstractmethod - async def process_frame( - self, frame: Frame - ) -> AsyncGenerator[Frame, None]: - """Process a single frame and yield 0 or more frames.""" - yield frame - - @abstractmethod - async def interrupted(self) -> None: - """Handle any cleanup if the pipeline was interrupted.""" - pass - - def __str__(self): - return self.__class__.__name__ diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py deleted file mode 100644 index 28a920dd8..000000000 --- a/src/dailyai/pipeline/frames.py +++ /dev/null @@ -1,253 +0,0 @@ -from dataclasses import dataclass -from typing import Any, List - - -class Frame: - def __str__(self): - return f"{self.__class__.__name__}" - - -class ControlFrame(Frame): - # Control frames should contain no instance data, so - # equality is based solely on the class. - def __eq__(self, other): - return isinstance(other, self.__class__) - - -class StartFrame(ControlFrame): - """Used (but not required) to start a pipeline, and is also used to - indicate that an interruption has ended and the transport should start - processing frames again.""" - pass - - -class EndFrame(ControlFrame): - """Indicates that a pipeline has ended and frame processors and pipelines - should be shut down. If the transport receives this frame, it will stop - sending frames to its output channel(s) and close all its threads.""" - pass - - -class EndPipeFrame(ControlFrame): - """Indicates that a pipeline has ended but that the transport should - continue processing. This frame is used in parallel pipelines and other - sub-pipelines.""" - pass - - -class PipelineStartedFrame(ControlFrame): - """ - Used by the transport to indicate that execution of a pipeline is starting - (or restarting). It should be the first frame your app receives when it - starts, or when an interruptible pipeline has been interrupted. - """ - - pass - - -class LLMResponseStartFrame(ControlFrame): - """Used to indicate the beginning of an LLM response. Following TextFrames - are part of the LLM response until an LLMResponseEndFrame""" - pass - - -class LLMResponseEndFrame(ControlFrame): - """Indicates the end of an LLM response.""" - pass - - -@dataclass() -class AudioFrame(Frame): - """A chunk of audio. Will be played by the transport if the transport's mic - has been enabled.""" - data: bytes - - def __str__(self): - return f"{self.__class__.__name__}, size: {len(self.data)} B" - - -@dataclass() -class ImageFrame(Frame): - """An image. Will be shown by the transport if the transport's camera is - enabled.""" - image: bytes - size: tuple[int, int] - - def __str__(self): - return f"{self.__class__.__name__}, image size: {self.size[0]}x{self.size[1]} buffer size: {len(self.image)} B" - - -@dataclass() -class URLImageFrame(ImageFrame): - """An image with an associated URL. Will be shown by the transport if the - transport's camera is enabled. - - """ - url: str | None - - def __init__(self, url, image, size): - super().__init__(image, size) - self.url = url - - def __str__(self): - return f"{self.__class__.__name__}, url: {self.url}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" - - -@dataclass() -class VisionImageFrame(ImageFrame): - """An image with an associated text to ask for a description of it. Will be shown by the - transport if the transport's camera is enabled. - - """ - text: str | None - - def __init__(self, text, image, size): - super().__init__(image, size) - self.text = text - - def __str__(self): - return f"{self.__class__.__name__}, text: {self.text}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" - - -@dataclass() -class UserImageFrame(ImageFrame): - """An image associated to a user. Will be shown by the transport if the transport's camera is - enabled.""" - user_id: str - - def __init__(self, user_id, image, size): - super().__init__(image, size) - self.user_id = user_id - - def __str__(self): - return f"{self.__class__.__name__}, user: {self.user_id}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" - - -@dataclass() -class UserImageRequestFrame(Frame): - """A frame user to request an image from the given user.""" - user_id: str - - def __str__(self): - return f"{self.__class__.__name__}, user: {self.user_id}" - - -@dataclass() -class SpriteFrame(Frame): - """An animated sprite. Will be shown by the transport if the transport's - camera is enabled. Will play at the framerate specified in the transport's - `fps` constructor parameter.""" - images: list[bytes] - - def __str__(self): - return f"{self.__class__.__name__}, list size: {len(self.images)}" - - -@dataclass() -class TextFrame(Frame): - """A chunk of text. Emitted by LLM services, consumed by TTS services, can - be used to send text through pipelines.""" - text: str - - def __str__(self): - return f'{self.__class__.__name__}: "{self.text}"' - - -@dataclass() -class TranscriptionFrame(TextFrame): - """A text frame with transcription-specific data. Will be placed in the - transport's receive queue when a participant speaks.""" - participantId: str - timestamp: str - - def __str__(self): - return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" - - -@dataclass() -class InterimTranscriptionFrame(TextFrame): - """A text frame with interim transcription-specific data. Will be placed in - the transport's receive queue when a participant speaks.""" - participantId: str - timestamp: str - - def __str__(self): - return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" - - -class TTSStartFrame(ControlFrame): - """Used to indicate the beginning of a TTS response. Following AudioFrames - are part of the TTS response until an TTEndFrame. These frames can be used - for aggregating audio frames in a transport to optimize the size of frames - sent to the session, without needing to control this in the TTS service.""" - pass - - -class TTSEndFrame(ControlFrame): - """Indicates the end of a TTS response.""" - pass - - -@dataclass() -class LLMMessagesFrame(Frame): - """A frame containing a list of LLM messages. Used to signal that an LLM - service should run a chat completion and emit an LLMStartFrames, TextFrames - and an LLMEndFrame. - Note that the messages property on this class is mutable, and will be - be updated by various ResponseAggregator frame processors.""" - messages: List[dict] - - -@dataclass() -class ReceivedAppMessageFrame(Frame): - message: Any - sender: str - - def __str__(self): - return f"ReceivedAppMessageFrame: sender: {self.sender}, message: {self.message}" - - -@dataclass() -class SendAppMessageFrame(Frame): - message: Any - participant_id: str | None - - def __str__(self): - return f"SendAppMessageFrame: participant: {self.participant_id}, message: {self.message}" - - -class UserStartedSpeakingFrame(Frame): - """Emitted by VAD to indicate that a participant has started speaking. - This can be used for interruptions or other times when detecting that - someone is speaking is more important than knowing what they're saying - (as you will with a TranscriptionFrame)""" - pass - - -class UserStoppedSpeakingFrame(Frame): - """Emitted by the VAD to indicate that a user stopped speaking.""" - pass - - -class BotStartedSpeakingFrame(Frame): - pass - - -class BotStoppedSpeakingFrame(Frame): - pass - - -@dataclass() -class LLMFunctionStartFrame(Frame): - """Emitted when the LLM receives the beginning of a function call - completion. A frame processor can use this frame to indicate that it should - start preparing to make a function call, if it can do so in the absence of - any arguments.""" - function_name: str - - -@dataclass() -class LLMFunctionCallFrame(Frame): - """Emitted when the LLM has received an entire function call completion.""" - function_name: str - arguments: str diff --git a/src/dailyai/pipeline/openai_frames.py b/src/dailyai/pipeline/openai_frames.py deleted file mode 100644 index 2a14c670e..000000000 --- a/src/dailyai/pipeline/openai_frames.py +++ /dev/null @@ -1,12 +0,0 @@ -from dataclasses import dataclass - -from dailyai.pipeline.frames import Frame -from dailyai.services.openai_llm_context import OpenAILLMContext - - -@dataclass() -class OpenAILLMContextFrame(Frame): - """Like an LLMMessagesFrame, but with extra context specific to the - OpenAI API. The context in this message is also mutable, and will be - changed by the OpenAIContextAggregator frame processor.""" - context: OpenAILLMContext diff --git a/src/dailyai/pipeline/parallel_pipeline.py b/src/dailyai/pipeline/parallel_pipeline.py new file mode 100644 index 000000000..4477345bb --- /dev/null +++ b/src/dailyai/pipeline/parallel_pipeline.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from typing import List + +from dailyai.pipeline.pipeline import Pipeline +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.frames.frames import ControlFrame, EndFrame, Frame, PipelineFrame, StartFrame + +from loguru import logger + + +class Source(FrameProcessor): + + def __init__(self, upstream_queue: asyncio.Queue, **kwargs): + super().__init__(**kwargs) + self._up_queue = upstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + match direction: + case FrameDirection.UPSTREAM: + await self._up_queue.put(frame) + case FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + + +class Sink(FrameProcessor): + + def __init__(self, downstream_queue: asyncio.Queue, **kwargs): + super().__init__(**kwargs) + self._down_queue = downstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + match direction: + case FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) + case FrameDirection.DOWNSTREAM: + await self._down_queue.put(frame) + + +class ParallelPipeline(FrameProcessor): + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + + if len(args) == 0: + raise Exception(f"ParallelPipeline needs at least one argument") + + self._sources = [] + self._sinks = [] + + self._running = True + self._up_queue = asyncio.Queue() + self._up_task = asyncio.create_task(self._process_up_queue()) + self._down_queue = asyncio.Queue() + self._down_task = asyncio.create_task(self._process_down_queue()) + + self._pipelines = [] + self._tasks = [] + + logger.debug(f"Creating {self} pipelines") + for processors in args: + if not isinstance(processors, list): + raise TypeError(f"ParallelPipeline argument {processors} is not a list") + + # We add a source at the beginning of the pipeline and a sink at the end. + source = Source(self._up_queue) + sink = Sink(self._down_queue) + processors: List[FrameProcessor] = [source] + processors + processors.append(sink) + + # Keep track of sources and sinks + self._sources.append(source) + self._sinks.append(sink) + + # Create pipeline (they will start later) + pipeline = Pipeline(processors, send_pipeline_frames=False) + self._pipelines.append(pipeline) + logger.debug(f"Finished creating {self} pipelines") + + # + # Frame processor + # + + async def process_frame(self, frame: Frame, direction: FrameDirection): + # If we get a StartFrame we start all the pipelines tasks'. + if isinstance(frame, StartFrame): + self._tasks = [asyncio.create_task(p.run()) for p in self._pipelines] + # If we get an EndFrame we stop our queues processing tasks and wait on + # all the pipelines to finish. + elif isinstance(frame, EndFrame): + self._running = False + + if direction == FrameDirection.UPSTREAM: + # If we get an upstream frame we process it in each sink. + asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks]) + elif direction == FrameDirection.DOWNSTREAM: + # We also push the frame to the internal pipelines. + asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources]) + + if not self._running: + await self._up_task + await self._down_task + asyncio.gather(*self._tasks) + + async def _process_up_queue(self): + seen_ids = set() + while self._running: + frame = await self._up_queue.get() + if frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.UPSTREAM) + seen_ids.add(frame.id) + + async def _process_down_queue(self): + seen_ids = set() + while self._running: + frame = await self._down_queue.get() + if frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.DOWNSTREAM) + seen_ids.add(frame.id) diff --git a/src/dailyai/pipeline/pipeline.py b/src/dailyai/pipeline/pipeline.py index e1a6a15f7..5df63bb94 100644 --- a/src/dailyai/pipeline/pipeline.py +++ b/src/dailyai/pipeline/pipeline.py @@ -1,149 +1,106 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio -import logging -from typing import AsyncGenerator, AsyncIterable, Iterable, List -from dailyai.pipeline.frame_processor import FrameProcessor +import signal + +from asyncio import AbstractEventLoop +from typing import AsyncIterable, Iterable, List -from dailyai.pipeline.frames import AudioFrame, EndPipeFrame, EndFrame, Frame +from dailyai.frames.frames import EndFrame, Frame, StartFrame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.utils.utils import obj_count + +from loguru import logger class Pipeline: - """ - This class manages a pipe of FrameProcessors, and runs them in sequence. The "source" - and "sink" queues are managed by the caller. You can use this class stand-alone to - perform specialized processing, or you can use the Transport's run_pipeline method to - instantiate and run a pipeline with the Transport's sink and source queues. - """ def __init__( self, - processors: List[FrameProcessor], - source: asyncio.Queue | None = None, - sink: asyncio.Queue[Frame] | None = None, + processors: List[FrameProcessor] = [], name: str | None = None, + loop: AbstractEventLoop | None = None, + send_pipeline_frames: bool = True, ): - """Create a new pipeline. By default we create the sink and source queues - if they're not provided, but these can be overridden to point to other - queues. If this pipeline is run by a transport, its sink and source queues - will be overridden. - """ - self._processors: List[FrameProcessor] = processors - - self.source: asyncio.Queue[Frame] = source or asyncio.Queue() - self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue() - - self._logger = logging.getLogger("dailyai.pipeline") - self._last_log_line = "" - self._shown_repeated_log = False - self._name = name or str(id(self)) - - def set_source(self, source: asyncio.Queue[Frame]): - """Set the source queue for this pipeline. Frames from this queue - will be processed by each frame_processor in the pipeline, or order - from first to last.""" - self.source = source - - def set_sink(self, sink: asyncio.Queue[Frame]): - """Set the sink queue for this pipeline. After the last frame_processor - has processed a frame, its output will be placed on this queue.""" - self.sink = sink + self.id = id(self) + self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" + self._loop: AbstractEventLoop = loop or asyncio.get_event_loop() + self._processors: List[FrameProcessor] = [] + self._send_pipeline_frames = send_pipeline_frames + + self._running = False + self._start_added = False + self._source_queue: asyncio.Queue = asyncio.Queue() + + self._setup_sigint() + self._add_processors(processors) + self._link_processors() def add_processor(self, processor: FrameProcessor): self._processors.append(processor) - async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]: - """Convenience function to get the next frame from the source queue. This - lets us consistently have an AsyncGenerator yield frames, from either the - source queue or a frame_processor.""" + async def stop(self): + if self._send_pipeline_frames: + await self._source_queue.put(EndFrame()) - yield await self.source.get() + async def run(self): + self._running = True + await self._maybe_send_start_frame() + while self._running: + frame = await self._source_queue.get() + if len(self._processors) > 0: + await self._processors[0].process_frame(frame, FrameDirection.DOWNSTREAM) + self._running = not isinstance(frame, EndFrame) - async def queue_frames( - self, - frames: Iterable[Frame] | AsyncIterable[Frame], - ) -> None: - """Insert frames directly into a pipeline. This is typically used inside a transport - participant_joined callback to prompt a bot to start a conversation, for example.""" + await self._cleanup_processors() + async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): + if not self._running: + await self._maybe_send_start_frame() if isinstance(frames, AsyncIterable): async for frame in frames: - await self.source.put(frame) + await self._source_queue.put(frame) elif isinstance(frames, Iterable): for frame in frames: - await self.source.put(frame) + await self._source_queue.put(frame) else: raise Exception("Frames must be an iterable or async iterable") - async def run_pipeline(self): - """Run the pipeline. Take each frame from the source queue, pass it to - the first frame_processor, pass the output of that frame_processor to the - next in the list, etc. until the last frame_processor has processed the - resulting frames, then place those frames in the sink queue. - - The source and sink queues must be set before calling this method. - - This method will exit when an EndFrame is placed on the sink queue. - No more frames will be placed on the sink queue after an EndFrame, even - if it's not the last frame yielded by the last frame_processor in the pipeline.. - """ - - try: - while True: - initial_frame = await self.source.get() - async for frame in self._run_pipeline_recursively( - initial_frame, self._processors - ): - self._log_frame(frame, len(self._processors) + 1) - await self.sink.put(frame) - - if isinstance(initial_frame, EndFrame) or isinstance( - initial_frame, EndPipeFrame - ): - break - except asyncio.CancelledError: - # this means there's been an interruption, do any cleanup necessary - # here. - for processor in self._processors: - await processor.interrupted() - - async def _run_pipeline_recursively( - self, initial_frame: Frame, processors: List[FrameProcessor], depth=1 - ) -> AsyncGenerator[Frame, None]: - """Internal function to add frames to the pipeline as they're yielded - by each processor.""" - if processors: - self._log_frame(initial_frame, depth) - async for frame in processors[0].process_frame(initial_frame): - async for final_frame in self._run_pipeline_recursively( - frame, processors[1:], depth + 1 - ): - yield final_frame - else: - yield initial_frame - - def _log_frame(self, frame: Frame, depth: int): - """Log a frame as it moves through the pipeline. This is useful for debugging. - Note that this function inherits the logging level from the "dailyai" logger. - If you want debug output from dailyai in general but not this function (it is - noisy) you can silence this function by doing something like this: - - # enable debug logging for the dailyai package. - logger = logging.getLogger("dailyai") - logger.setLevel(logging.DEBUG) - - # silence the pipeline logging - logger = logging.getLogger("dailyai.pipeline") - logger.setLevel(logging.WARNING) - """ - source = str(self._processors[depth - 2]) if depth > 1 else "source" - dest = str(self._processors[depth - 1]) if depth < (len(self._processors) + 1) else "sink" - prefix = self._name + " " * depth - logline = prefix + " -> ".join([source, frame.__class__.__name__, dest]) - if logline == self._last_log_line: - if self._shown_repeated_log: - return - self._shown_repeated_log = True - self._logger.debug(prefix + "... repeated") - else: - self._shown_repeated_log = False - self._last_log_line = logline - self._logger.debug(logline) + async def _maybe_send_start_frame(self): + if self._send_pipeline_frames and not self._start_added: + self._start_added = True + await self._source_queue.put(StartFrame()) + + def _setup_sigint(self): + self._loop.add_signal_handler( + signal.SIGINT, + lambda *args: asyncio.create_task(self._sigint_handler()) + ) + + async def _sigint_handler(self): + await self.stop() + + async def _cleanup_processors(self): + for p in self._processors: + await p.cleanup() + + def _add_processors(self, processors: List[FrameProcessor]): + for p in processors: + p.set_event_loop(self._loop) + self._processors.append(p) + + def _link_processors(self): + if len(self._processors) <= 1: + return + + prev = self._processors[0] + for curr in self._processors[1:]: + prev.link(curr) + prev = curr + + def __str__(self): + return self.name diff --git a/src/dailyai/processors/__init__.py b/src/dailyai/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dailyai/processors/frame_processor.py b/src/dailyai/processors/frame_processor.py new file mode 100644 index 000000000..1996bdd95 --- /dev/null +++ b/src/dailyai/processors/frame_processor.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from asyncio import AbstractEventLoop +from enum import Enum +from typing import List + +from dailyai.frames.frames import ControlFrame, Frame, PipelineFrame +from dailyai.utils.utils import obj_count + +from loguru import logger + + +class FrameDirection(Enum): + DOWNSTREAM = 1 + UPSTREAM = 2 + + +class FrameProcessor: + + def __init__(self): + self.id: int = id(self) + self.name = f"{self.__class__.__name__}#{obj_count(self)}" + self._prev: "FrameProcessor" | None = None + self._next: "FrameProcessor" | None = None + self._loop: AbstractEventLoop | None = None + + async def cleanup(self): + pass + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, PipelineFrame) or isinstance(frame, ControlFrame): + await self.push_frame(frame, direction) + + def link(self, processor: 'FrameProcessor'): + self._next = processor + processor._prev = self + logger.debug(f"Linking {self} -> {self._next}") + + def event_loop(self) -> AbstractEventLoop: + return self._loop + + def set_event_loop(self, loop: AbstractEventLoop): + self._loop = loop + + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + if direction == FrameDirection.DOWNSTREAM and self._next: + logger.trace(f"Pushing {frame} from {self} to {self._next}") + await self._next.process_frame(frame, direction) + elif direction == FrameDirection.UPSTREAM and self._prev: + logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}") + await self._prev.process_frame(frame, direction) + + def __str__(self): + return self.name diff --git a/src/dailyai/processors/passthrough.py b/src/dailyai/processors/passthrough.py new file mode 100644 index 000000000..0e48f87f6 --- /dev/null +++ b/src/dailyai/processors/passthrough.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from dailyai.frames.frames import Frame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor + +from loguru import logger + + +class Passthrough(FrameProcessor): + + def __init__(self, input_frames: List[str], ** kwargs): + super().__init__(**kwargs) + + self._input_frames = input_frames + + logger.debug(f"Created {self} with passthrough frames {input_frames}") + + # + # Frame processor + # + + def input_frames(self) -> List[str]: + return self._input_frames + + def output_frames(self) -> List[str]: + return self._input_frames + + async def process_frame(self, frame: Frame, direction: FrameDirection): + # Nothing to do, just push frames. + await self.push_frame(frame, direction) diff --git a/src/dailyai/processors/text_transformer.py b/src/dailyai/processors/text_transformer.py new file mode 100644 index 000000000..4dc6c74c8 --- /dev/null +++ b/src/dailyai/processors/text_transformer.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Coroutine + +from dailyai.frames.frames import Frame, TextFrame +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor + + +class StatelessTextTransformer(FrameProcessor): + """This processor calls the given function on any text in a text frame. + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame.text) + + >>> aggregator = StatelessTextTransformer(lambda x: x.upper()) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) + HELLO + """ + + def __init__(self, transform_fn): + super().__init__() + self._transform_fn = transform_fn + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + result = self._transform_fn(frame.data) + if isinstance(result, Coroutine): + result = await result + await self.push_frame(result) + + await super().process_frame(frame, direction) diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 5ba732acd..62d32683d 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -1,30 +1,23 @@ -import io -import logging -import time -import wave -from dailyai.pipeline.frame_processor import FrameProcessor - -from dailyai.pipeline.frames import ( - AudioFrame, +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import abstractmethod + +from dailyai.frames.frames import ( EndFrame, - EndPipeFrame, - ImageFrame, Frame, - TTSEndFrame, - TTSStartFrame, TextFrame, - TranscriptionFrame, - URLImageFrame, - VisionImageFrame, + VisionImageRawFrame, ) - -from abc import abstractmethod -from typing import AsyncGenerator, BinaryIO +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor class AIService(FrameProcessor): def __init__(self): - self.logger = logging.getLogger("dailyai") + super().__init__() class LLMService(AIService): @@ -35,51 +28,36 @@ def __init__(self): class TTSService(AIService): - def __init__(self, aggregate_sentences=True): + def __init__(self, aggregate_sentences: bool = True): super().__init__() - self.aggregate_sentences: bool = aggregate_sentences - self.current_sentence: str = "" + self._aggregate_sentences: bool = aggregate_sentences + self._current_sentence: str = "" - # Some TTS services require a specific sample rate. We default to 16k - def get_mic_sample_rate(self): - return 16000 - - # Converts the text to audio. Yields a list of audio frames that can - # be sent to the microphone device + # Converts the text to audio. @abstractmethod - async def run_tts(self, text) -> AsyncGenerator[bytes, None]: - # yield empty bytes here, so linting can infer what this method does - yield bytes() - - async def wrap_tts(self, text) -> AsyncGenerator[Frame, None]: - yield TTSStartFrame() - async for audio_chunk in self.run_tts(text): - yield AudioFrame(audio_chunk) - yield TTSEndFrame() - yield TextFrame(text) - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame): - if self.current_sentence: - async for cleanup_frame in self.wrap_tts(self.current_sentence): - yield cleanup_frame - - if not isinstance(frame, TextFrame): - yield frame - return + async def run_tts(self, text: str): + pass + async def _process_text_frame(self, frame: TextFrame): text: str | None = None - if not self.aggregate_sentences: - text = frame.text + if not self._aggregate_sentences: + text = frame.data else: - self.current_sentence += frame.text - if self.current_sentence.strip().endswith((".", "?", "!")): - text = self.current_sentence - self.current_sentence = "" + self._current_sentence += frame.data + if self._current_sentence.strip().endswith((".", "?", "!")): + text = self._current_sentence + self._current_sentence = "" if text: - async for frame in self.wrap_tts(text): - yield frame + await self.run_tts(text) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + await self._process_text_frame(frame) + elif isinstance(frame, EndFrame): + if self._current_sentence: + await self.run_tts(self._current_sentence) + await super().process_frame(frame, direction) class ImageGenService(AIService): @@ -88,16 +66,13 @@ def __init__(self, **kwargs): # Renders the image. Returns an Image object. @abstractmethod - async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]]: + async def run_image_gen(self, prompt: str): pass - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if not isinstance(frame, TextFrame): - yield frame - return - - (url, image_data, image_size) = await self.run_image_gen(frame.text) - yield URLImageFrame(url, image_data, image_size) + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, TextFrame): + await self.run_image_gen(frame.data) + await super().process_frame(frame, direction) class VisionService(AIService): @@ -108,58 +83,42 @@ def __init__(self, **kwargs): self._describe_text = None @abstractmethod - async def run_vision(self, frame: VisionImageFrame) -> str: - pass - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, VisionImageFrame): - description = await self.run_vision(frame) - yield TextFrame(description) - else: - yield frame - - -class STTService(AIService): - """STTService is a base class for speech-to-text services.""" - - _frame_rate: int - - def __init__(self, frame_rate: int = 16000, **kwargs): - super().__init__(**kwargs) - self._frame_rate = frame_rate - - @abstractmethod - async def run_stt(self, audio: BinaryIO) -> str: - """Returns transcript as a string""" + async def run_vision(self, frame: VisionImageRawFrame): pass - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - """Processes a frame of audio data, either buffering or transcribing it.""" - if not isinstance(frame, AudioFrame): - return - - data = frame.data - content = io.BufferedRandom(io.BytesIO()) - ww = wave.open(self._content, "wb") - ww.setnchannels(1) - ww.setsampwidth(2) - ww.setframerate(self._frame_rate) - ww.writeframesraw(data) - ww.close() - content.seek(0) - text = await self.run_stt(content) - yield TranscriptionFrame(text, "", str(time.time())) - - -class FrameLogger(AIService): - def __init__(self, prefix="Frame", **kwargs): - super().__init__(**kwargs) - self.prefix = prefix - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, (AudioFrame, ImageFrame)): - self.logger.info(f"{self.prefix}: {type(frame)}") - else: - print(f"{self.prefix}: {frame}") - - yield frame + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, VisionImageRawFrame): + await self.run_vision(frame) + await super().process_frame(frame, direction) + + +# class STTService(AIService): +# """STTService is a base class for speech-to-text services.""" + +# _frame_rate: int + +# def __init__(self, frame_rate: int = 16000, **kwargs): +# super().__init__(**kwargs) +# self._frame_rate = frame_rate + +# @abstractmethod +# async def run_stt(self, audio: BinaryIO) -> str: +# """Returns transcript as a string""" +# pass + +# async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: +# """Processes a frame of audio data, either buffering or transcribing it.""" +# if not isinstance(frame, AudioFrame): +# return + +# data = frame.data +# content = io.BufferedRandom(io.BytesIO()) +# ww = wave.open(self._content, "wb") +# ww.setnchannels(1) +# ww.setsampwidth(2) +# ww.setframerate(self._frame_rate) +# ww.writeframesraw(data) +# ww.close() +# content.seek(0) +# text = await self.run_stt(content) +# yield TranscriptionFrame(text, "", str(time.time())) diff --git a/src/dailyai/services/anthropic.py b/src/dailyai/services/anthropic.py new file mode 100644 index 000000000..69af99615 --- /dev/null +++ b/src/dailyai/services/anthropic.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.frames.frames import Frame, LLMMessagesFrame, TextFrame +from dailyai.processors.frame_processor import FrameDirection +from dailyai.services.ai_services import LLMService + +try: + from anthropic import AsyncAnthropic +except ModuleNotFoundError as e: + print(f"Exception: {e}") + print( + "In order to use Anthropic, you need to `pip install dailyai[anthropic]`. Also, set `ANTHROPIC_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class AnthropicLLMService(LLMService): + + def __init__( + self, + api_key, + model="claude-3-opus-20240229", + max_tokens=1024): + super().__init__() + self.client = AsyncAnthropic(api_key=api_key) + self.model = model + self.max_tokens = max_tokens + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, LLMMessagesFrame): + stream = await self.client.messages.create( + max_tokens=self.max_tokens, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model=self.model, + stream=True, + ) + async for event in stream: + if event.type == "content_block_delta": + await self.push_frame(TextFrame(event.delta.text)) + + await super().process_frame(frame, direction) diff --git a/src/dailyai/services/anthropic_llm_service.py b/src/dailyai/services/anthropic_llm_service.py deleted file mode 100644 index 44c045992..000000000 --- a/src/dailyai/services/anthropic_llm_service.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import AsyncGenerator -from dailyai.pipeline.frames import Frame, LLMMessagesFrame, TextFrame - -from dailyai.services.ai_services import LLMService - -try: - from anthropic import AsyncAnthropic -except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( - "In order to use Anthropic, you need to `pip install dailyai[anthropic]`. Also, set `ANTHROPIC_API_KEY` environment variable.") - raise Exception(f"Missing module: {e}") - - -class AnthropicLLMService(LLMService): - - def __init__( - self, - api_key, - model="claude-3-opus-20240229", - max_tokens=1024): - super().__init__() - self.client = AsyncAnthropic(api_key=api_key) - self.model = model - self.max_tokens = max_tokens - - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if not isinstance(frame, LLMMessagesFrame): - yield frame - - stream = await self.client.messages.create( - max_tokens=self.max_tokens, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model=self.model, - stream=True, - ) - async for event in stream: - if event.type == "content_block_delta": - yield TextFrame(event.delta.text) diff --git a/src/dailyai/services/azure_ai_services.py b/src/dailyai/services/azure.py similarity index 100% rename from src/dailyai/services/azure_ai_services.py rename to src/dailyai/services/azure.py diff --git a/src/dailyai/services/deepgram_ai_services.py b/src/dailyai/services/deepgram.py similarity index 65% rename from src/dailyai/services/deepgram_ai_services.py rename to src/dailyai/services/deepgram.py index c6aaa55d1..e15b66a9e 100644 --- a/src/dailyai/services/deepgram_ai_services.py +++ b/src/dailyai/services/deepgram.py @@ -1,8 +1,17 @@ -from collections.abc import AsyncGenerator +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.frames.frames import AudioRawFrame from dailyai.services.ai_services import TTSService +from loguru import logger + class DeepgramTTSService(TTSService): + def __init__( self, *, @@ -15,15 +24,13 @@ def __init__( self._api_key = api_key self._aiohttp_session = aiohttp_session - def get_mic_sample_rate(self): - return 24000 - - async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]: - self.logger.info(f"Running deepgram tts for {sentence}") + async def run_tts(self, text: str): + logger.info(f"Running Deepgram TTS for {text}") base_url = "https://api.beta.deepgram.com/v1/speak" request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000" headers = {"authorization": f"token {self._api_key}"} - body = {"text": sentence} + body = {"text": text} async with self._aiohttp_session.post(request_url, headers=headers, json=body) as r: async for data in r.content: - yield data + frame = AudioRawFrame(data, 16000, 1) + await self.push_frame(frame) diff --git a/src/dailyai/services/deepgram_ai_service.py b/src/dailyai/services/deepgram_ai_service.py deleted file mode 100644 index 4b552927e..000000000 --- a/src/dailyai/services/deepgram_ai_service.py +++ /dev/null @@ -1,36 +0,0 @@ -import aiohttp - -from dailyai.services.ai_services import TTSService - - -class DeepgramAIService(TTSService): - def __init__( - self, - *, - aiohttp_session: aiohttp.ClientSession, - api_key, - voice, - sample_rate=16000 - ): - super().__init__() - - self._api_key = api_key - self._voice = voice - self._sample_rate = sample_rate - self._aiohttp_session = aiohttp_session - - async def run_tts(self, sentence): - self.logger.info(f"Running deepgram tts for {sentence}") - base_url = "https://api.beta.deepgram.com/v1/speak" - request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate={self._sample_rate}" - headers = { - "authorization": f"token {self._api_key}", - "Content-Type": "application/json"} - data = {"text": sentence} - - async with self._aiohttp_session.post( - request_url, headers=headers, json=data - ) as r: - async for chunk in r.content: - if chunk: - yield chunk diff --git a/src/dailyai/services/elevenlabs.py b/src/dailyai/services/elevenlabs.py new file mode 100644 index 000000000..1302d3529 --- /dev/null +++ b/src/dailyai/services/elevenlabs.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp + +from dailyai.frames.frames import AudioRawFrame, TTSStartedFrame, TTSStoppedFrame +from dailyai.services.ai_services import TTSService + +from loguru import logger + + +class ElevenLabsTTSService(TTSService): + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + api_key: str, + voice_id: str, + model: str = "eleven_turbo_v2", + ): + super().__init__() + + self._api_key = api_key + self._voice_id = voice_id + self._aiohttp_session = aiohttp_session + self._model = model + + async def run_tts(self, text: str): + logger.debug(f"Transcribing text: {text}") + + url = f"https://api.elevenlabs.io/v1/text-to-speech/{self._voice_id}/stream" + + payload = {"text": text, "model_id": self._model} + + querystring = { + "output_format": "pcm_16000", + "optimize_streaming_latency": 2} + + headers = { + "xi-api-key": self._api_key, + "Content-Type": "application/json", + } + + async with self._aiohttp_session.post(url, json=payload, headers=headers, params=querystring) as r: + if r.status != 200: + logger.error(f"Audio fetch status code: {r.status}, error: {r.text}") + return + + await self.push_frame(TTSStartedFrame()) + async for chunk in r.content: + if len(chunk) > 0: + frame = AudioRawFrame(chunk, 16000, 1) + await self.push_frame(frame) + await self.push_frame(TTSStoppedFrame()) diff --git a/src/dailyai/services/elevenlabs_ai_service.py b/src/dailyai/services/elevenlabs_ai_service.py deleted file mode 100644 index c31d50bed..000000000 --- a/src/dailyai/services/elevenlabs_ai_service.py +++ /dev/null @@ -1,46 +0,0 @@ -import aiohttp - -from typing import AsyncGenerator - -from dailyai.services.ai_services import TTSService - - -class ElevenLabsTTSService(TTSService): - - def __init__( - self, - *, - aiohttp_session: aiohttp.ClientSession, - api_key, - voice_id, - model="eleven_turbo_v2", - ): - super().__init__() - - self._api_key = api_key - self._voice_id = voice_id - self._aiohttp_session = aiohttp_session - self._model = model - - async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]: - url = f"https://api.elevenlabs.io/v1/text-to-speech/{self._voice_id}/stream" - payload = {"text": sentence, "model_id": self._model} - querystring = { - "output_format": "pcm_16000", - "optimize_streaming_latency": 2} - headers = { - "xi-api-key": self._api_key, - "Content-Type": "application/json", - } - async with self._aiohttp_session.post( - url, json=payload, headers=headers, params=querystring - ) as r: - if r.status != 200: - self.logger.error( - f"audio fetch status code: {r.status}, error: {r.text}" - ) - return - - async for chunk in r.content: - if chunk: - yield chunk diff --git a/src/dailyai/services/fal_ai_services.py b/src/dailyai/services/fal.py similarity index 72% rename from src/dailyai/services/fal_ai_services.py rename to src/dailyai/services/fal.py index a924607d2..2063bb376 100644 --- a/src/dailyai/services/fal_ai_services.py +++ b/src/dailyai/services/fal.py @@ -1,14 +1,22 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import aiohttp -import asyncio import io import os + from PIL import Image from pydantic import BaseModel from typing import Optional, Union, Dict - +from dailyai.frames.frames import URLImageRawFrame from dailyai.services.ai_services import ImageGenService +from loguru import logger + try: import fal_client except ModuleNotFoundError as e: @@ -33,8 +41,8 @@ def __init__( *, aiohttp_session: aiohttp.ClientSession, params: InputParams, - model="fal-ai/fast-sdxl", - key=None, + model: str = "fal-ai/fast-sdxl", + key: str | None = None, ): super().__init__() self._model = model @@ -43,7 +51,9 @@ def __init__( if key: os.environ["FAL_KEY"] = key - async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]]: + async def run_image_gen(self, prompt: str): + logger.debug(f"Generating image from prompt: {prompt}") + response = await fal_client.run_async( self._model, arguments={"prompt": prompt, **self._params.dict()} @@ -52,10 +62,14 @@ async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]] image_url = response["images"][0]["url"] if response else None if not image_url: - raise Exception("Image generation failed") + logger.error("Image generation failed") + return # Load the image from the url async with self._aiohttp_session.get(image_url) as response: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) - return (image_url, image.tobytes(), image.size) + + frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format) + print("FALLLLLLLLLLLLLLLLLLLL IMAGE ", frame) + await self.push_frame(frame) diff --git a/src/dailyai/services/fireworks_ai_services.py b/src/dailyai/services/fireworks.py similarity index 50% rename from src/dailyai/services/fireworks_ai_services.py rename to src/dailyai/services/fireworks.py index e5ccbc658..9f995c603 100644 --- a/src/dailyai/services/fireworks_ai_services.py +++ b/src/dailyai/services/fireworks.py @@ -1,7 +1,10 @@ -import os - -from dailyai.services.openai_api_llm_service import BaseOpenAILLMService +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# +from dailyai.services.openai import BaseOpenAILLMService try: from openai import AsyncOpenAI @@ -13,6 +16,7 @@ class FireworksLLMService(BaseOpenAILLMService): - def __init__(self, model="accounts/fireworks/models/firefunction-v1", *args, **kwargs): - kwargs["base_url"] = "https://api.fireworks.ai/inference/v1" - super().__init__(model, *args, **kwargs) + def __init__(self, + model="accounts/fireworks/models/firefunction-v1", + base_url="https://api.fireworks.ai/inference/v1"): + super().__init__(model, base_url) diff --git a/src/dailyai/services/live_stream.py b/src/dailyai/services/live_stream.py new file mode 100644 index 000000000..470ffa235 --- /dev/null +++ b/src/dailyai/services/live_stream.py @@ -0,0 +1,323 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import inspect +import itertools +import queue +import threading +import time +import types + +from functools import partial +from asyncio import AbstractEventLoop +from typing import List + +from dailyai.processors.frame_processor import FrameDirection, FrameProcessor +from dailyai.frames.frames import ( + AudioRawFrame, + StartFrame, + EndFrame, + Frame, + ImageRawFrame, + TranscriptionFrame, + InterimTranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) +from dailyai.transports.live_stream_transport import LiveStreamTransport +from dailyai.vad.vad_analyzer import VADState + +from loguru import logger + + +class LiveStream(FrameProcessor): + + def __init__(self, transport: LiveStreamTransport, **kwargs): + super().__init__() + + self._camera_enabled = kwargs.get("camera_enabled") or False + self._mic_enabled = kwargs.get("mic_enabled") or False + self._speaker_enabled = kwargs.get("speaker_enabled") or False + + self._event_handlers: dict = {} + self._transport: LiveStreamTransport = transport + self._other_participant_has_joined = False + self._running = False + self._images = None + + if self._transport.video_capture_enabled: + self._camera_in_queue = queue.Queue() + self._camera_in_thread = threading.Thread(target=self._camera_in_thread_handler) + if self._camera_enabled: + self._camera_out_thread = threading.Thread(target=self._camera_out_thread_handler) + if self._speaker_enabled: + self._speaker_thread = threading.Thread(target=self._speaker_thread_handler) + + self._sink_queue = queue.Queue() + self._sink_thread = threading.Thread(target=self._sink_thread_handler) + + transport.set_event_loop(self.event_loop()) + + @transport.event_handler("on_joined") + def on_joined(transport, participant): + self.on_joined(participant) + + @transport.event_handler("on_participant_joined") + def on_participant_joined(transport, participant): + if not self._other_participant_has_joined: + self._other_participant_has_joined = True + self.on_first_participant_joined(participant) + self.on_participant_joined(participant) + + # + # Frame processor + # + + def set_event_loop(self, loop: AbstractEventLoop): + super().set_event_loop(loop) + self._transport.set_event_loop(loop) + + async def cleanup(self): + await self._transport.cleanup() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, StartFrame): + await self._start() + elif isinstance(frame, EndFrame): + await self._stop() + else: + self._sink_queue.put(frame) + + async def _start(self): + self._running = True + if self._transport.video_capture_enabled: + self._camera_in_thread.start() + + if self._camera_enabled: + self._camera_out_thread.start() + + if self._speaker_enabled: + self._speaker_thread.start() + + self._sink_thread.start() + + await self._transport.join() + + async def _stop(self): + self._running = False + if self._transport.video_capture_enabled: + self._camera_in_thread.join() + + if self._camera_enabled: + self._camera_out_thread.join() + + if self._speaker_enabled: + self._speaker_thread.join() + + self._sink_thread.join() + + await self._transport.leave() + + def _sink_thread_handler(self): + buffer = bytearray() + while self._running: + try: + frame = self._sink_queue.get(timeout=1) + if isinstance(frame, AudioRawFrame): + if self._mic_enabled: + buffer = self._process_mic_frame(buffer, frame) + elif isinstance(frame, ImageRawFrame): + if self._camera_enabled: + self._set_image(frame) + except queue.Empty: + pass + except BaseException as e: + logger.error("Error capturing video: {e}") + self._send_audio_chunk(buffer, 160) + + # + # Transcription + # + + def capture_participant_transcription(self, participant_id: str): + self._transport.capture_participant_transcription( + participant_id, + self._on_transcription_message + ) + + def _on_transcription_message(self, participant_id, message): + text = message["text"] + timestamp = message["timestamp"] + is_final = message["rawResponse"]["is_final"] + if is_final: + frame = TranscriptionFrame(text, participant_id, timestamp) + else: + frame = InterimTranscriptionFrame(text, participant_id, timestamp) + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.event_loop()) + + # + # Camera in/out. + # + + def capture_participant_video( + self, + participant_id: str, + framerate: int = 30, + video_source: str = "camera", + color_format: str = "RGB"): + self._transport.capture_participant_video( + participant_id, + self._on_participant_video_frame, + framerate, + video_source, + color_format + ) + + def _on_participant_video_frame(self, participant_id, buffer, size, format): + frame = ImageRawFrame(buffer, size, format) + frame.metadata["user_id"] = participant_id + self._camera_in_queue.put(frame) + + def _camera_in_thread_handler(self): + while self._running: + try: + frame = self._camera_in_queue.get(timeout=1) + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.event_loop()) + except queue.Empty: + pass + except BaseException as e: + logger.error("Error capturing video: {e}") + + def _set_image(self, image: ImageRawFrame): + self._images = itertools.cycle([image]) + + def _camera_out_thread_handler(self): + while self._running: + try: + if self._images: + image = next(self._images) + self._transport.write_frame_to_camera(image) + time.sleep(1.0 / self._transport.camera_framerate) + except Exception as e: + logger.error(f"Error writing to camera: {e}") + + # + # Microphone + # + + # Subdivide large audio frames to enable interruption + def _maybe_split_audio_frame(self, frame: AudioRawFrame) -> List[AudioRawFrame]: + largest_write_size = 8000 + frames: List[AudioRawFrame] = [] + if len(frame.data) > largest_write_size: + for i in range(0, len(frame.data), largest_write_size): + chunk = frame.data[i: i + largest_write_size] + frames.append(AudioRawFrame(chunk, frame.sample_rate, frame.num_channels)) + else: + frames.append(frame) + return frames + + def _send_audio_chunk(self, buffer: bytearray, write_size: int) -> bytearray: + truncated_length: int = len(buffer) - (len(buffer) % write_size) + if truncated_length: + self._transport.write_raw_audio_frames(bytes(buffer[:truncated_length])) + buffer = buffer[truncated_length:] + return buffer + + def _process_mic_frame(self, buffer: bytearray, frame: AudioRawFrame) -> bytearray: + try: + frames = self._maybe_split_audio_frame(frame) + for frame in frames: + buffer.extend(frame.data) + buffer = self._send_audio_chunk(buffer, 3200) + return buffer + except BaseException as e: + logger.error(f"Error writing to microphone: {e}") + return buffer + + # + # Speaker + # + + def _speaker_thread_handler(self): + sample_rate = self._transport.speaker_sample_rate + num_channels = self._transport.speaker_channels + num_frames = 160 + + vad_state: VADState = VADState.QUIET + while self._running: + try: + audio_frames = self._transport.read_raw_audio_frames(num_frames) + + # Check VAD and push event if necessary. We just care about changes + # from QUIET to SPEAKING and vice versa. + new_vad_state = self._transport.vad_analyze(audio_frames) + if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING: + frame = None + if new_vad_state == VADState.SPEAKING: + frame = UserStartedSpeakingFrame() + elif new_vad_state == VADState.QUIET: + frame = UserStoppedSpeakingFrame() + if frame: + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.event_loop()) + vad_state = new_vad_state + + # Always push audio downstream if not empty. + if len(audio_frames) > 0: + frame = AudioRawFrame(audio_frames, sample_rate, num_channels) + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.event_loop()) + except BaseException as e: + logger.error(f"Error reading from speaker: {e}") + + # + # Decorators (event handlers) + # + + def on_joined(self, participant): + pass + + def on_participant_joined(self, participant): + pass + + def on_first_participant_joined(self, participant): + pass + + def event_handler(self, event_name: str): + def decorator(handler): + self._add_event_handler(event_name, handler) + return handler + return decorator + + def _add_event_handler(self, event_name: str, handler): + methods = inspect.getmembers(self, predicate=inspect.ismethod) + if event_name not in [method[0] for method in methods]: + raise Exception(f"Event handler {event_name} not found") + + if event_name not in self._event_handlers: + self._event_handlers[event_name] = [getattr(self, event_name)] + patch_method = types.MethodType(partial(self._patch_method, event_name), self) + setattr(self, event_name, patch_method) + self._event_handlers[event_name].append(types.MethodType(handler, self)) + + def _patch_method(self, event_name, *args, **kwargs): + try: + for handler in self._event_handlers[event_name]: + if inspect.iscoroutinefunction(handler): + if self.event_loop(): + future = asyncio.run_coroutine_threadsafe( + handler(*args[1:], **kwargs), self.event_loop()) + + # wait for the coroutine to finish. This will also + # raise any exceptions raised by the coroutine. + future.result() + else: + raise Exception( + "No event loop to run coroutine. In order to use async event handlers, you must run the DailyTransportService in an asyncio event loop.") + else: + handler(*args[1:], **kwargs) + except Exception as e: + # TODO(aleix) self._logger.error(f"Exception in event handler {event_name}: {e}") + raise e diff --git a/src/dailyai/services/moondream_ai_service.py b/src/dailyai/services/moondream.py similarity index 68% rename from src/dailyai/services/moondream_ai_service.py rename to src/dailyai/services/moondream.py index 704d4c51b..4a72610c5 100644 --- a/src/dailyai/services/moondream_ai_service.py +++ b/src/dailyai/services/moondream.py @@ -1,13 +1,26 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio -from dailyai.pipeline.frames import ImageFrame, VisionImageFrame +from dailyai.frames.frames import VisionImageRawFrame from dailyai.services.ai_services import VisionService from PIL import Image -from transformers import AutoModelForCausalLM, AutoTokenizer +from loguru import logger + +try: + import torch -import torch + from transformers import AutoModelForCausalLM, AutoTokenizer +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Moondream, you need to `pip install dailyai[moondream]`.") + raise Exception(f"Missing module(s): {e}") def detect_device(): @@ -44,9 +57,9 @@ def __init__( ).to(device=device, dtype=dtype) self._model.eval() - async def run_vision(self, frame: VisionImageFrame) -> str: - def get_image_description(frame: VisionImageFrame): - image = Image.frombytes("RGB", (frame.size[0], frame.size[1]), frame.image) + async def run_vision(self, frame: VisionImageRawFrame): + def get_image_description(frame: VisionImageRawFrame): + image = Image.frombytes(frame.format, (frame.size[0], frame.size[1]), frame.data) image_embeds = self._model.encode_image(image) description = self._model.answer_question( image_embeds=image_embeds, diff --git a/src/dailyai/services/ollama_ai_services.py b/src/dailyai/services/ollama.py similarity index 59% rename from src/dailyai/services/ollama_ai_services.py rename to src/dailyai/services/ollama.py index adb69c7d6..876c547b7 100644 --- a/src/dailyai/services/ollama_ai_services.py +++ b/src/dailyai/services/ollama.py @@ -1,4 +1,10 @@ -from dailyai.services.openai_api_llm_service import BaseOpenAILLMService +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dailyai.services.openai import BaseOpenAILLMService class OLLamaLLMService(BaseOpenAILLMService): diff --git a/src/dailyai/services/open_ai_services.py b/src/dailyai/services/open_ai_services.py deleted file mode 100644 index 9eaec5218..000000000 --- a/src/dailyai/services/open_ai_services.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Literal -import aiohttp -from PIL import Image -import io - -from dailyai.services.ai_services import ImageGenService -from dailyai.services.openai_api_llm_service import BaseOpenAILLMService - - -try: - from openai import AsyncOpenAI -except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( - "In order to use OpenAI, you need to `pip install dailyai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") - raise Exception(f"Missing module: {e}") - - -class OpenAILLMService(BaseOpenAILLMService): - - def __init__(self, model="gpt-4", * args, **kwargs): - super().__init__(model, *args, **kwargs) - - -class OpenAIImageGenService(ImageGenService): - - def __init__( - self, - *, - image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], - aiohttp_session: aiohttp.ClientSession, - api_key, - model="dall-e-3", - ): - super().__init__() - self._model = model - self._image_size = image_size - self._client = AsyncOpenAI(api_key=api_key) - self._aiohttp_session = aiohttp_session - - async def run_image_gen(self, prompt: str) -> tuple[str, bytes, tuple[int, int]]: - self.logger.info("Generating OpenAI image", prompt) - - image = await self._client.images.generate( - prompt=prompt, - model=self._model, - n=1, - size=self._image_size - ) - image_url = image.data[0].url - if not image_url: - raise Exception("No image provided in response", image) - - # Load the image from the url - async with self._aiohttp_session.get(image_url) as response: - image_stream = io.BytesIO(await response.content.read()) - image = Image.open(image_stream) - return (image_url, image.tobytes(), image.size) diff --git a/src/dailyai/services/openai_api_llm_service.py b/src/dailyai/services/openai.py similarity index 60% rename from src/dailyai/services/openai_api_llm_service.py rename to src/dailyai/services/openai.py index 2ddfc7796..10b3a0e06 100644 --- a/src/dailyai/services/openai_api_llm_service.py +++ b/src/dailyai/services/openai.py @@ -1,18 +1,25 @@ +import io import json import time -from typing import AsyncGenerator, List -from dailyai.pipeline.frames import ( +import aiohttp +from PIL import Image + +from typing import List, Literal + +from dailyai.frames.frames import ( Frame, - LLMFunctionCallFrame, - LLMFunctionStartFrame, LLMMessagesFrame, LLMResponseEndFrame, LLMResponseStartFrame, TextFrame, + URLImageRawFrame ) -from dailyai.services.ai_services import LLMService -from dailyai.pipeline.openai_frames import OpenAILLMContextFrame -from dailyai.services.openai_llm_context import OpenAILLMContext +from dailyai.frames.openai_frames import OpenAILLMContextFrame +from dailyai.processors.frame_processor import FrameDirection +from dailyai.services.ai_services import LLMService, ImageGenService +from dailyai.aggregators.openai_llm_context import OpenAILLMContext + +from loguru import logger try: from openai import AsyncOpenAI, AsyncStream @@ -23,8 +30,8 @@ ChatCompletionMessageParam, ) except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( + logger.error(f"Exception: {e}") + logger.error( "In order to use OpenAI, you need to `pip install dailyai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") raise Exception(f"Missing module: {e}") @@ -52,7 +59,7 @@ async def _stream_chat_completions( ) -> AsyncStream[ChatCompletionChunk]: messages: List[ChatCompletionMessageParam] = context.get_messages() messages_for_log = json.dumps(messages) - self.logger.debug(f"Generating chat via openai: {messages_for_log}") + logger.debug(f"Generating chat: {messages_for_log}") start_time = time.time() chunks: AsyncStream[ChatCompletionChunk] = ( @@ -64,12 +71,15 @@ async def _stream_chat_completions( tool_choice=context.tool_choice, ) ) - self.logger.info(f"=== OpenAI LLM TTFB: {time.time() - start_time}") + + logger.debug(f"OpenAI LLM TTFB: {time.time() - start_time}") + return chunks async def _chat_completions(self, messages) -> str | None: messages_for_log = json.dumps(messages) - self.logger.debug(f"Generating chat via openai: {messages_for_log}") + + logger.debug(f"Generating chat: {messages_for_log}") response: ChatCompletion = await self._client.chat.completions.create( model=self._model, stream=False, messages=messages @@ -79,22 +89,16 @@ async def _chat_completions(self, messages) -> str | None: else: return None - async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, OpenAILLMContextFrame): - context: OpenAILLMContext = frame.context - elif isinstance(frame, LLMMessagesFrame): - context = OpenAILLMContext.from_messages(frame.messages) - else: - yield frame - return - + async def _process_context(self, context: OpenAILLMContext): function_name = "" arguments = "" - yield LLMResponseStartFrame() + await self.push_frame(LLMResponseStartFrame()) + chunk_stream: AsyncStream[ChatCompletionChunk] = ( await self._stream_chat_completions(context) ) + async for chunk in chunk_stream: if len(chunk.choices) == 0: continue @@ -114,18 +118,75 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: tool_call = chunk.choices[0].delta.tool_calls[0] if tool_call.function and tool_call.function.name: function_name += tool_call.function.name - yield LLMFunctionStartFrame(function_name=tool_call.function.name) + # yield LLMFunctionStartFrame(function_name=tool_call.function.name) if tool_call.function and tool_call.function.arguments: # Keep iterating through the response to collect all the argument fragments and # yield a complete LLMFunctionCallFrame after run_llm_async # completes arguments += tool_call.function.arguments elif chunk.choices[0].delta.content: - yield TextFrame(chunk.choices[0].delta.content) + await self.push_frame(TextFrame(chunk.choices[0].delta.content)) # if we got a function name and arguments, yield the frame with all the info so # frame consumers can take action based on the function call. - if function_name and arguments: - yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments) + # if function_name and arguments: + # yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments) + + await self.push_frame(LLMResponseEndFrame()) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + context = None + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.data + elif isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.data) + + if context: + await self._process_context(context) + + await super().process_frame(frame, direction) + + +class OpenAILLMService(BaseOpenAILLMService): + + def __init__(self, model="gpt-4", * args, **kwargs): + super().__init__(model, *args, **kwargs) + + +class OpenAIImageGenService(ImageGenService): + + def __init__( + self, + *, + image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + aiohttp_session: aiohttp.ClientSession, + api_key: str, + model: str = "dall-e-3", + ): + super().__init__() + self._model = model + self._image_size = image_size + self._client = AsyncOpenAI(api_key=api_key) + self._aiohttp_session = aiohttp_session + + async def run_image_gen(self, prompt: str): + logger.debug(f"Generating image from prompt: {prompt}") + + image = await self._client.images.generate( + prompt=prompt, + model=self._model, + n=1, + size=self._image_size + ) + + image_url = image.data[0].url + + if not image_url: + logger.error(f"no image provided in response: {image}") - yield LLMResponseEndFrame() + # Load the image from the url + async with self._aiohttp_session.get(image_url) as response: + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format) + await self.push_frame(frame) diff --git a/src/dailyai/services/openai_llm_context.py b/src/dailyai/services/openai_llm_context.py deleted file mode 100644 index 2d16c3cb6..000000000 --- a/src/dailyai/services/openai_llm_context.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List - -try: - from openai._types import NOT_GIVEN, NotGiven - - from openai.types.chat import ( - ChatCompletionToolParam, - ChatCompletionToolChoiceOptionParam, - ChatCompletionMessageParam, - ) -except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( - "In order to use OpenAI, you need to `pip install dailyai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") - raise Exception(f"Missing module: {e}") - - -class OpenAILLMContext: - - def __init__( - self, - messages: List[ChatCompletionMessageParam] | None = None, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN - ): - self.messages: List[ChatCompletionMessageParam] = messages if messages else [ - ] - self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice - self.tools: List[ChatCompletionToolParam] | NotGiven = tools - - @staticmethod - def from_messages(messages: List[dict]) -> "OpenAILLMContext": - context = OpenAILLMContext() - for message in messages: - context.add_message({ - "content": message["content"], - "role": message["role"], - "name": message["name"] if "name" in message else message["role"] - }) - return context - - # def __deepcopy__(self, memo): - - def add_message(self, message: ChatCompletionMessageParam): - self.messages.append(message) - - def get_messages(self) -> List[ChatCompletionMessageParam]: - return self.messages - - def set_tool_choice( - self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven - ): - self.tool_choice = tool_choice - - def set_tools( - self, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN): - if tools != NOT_GIVEN and len(tools) == 0: - tools = NOT_GIVEN - - self.tools = tools diff --git a/src/dailyai/services/playht_ai_service.py b/src/dailyai/services/playht.py similarity index 65% rename from src/dailyai/services/playht_ai_service.py rename to src/dailyai/services/playht.py index 291855264..fea2e6e2a 100644 --- a/src/dailyai/services/playht_ai_service.py +++ b/src/dailyai/services/playht.py @@ -1,8 +1,17 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import io import struct +from dailyai.frames.frames import AudioRawFrame from dailyai.services.ai_services import TTSService +from loguru import logger + try: from pyht import Client from pyht.client import TTSOptions @@ -16,35 +25,29 @@ class PlayHTAIService(TTSService): - def __init__( - self, - *, - api_key, - user_id, - voice_url - ): + def __init__(self, *, api_key, user_id, voice_url): super().__init__() - self.speech_key = api_key - self.user_id = user_id + self._user_id = user_id + self._speech_key = api_key - self.client = Client( - user_id=self.user_id, - api_key=self.speech_key, + self._client = Client( + user_id=self._user_id, + api_key=self._speech_key, ) - self.options = TTSOptions( + self._options = TTSOptions( voice=voice_url, sample_rate=16000, quality="higher", format=Format.FORMAT_WAV) def __del__(self): - self.client.close() + self._client.close() - async def run_tts(self, sentence): + async def run_tts(self, text: str): b = bytearray() in_header = True - for chunk in self.client.tts(sentence, self.options): + for chunk in self._client.tts(text, self._options): # skip the RIFF header. if in_header: b.extend(chunk) @@ -54,15 +57,16 @@ async def run_tts(self, sentence): fh = io.BytesIO(b) fh.seek(36) (data, size) = struct.unpack('<4sI', fh.read(8)) - self.logger.info( + logger.debug( f"first attempt: data: {data}, size: {hex(size)}, position: {fh.tell()}") while data != b'data': fh.read(size) (data, size) = struct.unpack('<4sI', fh.read(8)) - self.logger.info( + logger.debug( f"subsequent data: {data}, size: {hex(size)}, position: {fh.tell()}, data != data: {data != b'data'}") - self.logger.info("position: ", fh.tell()) + logger.debug("position: ", fh.tell()) in_header = False else: if len(chunk): - yield chunk + frame = AudioRawFrame(chunk, 16000, 1) + await self.push_frame(frame) diff --git a/src/dailyai/services/whisper_ai_services.py b/src/dailyai/services/whisper.py similarity index 100% rename from src/dailyai/services/whisper_ai_services.py rename to src/dailyai/services/whisper.py diff --git a/src/dailyai/transports/abstract_transport.py b/src/dailyai/transports/abstract_transport.py deleted file mode 100644 index 1a30c9063..000000000 --- a/src/dailyai/transports/abstract_transport.py +++ /dev/null @@ -1,42 +0,0 @@ -from abc import abstractmethod -import asyncio -import logging -import time - -from dailyai.pipeline.frame_processor import FrameProcessor -from dailyai.pipeline.pipeline import Pipeline - - -class AbstractTransport: - def __init__(self, **kwargs): - self.send_queue = asyncio.Queue() - self.receive_queue = asyncio.Queue() - self.completed_queue = asyncio.Queue() - - duration_minutes = kwargs.get("duration_minutes") or 10 - self._expiration = time.time() + duration_minutes * 60 - - self._mic_enabled = kwargs.get("mic_enabled") or False - self._mic_sample_rate = kwargs.get("mic_sample_rate") or 16000 - self._camera_enabled = kwargs.get("camera_enabled") or False - self._camera_width = kwargs.get("camera_width") or 1024 - self._camera_height = kwargs.get("camera_height") or 768 - self._camera_bitrate = kwargs.get("camera_bitrate") or 250000 - self._camera_framerate = kwargs.get("camera_framerate") or 10 - self._speaker_enabled = kwargs.get("speaker_enabled") or False - self._speaker_sample_rate = kwargs.get("speaker_sample_rate") or 16000 - - self._logger: logging.Logger = logging.getLogger("dailyai.transport") - - @abstractmethod - async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True): - pass - - @abstractmethod - async def run_interruptible_pipeline( - self, - pipeline: Pipeline, - pre_processor: FrameProcessor | None = None, - post_processor: FrameProcessor | None = None, - ): - pass diff --git a/src/dailyai/transports/daily_transport.py b/src/dailyai/transports/daily_transport.py index 86c0ee1cf..4e3064f18 100644 --- a/src/dailyai/transports/daily_transport.py +++ b/src/dailyai/transports/daily_transport.py @@ -1,89 +1,69 @@ -import asyncio -import inspect -import logging -import signal +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import time -import threading -import types -from functools import partial -from typing import Any +from typing import Any, Callable, Mapping + +from daily import VirtualCameraDevice, VirtualSpeakerDevice, VirtualMicrophoneDevice -from dailyai.pipeline.frames import ( - InterimTranscriptionFrame, - ReceivedAppMessageFrame, - TranscriptionFrame, - UserImageFrame, -) +from dailyai.frames.frames import ImageRawFrame +from dailyai.transports.live_stream_transport import LiveStreamTransport +from dailyai.vad.vad_analyzer import VADAnalyzer -from threading import Event +from loguru import logger try: - from daily import ( - EventHandler, - CallClient, - Daily, - VirtualCameraDevice, - VirtualMicrophoneDevice, - VirtualSpeakerDevice, - ) + from daily import (EventHandler, CallClient, Daily) except ModuleNotFoundError as e: - print(f"Exception: {e}") - print( - "In order to use the Daily transport, you need to `pip install dailyai[daily]`.") + logger.error(f"Exception: {e}") + logger.error("In order to use the Daily transport, you need to `pip install dailyai[daily]`.") raise Exception(f"Missing module: {e}") +VAD_RESET_PERIOD_MS = 2000 -from dailyai.transports.threaded_transport import ThreadedTransport -NUM_CHANNELS = 1 +class WebRTCVADAnalyzer(VADAnalyzer): -SPEECH_THRESHOLD = 0.90 -VAD_RESET_PERIOD_MS = 2000 + def __init__(self, sample_rate=16000, num_channels=1, **kwargs): + super().__init__(sample_rate, num_channels, **kwargs) + self._webrtc_vad = Daily.create_native_vad( + reset_period_ms=VAD_RESET_PERIOD_MS, + sample_rate=sample_rate, + channels=num_channels + ) + logger.info("Using WebRTC VAD") -class DailyTransport(ThreadedTransport, EventHandler): - _daily_initialized = False - _lock = threading.Lock() + def num_frames_required(self) -> int: + return int(self.sample_rate / 100.0) - _speaker_enabled: bool - _speaker_sample_rate: int - _vad_enabled: bool + def voice_confidence(self, buffer) -> float: + confidence = 0 + if len(buffer) > 0: + confidence = self._webrtc_vad.analyze_frames(buffer) + return confidence + + +class DailyTransport(LiveStreamTransport, EventHandler): + + _daily_initialized: bool = False # This is necessary to override EventHandler's __new__ method. def __new__(cls, *args, **kwargs): return super().__new__(cls) - def __init__( - self, - room_url: str, - token: str | None, - bot_name: str, - min_others_count: int = 1, - start_transcription: bool = False, - video_rendering_enabled: bool = False, - **kwargs, - ): - kwargs['has_webrtc_vad'] = True - # This will call ThreadedTransport.__init__ method, not EventHandler + def __init__(self, room_url: str, token: str | None, bot_name: str, **kwargs): super().__init__(**kwargs) - self._room_url: str = room_url - self._bot_name: str = bot_name - self._token: str | None = token - self._min_others_count = min_others_count - self._start_transcription = start_transcription - self._video_rendering_enabled = video_rendering_enabled - - self._is_interrupted = Event() - self._stop_threads = Event() - - self._other_participant_has_joined = False - self._my_participant_id = None - - self._video_renderers = {} + if not self._daily_initialized: + self._daily_initialized = True + Daily.init() - self.transcription_settings = { + self._transcription_settings = kwargs.get("transcription_settings") or { "language": "en", "tier": "nova", "model": "2-conversationalai", @@ -97,122 +77,69 @@ def __init__( }, } - self._logger: logging.Logger = logging.getLogger("dailyai") + self._room_url: str = room_url + self._token: str | None = token + self._bot_name: str = bot_name - self._event_handlers = {} + self._participant_id: str = "" + self._video_renderers = {} + self._transcription_renderers = {} - self.webrtc_vad = Daily.create_native_vad( - reset_period_ms=VAD_RESET_PERIOD_MS, - sample_rate=self._speaker_sample_rate, - channels=NUM_CHANNELS - ) + self._client: CallClient = CallClient(event_handler=self) - def _patch_method(self, event_name, *args, **kwargs): - try: - for handler in self._event_handlers[event_name]: - if inspect.iscoroutinefunction(handler): - if self._loop: - future = asyncio.run_coroutine_threadsafe( - handler(*args, **kwargs), self._loop) - - # wait for the coroutine to finish. This will also - # raise any exceptions raised by the coroutine. - future.result() - else: - raise Exception( - "No event loop to run coroutine. In order to use async event handlers, you must run the DailyTransportService in an asyncio event loop.") - else: - handler(*args, **kwargs) - except Exception as e: - self._logger.error(f"Exception in event handler {event_name}: {e}") - raise e - - def _webrtc_vad_analyze(self): - buffer = self.read_audio_frames(int(self._vad_samples)) - if len(buffer) > 0: - confidence = self.webrtc_vad.analyze_frames(buffer) - # yeses = int(confidence * 20.0) - # nos = 20 - yeses - # out = "!" * yeses + "." * nos - # print(f"!!! confidence: {out} {confidence}") - talking = confidence > SPEECH_THRESHOLD - return talking - - def add_event_handler(self, event_name: str, handler): - if not event_name.startswith("on_"): - raise Exception( - f"Event handler {event_name} must start with 'on_'") - - methods = inspect.getmembers(self, predicate=inspect.ismethod) - if event_name not in [method[0] for method in methods]: - raise Exception(f"Event handler {event_name} not found") - - if event_name not in self._event_handlers: - self._event_handlers[event_name] = [ - getattr( - self, event_name), types.MethodType( - handler, self)] - setattr(self, event_name, partial(self._patch_method, event_name)) - else: - self._event_handlers[event_name].append( - types.MethodType(handler, self)) - - def event_handler(self, event_name: str): - def decorator(handler): - self.add_event_handler(event_name, handler) - return handler - - return decorator - - def write_frame_to_camera(self, frame: bytes): - if self._camera_enabled: - self.camera.write_frame(frame) - - def write_frame_to_mic(self, frame: bytes): - if self._mic_enabled: - self.mic.write_frames(frame) - - def request_participant_image(self, participant_id: str): - if participant_id in self._video_renderers: - self._video_renderers[participant_id]["render_next_frame"] = True - - def send_app_message(self, message: Any, participant_id: str | None): - self.client.send_app_message(message, participant_id) - - def read_audio_frames(self, desired_frame_count): - bytes = b"" - if self._speaker_enabled or self._vad_enabled: - bytes = self._speaker.read_frames(desired_frame_count) - return bytes - - def _prerun(self): - # Only initialize Daily once - if not DailyTransport._daily_initialized: - with DailyTransport._lock: - Daily.init() - DailyTransport._daily_initialized = True - self.client = CallClient(event_handler=self) - - if self._mic_enabled: - self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device( - "mic", sample_rate=self._mic_sample_rate, channels=1 - ) + if self.camera_enabled: + self._camera: VirtualCameraDevice = Daily.create_camera_device( + "camera", width=self.camera_width, height=self.camera_height, color_format="RGB") - if self._camera_enabled: - self.camera: VirtualCameraDevice = Daily.create_camera_device( - "camera", width=self._camera_width, height=self._camera_height, color_format="RGB") + if self.mic_enabled: + self._mic: VirtualMicrophoneDevice = Daily.create_microphone_device( + "mic", sample_rate=self.mic_sample_rate, channels=self.mic_channels + ) - if self._speaker_enabled or self._vad_enabled: + if self.speaker_enabled: self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device( - "speaker", sample_rate=self._speaker_sample_rate, channels=1 - ) + "speaker", sample_rate=self.speaker_sample_rate, channels=self.speaker_channels) Daily.select_speaker_device("speaker") - self.client.set_user_name(self._bot_name) - self.client.join( + if self.vad_enabled and not self.vad_analyzer: + self.vad_analyzer = WebRTCVADAnalyzer( + sample_rate=self.speaker_sample_rate, + num_channels=self.speaker_channels) + + self._joined = False + + # + # LiveStreamTransport + # + + @property + def participant_id(self) -> str: + return self._participant_id + + async def join(self): + # Transport already joined, ignore. + if self._joined: + return + + logger.info(f"Joining {self._room_url}") + + self._joined = True + + # For performance reasons, never subscribe to video streams (unless a + # video renderer is registered). + self._client.update_subscription_profiles({ + "base": { + "camera": "unsubscribed", + "screenVideo": "unsubscribed" + } + }) + + self._client.set_user_name(self._bot_name) + + self._client.join( self._room_url, self._token, - completion=self.call_joined, + completion=self._call_joined, client_settings={ "inputs": { "camera": { @@ -239,66 +166,54 @@ def _prerun(self): "maxQuality": "low", "encodings": { "low": { - "maxBitrate": self._camera_bitrate, + "maxBitrate": self.camera_bitrate, "scaleResolutionDownBy": 1.333, - "maxFramerate": self._camera_framerate, + "maxFramerate": self.camera_framerate, } }, } } }, - }, - ) - self._my_participant_id = self.client.participants()["local"]["id"] - - # For performance reasons, never subscribe to video streams (unless a - # video renderer is registered). - self.client.update_subscription_profiles({ - "base": { - "camera": "unsubscribed", - "screenVideo": "unsubscribed" - } - }) - - if self._token and self._start_transcription: - self.client.start_transcription(self.transcription_settings) + }) - self.original_sigint_handler = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, self.process_interrupt_handler) + async def leave(self): + # Transport not joined, ignore. + if not self._joined: + return - def process_interrupt_handler(self, signum, frame): - self._post_run() - if callable(self.original_sigint_handler): - self.original_sigint_handler(signum, frame) + self._joined = False - def _post_run(self): - self.client.leave() - self.client.release() + if self._transcription_enabled: + self._client.stop_transcription() - def on_first_other_participant_joined(self, participant): - pass + logger.info(f"Leaving {self._room_url}") + self._client.leave(completion=self._call_left) - def call_joined(self, join_data, client_error): - # self._logger.info(f"Call_joined: {join_data}, {client_error}") - pass + async def cleanup(self): + if self._client: + self._client.release() + self._client = None - def dialout(self, number): - self.client.start_dialout({"phoneNumber": number}) + def capture_participant_transcription(self, participant_id: str, callback: Callable): + if not self.transcription_enabled: + return - def start_recording(self): - self.client.start_recording() + self._transcription_renderers[participant_id] = { + "callback": callback + } - def render_participant_video(self, - participant_id, - framerate=10, - video_source="camera", - color_format="RGB") -> None: - if not self._video_rendering_enabled: - self._logger.warn("Video rendering is not enabled") + def capture_participant_video( + self, + participant_id: str, + callback: Callable, + framerate: int = 30, + video_source: str = "camera", + color_format: str = "RGB"): + if not self.video_capture_enabled: return # Only enable camera subscription on this participant - self.client.update_subscriptions(participant_settings={ + self._client.update_subscriptions(participant_settings={ participant_id: { "media": { video_source: "subscribed" @@ -310,81 +225,452 @@ def render_participant_video(self, "framerate": framerate, "timestamp": 0, "render_next_frame": False, + "callback": callback } - self.client.set_video_renderer( + + self._client.set_video_renderer( participant_id, - self.on_participant_video_frame, + self._video_frame_received, video_source=video_source, color_format=color_format) - def on_participant_video_frame(self, participant_id, video_frame): - if not self._loop: + def read_raw_audio_frames(self, frame_count: int) -> bytes: + result = b"" + if self.speaker_enabled: + result = self._speaker.read_frames(frame_count) + return result + + def write_raw_audio_frames(self, frames: bytes) -> int: + written = 0 + if self.mic_enabled: + # TODO(aleix). Find a better way. It might be the microhpone is not + # ready yet, so wait until something is written. + written = self._mic.write_frames(frames) + while written == 0: + written = self._mic.write_frames(frames) + return written + + def write_frame_to_camera(self, frame: ImageRawFrame): + if self.camera_enabled: + self._camera.write_frame(frame.data) + + # + # Daily (EventHandler) + # + + def on_participant_joined(self, participant): + id = participant["id"] + logger.info(f"Participant joined {id}") + # NOTE(aleix): It's a bit confusing but other event handlers will be + # called next if registered. + + def on_transcription_message(self, message: Mapping[str, Any]): + participant_id = "" + if "participantId" in message: + participant_id = message["participantId"] + + if participant_id in self._transcription_renderers: + callback = self._transcription_renderers[participant_id]["callback"] + callback(participant_id, message) + + def on_transcription_error(self, message): + logger.error(f"Transcription error: {message}") + + def on_transcription_started(self, status): + logger.info(f"Transcription started: {status}") + + def on_transcription_stopped(self, stopped_by, stopped_by_error): + logger.info("Transcription stopped") + + # + # Daily (callbacks) + # + + def _call_joined(self, data, error): + if error: + logger.error(f"Error joining {self._room_url}: {error}") return + logger.info(f"Joined {self._room_url}") + + if self._token and self._transcription_enabled: + self._client.start_transcription(self._transcription_settings) + + self.on_joined(data["participants"]["local"]) + + def _call_left(self, error): + if error: + logger.error(f"Error leaving {self._room_url}: {error}") + return + + logger.info(f"Left {self._room_url}") + + def _video_frame_received(self, participant_id, video_frame): render_frame = False curr_time = time.time() + prev_time = self._video_renderers[participant_id]["timestamp"] or curr_time framerate = self._video_renderers[participant_id]["framerate"] if framerate > 0: - prev_time = self._video_renderers[participant_id]["timestamp"] next_time = prev_time + 1 / framerate - render_frame = curr_time > next_time + render_frame = (curr_time - next_time) < 0.1 elif self._video_renderers[participant_id]["render_next_frame"]: self._video_renderers[participant_id]["render_next_frame"] = False render_frame = True if render_frame: - frame = UserImageFrame(participant_id, video_frame.buffer, - (video_frame.width, video_frame.height)) - asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self._loop) + callback = self._video_renderers[participant_id]["callback"] + callback(participant_id, + video_frame.buffer, + (video_frame.width, video_frame.height), + video_frame.color_format) self._video_renderers[participant_id]["timestamp"] = curr_time - def on_error(self, error): - self._logger.error(f"on_error: {error}") - - def on_call_state_updated(self, state): - pass - - def on_participant_joined(self, participant): - if not self._other_participant_has_joined and participant["id"] != self._my_participant_id: - self._other_participant_has_joined = True - self.on_first_other_participant_joined(participant) - - def on_participant_left(self, participant, reason): - if len(self.client.participants()) < self._min_others_count + 1: - self._stop_threads.set() - - def on_app_message(self, message: Any, sender: str): - if self._loop: - frame = ReceivedAppMessageFrame(message, sender) - asyncio.run_coroutine_threadsafe( - self.receive_queue.put(frame), self._loop - ) - - def on_transcription_message(self, message: dict): - if self._loop: - participantId = "" - if "participantId" in message: - participantId = message["participantId"] - elif "session_id" in message: - participantId = message["session_id"] - if self._my_participant_id and participantId != self._my_participant_id: - is_final = message["rawResponse"]["is_final"] - if is_final: - frame = TranscriptionFrame(message["text"], participantId, message["timestamp"]) - else: - frame = InterimTranscriptionFrame( - message["text"], participantId, message["timestamp"]) - asyncio.run_coroutine_threadsafe( - self.receive_queue.put(frame), self._loop) - - def on_transcription_error(self, message): - self._logger.error(f"Transcription error: {message}") - - def on_transcription_started(self, status): - pass - - def on_transcription_stopped(self, stopped_by, stopped_by_error): - pass + # class DailyTransport(ThreadedTransport, EventHandler): + # _daily_initialized = False + # _lock = threading.Lock() + + # _speaker_enabled: bool + # _speaker_sample_rate: int + # _vad_enabled: bool + + # # This is necessary to override EventHandler's __new__ method. + # def __new__(cls, *args, **kwargs): + # return super().__new__(cls) + + # def __init__( + # self, + # room_url: str, + # token: str | None, + # bot_name: str, + # min_others_count: int = 1, + # start_transcription: bool = False, + # video_rendering_enabled: bool = False, + # **kwargs, + # ): + # kwargs['has_webrtc_vad'] = True + # # This will call ThreadedTransport.__init__ method, not EventHandler + # super().__init__(**kwargs) + + # self._room_url: str = room_url + # self._bot_name: str = bot_name + # self._token: str | None = token + # self._min_others_count = min_others_count + # self._start_transcription = start_transcription + # self._video_rendering_enabled = video_rendering_enabled + + # self._is_interrupted = Event() + # self._stop_threads = Event() + + # self._other_participant_has_joined = False + # self._my_participant_id = None + + # self._video_renderers = {} + + # self.transcription_settings = { + # "language": "en", + # "tier": "nova", + # "model": "2-conversationalai", + # "profanity_filter": True, + # "redact": False, + # "endpointing": True, + # "punctuate": True, + # "includeRawResponse": True, + # "extra": { + # "interim_results": True, + # }, + # } + + # self._logger: logging.Logger = logging.getLogger("dailyai") + + # self._event_handlers = {} + + # self.webrtc_vad = Daily.create_native_vad( + # reset_period_ms=VAD_RESET_PERIOD_MS, + # sample_rate=self._speaker_sample_rate, + # channels=NUM_CHANNELS + # ) + + # def _patch_method(self, event_name, *args, **kwargs): + # try: + # for handler in self._event_handlers[event_name]: + # if inspect.iscoroutinefunction(handler): + # if self._loop: + # future = asyncio.run_coroutine_threadsafe( + # handler(*args, **kwargs), self._loop) + + # # wait for the coroutine to finish. This will also + # # raise any exceptions raised by the coroutine. + # future.result() + # else: + # raise Exception( + # "No event loop to run coroutine. In order to use async event handlers, you must run the DailyTransportService in an asyncio event loop.") + # else: + # handler(*args, **kwargs) + # except Exception as e: + # self._logger.error(f"Exception in event handler {event_name}: {e}") + # raise e + + # def _webrtc_vad_analyze(self): + # buffer = self.read_audio_frames(int(self._vad_samples)) + # if len(buffer) > 0: + # confidence = self.webrtc_vad.analyze_frames(buffer) + # # yeses = int(confidence * 20.0) + # # nos = 20 - yeses + # # out = "!" * yeses + "." * nos + # # print(f"!!! confidence: {out} {confidence}") + # talking = confidence > SPEECH_THRESHOLD + # return talking + + # def add_event_handler(self, event_name: str, handler): + # if not event_name.startswith("on_"): + # raise Exception( + # f"Event handler {event_name} must start with 'on_'") + + # methods = inspect.getmembers(self, predicate=inspect.ismethod) + # if event_name not in [method[0] for method in methods]: + # raise Exception(f"Event handler {event_name} not found") + + # if event_name not in self._event_handlers: + # self._event_handlers[event_name] = [ + # getattr( + # self, event_name), types.MethodType( + # handler, self)] + # setattr(self, event_name, partial(self._patch_method, event_name)) + # else: + # self._event_handlers[event_name].append( + # types.MethodType(handler, self)) + + # def event_handler(self, event_name: str): + # def decorator(handler): + # self.add_event_handler(event_name, handler) + # return handler + + # return decorator + + # def write_frame_to_camera(self, frame: bytes): + # if self._camera_enabled: + # self.camera.write_frame(frame) + + # def write_frame_to_mic(self, frame: bytes): + # if self._mic_enabled: + # self.mic.write_frames(frame) + + # def request_participant_image(self, participant_id: str): + # if participant_id in self._video_renderers: + # self._video_renderers[participant_id]["render_next_frame"] = True + + # def send_app_message(self, message: Any, participant_id: str | None): + # self.client.send_app_message(message, participant_id) + + # def read_audio_frames(self, desired_frame_count): + # bytes = b"" + # if self._speaker_enabled or self._vad_enabled: + # bytes = self._speaker.read_frames(desired_frame_count) + # return bytes + + # def _prerun(self): + # # Only initialize Daily once + # if not DailyTransport._daily_initialized: + # with DailyTransport._lock: + # Daily.init() + # DailyTransport._daily_initialized = True + # self.client = CallClient(event_handler=self) + + # if self._mic_enabled: + # self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device( + # "mic", sample_rate=self._mic_sample_rate, channels=1 + # ) + + # if self._camera_enabled: + # self.camera: VirtualCameraDevice = Daily.create_camera_device( + # "camera", width=self._camera_width, height=self._camera_height, color_format="RGB") + + # if self._speaker_enabled or self._vad_enabled: + # self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device( + # "speaker", sample_rate=self._speaker_sample_rate, channels=1 + # ) + # Daily.select_speaker_device("speaker") + + # self.client.set_user_name(self._bot_name) + # self.client.join( + # self._room_url, + # self._token, + # completion=self.call_joined, + # client_settings={ + # "inputs": { + # "camera": { + # "isEnabled": True, + # "settings": { + # "deviceId": "camera", + # }, + # }, + # "microphone": { + # "isEnabled": True, + # "settings": { + # "deviceId": "mic", + # "customConstraints": { + # "autoGainControl": {"exact": False}, + # "echoCancellation": {"exact": False}, + # "noiseSuppression": {"exact": False}, + # }, + # }, + # }, + # }, + # "publishing": { + # "camera": { + # "sendSettings": { + # "maxQuality": "low", + # "encodings": { + # "low": { + # "maxBitrate": self._camera_bitrate, + # "scaleResolutionDownBy": 1.333, + # "maxFramerate": self._camera_framerate, + # } + # }, + # } + # } + # }, + # }, + # ) + # self._my_participant_id = self.client.participants()["local"]["id"] + + # # For performance reasons, never subscribe to video streams (unless a + # # video renderer is registered). + # self.client.update_subscription_profiles({ + # "base": { + # "camera": "unsubscribed", + # "screenVideo": "unsubscribed" + # } + # }) + + # if self._token and self._start_transcription: + # self.client.start_transcription(self.transcription_settings) + + # self.original_sigint_handler = signal.getsignal(signal.SIGINT) + # signal.signal(signal.SIGINT, self.process_interrupt_handler) + + # def process_interrupt_handler(self, signum, frame): + # self._post_run() + # if callable(self.original_sigint_handler): + # self.original_sigint_handler(signum, frame) + + # def _post_run(self): + # self.client.leave() + # self.client.release() + + # def on_first_other_participant_joined(self, participant): + # pass + + # def call_joined(self, join_data, client_error): + # # self._logger.info(f"Call_joined: {join_data}, {client_error}") + # pass + + # def dialout(self, number): + # self.client.start_dialout({"phoneNumber": number}) + + # def start_recording(self): + # self.client.start_recording() + + # def render_participant_video(self, + # participant_id, + # framerate=10, + # video_source="camera", + # color_format="RGB") -> None: + # if not self._video_rendering_enabled: + # self._logger.warn("Video rendering is not enabled") + # return + + # # Only enable camera subscription on this participant + # self.client.update_subscriptions(participant_settings={ + # participant_id: { + # "media": { + # video_source: "subscribed" + # } + # } + # }) + + # self._video_renderers[participant_id] = { + # "framerate": framerate, + # "timestamp": 0, + # "render_next_frame": False, + # } + # self.client.set_video_renderer( + # participant_id, + # self.on_participant_video_frame, + # video_source=video_source, + # color_format=color_format) + + # def on_participant_video_frame(self, participant_id, video_frame): + # if not self._loop: + # return + + # render_frame = False + + # curr_time = time.time() + # framerate = self._video_renderers[participant_id]["framerate"] + + # if framerate > 0: + # prev_time = self._video_renderers[participant_id]["timestamp"] + # next_time = prev_time + 1 / framerate + # render_frame = curr_time > next_time + # elif self._video_renderers[participant_id]["render_next_frame"]: + # self._video_renderers[participant_id]["render_next_frame"] = False + # render_frame = True + + # if render_frame: + # frame = UserImageFrame(participant_id, video_frame.buffer, + # (video_frame.width, video_frame.height)) + # asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self._loop) + + # self._video_renderers[participant_id]["timestamp"] = curr_time + + # def on_error(self, error): + # self._logger.error(f"on_error: {error}") + + # def on_call_state_updated(self, state): + # pass + + # def on_participant_joined(self, participant): + # if not self._other_participant_has_joined and participant["id"] != self._my_participant_id: + # self._other_participant_has_joined = True + # self.on_first_other_participant_joined(participant) + + # def on_participant_left(self, participant, reason): + # if len(self.client.participants()) < self._min_others_count + 1: + # self._stop_threads.set() + + # def on_app_message(self, message: Any, sender: str): + # if self._loop: + # frame = ReceivedAppMessageFrame(message, sender) + # asyncio.run_coroutine_threadsafe( + # self.receive_queue.put(frame), self._loop + # ) + + # def on_transcription_message(self, message: dict): + # if self._loop: + # participantId = "" + # if "participantId" in message: + # participantId = message["participantId"] + # elif "session_id" in message: + # participantId = message["session_id"] + # if self._my_participant_id and participantId != self._my_participant_id: + # is_final = message["rawResponse"]["is_final"] + # if is_final: + # frame = TranscriptionFrame(message["text"], participantId, message["timestamp"]) + # else: + # frame = InterimTranscriptionFrame( + # message["text"], participantId, message["timestamp"]) + # asyncio.run_coroutine_threadsafe( + # self.receive_queue.put(frame), self._loop) + + # def on_transcription_error(self, message): + # self._logger.error(f"Transcription error: {message}") + + # def on_transcription_started(self, status): + # pass + + # def on_transcription_stopped(self, stopped_by, stopped_by_error): + # pass diff --git a/src/dailyai/transports/live_stream_transport.py b/src/dailyai/transports/live_stream_transport.py new file mode 100644 index 000000000..b3736e58b --- /dev/null +++ b/src/dailyai/transports/live_stream_transport.py @@ -0,0 +1,196 @@ +import asyncio +import inspect +import types + +from abc import abstractmethod +from functools import partial +from asyncio import AbstractEventLoop +from typing import Callable + +from dailyai.frames.frames import ImageRawFrame +from dailyai.vad.vad_analyzer import VADState + + +class LiveStreamTransport: + + def __init__(self, **kwargs): + self._event_handlers: dict = {} + self._loop: asyncio.AbstractEventLoop | None = None + + self._camera_enabled = kwargs.get("camera_enabled") or False + self._camera_width = kwargs.get("camera_width") or 1024 + self._camera_height = kwargs.get("camera_height") or 768 + self._camera_bitrate = kwargs.get("camera_bitrate") or 680000 + self._camera_framerate = kwargs.get("camera_framerate") or 30 + self._mic_enabled = kwargs.get("mic_enabled") or False + self._mic_sample_rate = kwargs.get("mic_sample_rate") or 16000 + self._mic_channels = kwargs.get("mic_channels") or 1 + self._speaker_enabled = kwargs.get("speaker_enabled") or False + self._speaker_sample_rate = kwargs.get("speaker_sample_rate") or 16000 + self._speaker_channels = kwargs.get("speaker_channels") or 1 + self._transcription_enabled = kwargs.get("transcription_enabled") or False + self._video_capture_enabled = kwargs.get("video_capture_enabled") or False + self._vad_enabled = kwargs.get("vad_enabled") or False + self.vad_analyzer = kwargs.get("vad_analyzer") or None + + @property + def mic_enabled(self): + return self._mic_enabled + + @property + def mic_sample_rate(self): + return self._mic_sample_rate + + @property + def mic_channels(self): + return self._mic_channels + + @property + def speaker_enabled(self): + return self._speaker_enabled + + @property + def speaker_sample_rate(self): + return self._speaker_sample_rate + + @property + def speaker_channels(self): + return self._speaker_channels + + @property + def camera_enabled(self): + return self._camera_enabled + + @property + def camera_width(self): + return self._camera_width + + @property + def camera_height(self): + return self._camera_height + + @property + def camera_bitrate(self): + return self._camera_bitrate + + @property + def camera_framerate(self): + return self._camera_framerate + + @property + def transcription_enabled(self): + return self._transcription_enabled + + @property + def video_capture_enabled(self): + return self._video_capture_enabled + + @property + def vad_enabled(self): + return self._vad_enabled + + @property + @abstractmethod + def participant_id(self) -> str: + pass + + @abstractmethod + async def join(self): + pass + + @abstractmethod + async def leave(self): + pass + + @abstractmethod + async def cleanup(self): + pass + + @abstractmethod + def read_raw_audio_frames(self, frame_count: int) -> bytes: + pass + + @abstractmethod + def write_raw_audio_frames(self, frames: bytes) -> int: + pass + + @abstractmethod + def write_frame_to_camera(self, frame: ImageRawFrame): + pass + + @abstractmethod + def capture_participant_transcription(self, participant_id: str, callback: Callable): + pass + + @abstractmethod + def capture_participant_video( + self, + participant_id: str, + callback: Callable, + framerate: int, + video_source: str, + color_format: str): + pass + + def vad_analyze(self, buffer: bytes) -> VADState: + result = VADState.QUIET + if self.vad_analyzer: + result = self.vad_analyzer.analyze_audio(buffer) + return result + + # + # Frame processor + # + + def event_loop(self) -> AbstractEventLoop: + return self._loop + + def set_event_loop(self, loop: AbstractEventLoop): + self._loop = loop + + # + # Decorators (event handlers) + # + + def on_joined(self, participant): + pass + + def on_participant_joined(self, participant): + pass + + def event_handler(self, event_name: str, obj=None): + def decorator(handler): + self._add_event_handler(event_name, handler, obj) + return handler + return decorator + + def _add_event_handler(self, event_name: str, handler, obj): + methods = inspect.getmembers(self, predicate=inspect.ismethod) + if event_name not in [method[0] for method in methods]: + raise Exception(f"Event handler {event_name} not found") + + if event_name not in self._event_handlers: + self._event_handlers[event_name] = [getattr(self, event_name)] + patch_method = types.MethodType(partial(self._patch_method, event_name), self) + setattr(self, event_name, patch_method) + self._event_handlers[event_name].append(types.MethodType(handler, self)) + + def _patch_method(self, event_name, *args, **kwargs): + try: + for handler in self._event_handlers[event_name]: + if inspect.iscoroutinefunction(handler): + if self.event_loop(): + future = asyncio.run_coroutine_threadsafe( + handler(*args[1:], **kwargs), self.event_loop()) + + # wait for the coroutine to finish. This will also + # raise any exceptions raised by the coroutine. + future.result() + else: + raise Exception( + "No event loop to run coroutine. In order to use async event handlers, you must run the DailyTransportService in an asyncio event loop.") + else: + handler(*args[1:], **kwargs) + except Exception as e: + # TODO(aleix) self._logger.error(f"Exception in event handler {event_name}: {e}") + raise e diff --git a/src/dailyai/utils/__init__.py b/src/dailyai/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dailyai/utils/utils.py b/src/dailyai/utils/utils.py new file mode 100644 index 000000000..a752ad705 --- /dev/null +++ b/src/dailyai/utils/utils.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from threading import Lock + +_IDS = {} + +_IDS_MUTEX = Lock() + + +def obj_count(obj) -> int: + name = obj.__class__.__name__ + with _IDS_MUTEX: + if name not in _IDS: + _IDS[name] = 0 + else: + _IDS[name] += 1 + return _IDS[name] diff --git a/src/dailyai/vad/__init__.py b/src/dailyai/vad/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/dailyai/vad/silero_vad.py b/src/dailyai/vad/silero_vad.py new file mode 100644 index 000000000..9ab85ed12 --- /dev/null +++ b/src/dailyai/vad/silero_vad.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import numpy as np + +from dailyai.vad.vad_analyzer import VADAnalyzer + +from loguru import logger + +try: + import torch + # We don't use torchaudio here, but we need to try importing it because + # Silero uses it. + import torchaudio + + torch.set_num_threads(1) + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Silero VAD, you need to `pip install dailyai[silero]`.") + raise Exception(f"Missing module(s): {e}") + + +# Provided by Alexander Veysov +def int2float(sound): + try: + abs_max = np.abs(sound).max() + sound = sound.astype("float32") + if abs_max > 0: + sound *= 1 / 32768 + sound = sound.squeeze() # depends on the use case + return sound + except ValueError: + return sound + + +class SileroVADAnalyzer(VADAnalyzer): + + def __init__(self, sample_rate=16000, **kwargs): + super().__init__(sample_rate=sample_rate, num_channels=1, **kwargs) + + logger.info("Loading Silero VAD") + + (self._model, self._utils) = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False + ) + + logger.info("Using Silero VAD") + + def num_frames_required(self) -> int: + return int(self.sample_rate / 100) * 4 # 40ms + + def voice_confidence(self, buffer) -> float: + try: + audio_int16 = np.frombuffer(buffer, np.int16) + audio_float32 = int2float(audio_int16) + new_confidence = self._model(torch.from_numpy(audio_float32), self.sample_rate).item() + return new_confidence + except BaseException as e: + # This comes from an empty audio array + logger.error(f"Error analyzing audio with Silero VAD: {e}") + return 0 diff --git a/src/dailyai/vad/vad_analyzer.py b/src/dailyai/vad/vad_analyzer.py new file mode 100644 index 000000000..a506292e2 --- /dev/null +++ b/src/dailyai/vad/vad_analyzer.py @@ -0,0 +1,104 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import abstractmethod +from enum import Enum + + +class VADState(Enum): + QUIET = 1 + STARTING = 2 + SPEAKING = 3 + STOPPING = 4 + + +class VADAnalyzer: + + def __init__( + self, + sample_rate, + num_channels, + vad_confidence=0.5, + vad_start_s=0.2, + vad_stop_s=0.8): + self._sample_rate = sample_rate + self._vad_confidence = vad_confidence + self._vad_start_s = vad_start_s + self._vad_stop_s = vad_stop_s + self._vad_frames = self.num_frames_required() + self._vad_frames_num_bytes = self._vad_frames * num_channels * 2 + + vad_frame_s = self._vad_frames / self._sample_rate + + self._vad_start_frames = round(self._vad_start_s / vad_frame_s) + self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s) + self._vad_starting_count = 0 + self._vad_stopping_count = 0 + self._vad_state: VADState = VADState.QUIET + + self._vad_buffer = b"" + + @property + def sample_rate(self): + return self._sample_rate + + @abstractmethod + def num_frames_required(self) -> int: + pass + + @abstractmethod + def voice_confidence(self, buffer) -> float: + pass + + def analyze_audio(self, buffer) -> VADState: + self._vad_buffer += buffer + + num_required_bytes = self._vad_frames_num_bytes + if len(self._vad_buffer) < num_required_bytes: + return self._vad_state + + audio_frames = self._vad_buffer[:num_required_bytes] + self._vad_buffer = self._vad_buffer[num_required_bytes:] + + confidence = self.voice_confidence(audio_frames) + speaking = confidence >= self._vad_confidence + + if speaking: + match self._vad_state: + case VADState.QUIET: + self._vad_state = VADState.STARTING + self._vad_starting_count = 1 + case VADState.STARTING: + self._vad_starting_count += 1 + case VADState.STOPPING: + self._vad_state = VADState.SPEAKING + self._vad_stopping_count = 0 + else: + match self._vad_state: + case VADState.STARTING: + self._vad_state = VADState.QUIET + self._vad_starting_count = 0 + case VADState.SPEAKING: + self._vad_state = VADState.STOPPING + self._vad_stopping_count = 1 + case VADState.STOPPING: + self._vad_stopping_count += 1 + + if ( + self._vad_state == VADState.STARTING + and self._vad_starting_count >= self._vad_start_frames + ): + self._vad_state = VADState.SPEAKING + self._vad_starting_count = 0 + + if ( + self._vad_state == VADState.STOPPING + and self._vad_stopping_count >= self._vad_stop_frames + ): + self._vad_state = VADState.QUIET + self._vad_stopping_count = 0 + + return self._vad_state