From 3f3a853d715f5d602a69b9028a1c09d79da70e02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 12 Dec 2024 14:45:20 -0800 Subject: [PATCH 1/2] no longer necessary to call AIService super().start/stop/cancel(frame) --- CHANGELOG.md | 4 ++++ src/pipecat/services/ai_services.py | 18 ++++++++++++------ src/pipecat/services/assemblyai.py | 3 --- src/pipecat/services/azure.py | 3 --- src/pipecat/services/canonical.py | 2 -- src/pipecat/services/deepgram.py | 3 --- .../services/gemini_multimodal_live/gemini.py | 3 --- src/pipecat/services/gladia.py | 3 --- .../services/openai_realtime_beta/openai.py | 3 --- src/pipecat/services/riva.py | 3 --- 10 files changed, 16 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4a005280..fdc45c33f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- It's no longer necessary to call `super().start/stop/cancel(frame)` if you + subclass and implement `AIService.start/stop/cancel()`. This is all now done + internally and will avoid possible issues if you forget to add it. + - It's no longer necessary to call `super().process_frame(frame, direction)` if you subclass and implement `FrameProcessor.process_frame()`. This is all now done internally and will avoid possible issues if you forget to add it. diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index e324d413c..5f867246f 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -111,11 +111,11 @@ async def _update_settings(self, settings: Dict[str, Any]): async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, StartFrame): - await self.start(frame) + await self._start(frame) elif isinstance(frame, CancelFrame): - await self.cancel(frame) + await self._cancel(frame) elif isinstance(frame, EndFrame): - await self.stop(frame) + await self._stop(frame) async def process_generator(self, generator: AsyncGenerator[Frame | None, None]): async for f in generator: @@ -125,6 +125,15 @@ async def process_generator(self, generator: AsyncGenerator[Frame | None, None]) else: await self.push_frame(f) + async def _start(self, frame: StartFrame): + await self.start(frame) + + async def _stop(self, frame: EndFrame): + await self.stop(frame) + + async def _cancel(self, frame: CancelFrame): + await self.cancel(frame) + class LLMService(AIService): """This class is a no-op but serves as a base class for LLM services.""" @@ -248,19 +257,16 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: pass async def start(self, frame: StartFrame): - await super().start(frame) if self._push_stop_frames: self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler()) async def stop(self, frame: EndFrame): - await super().stop(frame) if self._stop_frame_task: self._stop_frame_task.cancel() await self._stop_frame_task self._stop_frame_task = None async def cancel(self, frame: CancelFrame): - await super().cancel(frame) if self._stop_frame_task: self._stop_frame_task.cancel() await self._stop_frame_task diff --git a/src/pipecat/services/assemblyai.py b/src/pipecat/services/assemblyai.py index 36a7e92ff..577df7d76 100644 --- a/src/pipecat/services/assemblyai.py +++ b/src/pipecat/services/assemblyai.py @@ -61,15 +61,12 @@ async def set_language(self, language: Language): self._settings["language"] = language async def start(self, frame: StartFrame): - await super().start(frame) await self._connect() async def stop(self, frame: EndFrame): - await super().stop(frame) await self._disconnect() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._disconnect() async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index a95ff7d3c..01ab22940 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -676,16 +676,13 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: yield None async def start(self, frame: StartFrame): - await super().start(frame) self._speech_recognizer.start_continuous_recognition_async() async def stop(self, frame: EndFrame): - await super().stop(frame) self._speech_recognizer.stop_continuous_recognition_async() self._audio_stream.close() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) self._speech_recognizer.stop_continuous_recognition_async() self._audio_stream.close() diff --git a/src/pipecat/services/canonical.py b/src/pipecat/services/canonical.py index 265cc1b1b..2986d8b94 100644 --- a/src/pipecat/services/canonical.py +++ b/src/pipecat/services/canonical.py @@ -84,11 +84,9 @@ def __init__( self._output_dir = output_dir async def stop(self, frame: EndFrame): - await super().stop(frame) await self._process_audio() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._process_audio() async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index 6578f2873..89064dab8 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -176,15 +176,12 @@ async def set_language(self, language: Language): await self._connect() async def start(self, frame: StartFrame): - await super().start(frame) await self._connect() async def stop(self, frame: EndFrame): - await super().stop(frame) await self._disconnect() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._disconnect() async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index bf433054a..efdeea329 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -229,15 +229,12 @@ async def set_context(self, context: OpenAILLMContext): # async def start(self, frame: StartFrame): - await super().start(frame) await self._connect() async def stop(self, frame: EndFrame): - await super().stop(frame) await self._disconnect() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._disconnect() # diff --git a/src/pipecat/services/gladia.py b/src/pipecat/services/gladia.py index 8909c4bb2..adc261222 100644 --- a/src/pipecat/services/gladia.py +++ b/src/pipecat/services/gladia.py @@ -177,18 +177,15 @@ def language_to_service_language(self, language: Language) -> str | None: return language_to_gladia_language(language) async def start(self, frame: StartFrame): - await super().start(frame) response = await self._setup_gladia() self._websocket = await websockets.connect(response["url"]) self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) async def stop(self, frame: EndFrame): - await super().stop(frame) await self._send_stop_recording() await self._websocket.close() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._websocket.close() async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index ac492a205..e8e2e50dd 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -112,15 +112,12 @@ def set_audio_input_paused(self, paused: bool): # async def start(self, frame: StartFrame): - await super().start(frame) await self._connect() async def stop(self, frame: EndFrame): - await super().stop(frame) await self._disconnect() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._disconnect() # diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva.py index 6be722d49..7d14cb583 100644 --- a/src/pipecat/services/riva.py +++ b/src/pipecat/services/riva.py @@ -187,17 +187,14 @@ def can_generate_metrics(self) -> bool: return False async def start(self, frame: StartFrame): - await super().start(frame) self._thread_task = self.get_event_loop().create_task(self._thread_task_handler()) self._response_task = self.get_event_loop().create_task(self._response_task_handler()) self._response_queue = asyncio.Queue() async def stop(self, frame: EndFrame): - await super().stop(frame) await self._stop_tasks() async def cancel(self, frame: CancelFrame): - await super().cancel(frame) await self._stop_tasks() async def _stop_tasks(self): From 06043ce9b1b52df92320647f1751d64307862e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 12 Dec 2024 14:53:56 -0800 Subject: [PATCH 2/2] missing no longer necessary to call super().process_frame(frame, direction) --- src/pipecat/services/ai_services.py | 10 ---------- src/pipecat/services/anthropic.py | 4 ---- src/pipecat/services/canonical.py | 1 - src/pipecat/services/cartesia.py | 2 -- src/pipecat/services/elevenlabs.py | 2 -- .../services/gemini_multimodal_live/gemini.py | 3 --- src/pipecat/services/google.py | 2 -- src/pipecat/services/openai.py | 4 ---- .../services/openai_realtime_beta/context.py | 1 - src/pipecat/services/openai_realtime_beta/openai.py | 2 -- src/pipecat/services/playht.py | 2 -- src/pipecat/services/tavus.py | 1 - src/pipecat/transports/network/fastapi_websocket.py | 2 -- src/pipecat/transports/network/websocket_server.py | 2 -- src/pipecat/transports/services/daily.py | 2 -- src/pipecat/transports/services/livekit.py | 13 ------------- 16 files changed, 53 deletions(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 5f867246f..5aa2ec731 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -292,8 +292,6 @@ async def say(self, text: str): await self.queue_frame(TTSSpeakFrame(text)) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, TextFrame): await self._process_text_frame(frame) elif isinstance(frame, StartInterruptionFrame): @@ -410,8 +408,6 @@ async def cancel(self, frame: CancelFrame): await self._stop_words_task() async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)): await self.flush_audio() @@ -498,8 +494,6 @@ async def process_audio_frame(self, frame: AudioRawFrame): async def process_frame(self, frame: Frame, direction: FrameDirection): """Processes a frame of audio data, either buffering or transcribing it.""" - await super().process_frame(frame, direction) - if isinstance(frame, AudioRawFrame): # In this service we accumulate audio internally and at the end we # push a TextFrame. We also push audio downstream in case someone @@ -597,8 +591,6 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: pass async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, TextFrame): await self.push_frame(frame, direction) await self.start_processing_metrics() @@ -620,8 +612,6 @@ async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, pass async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, VisionImageRawFrame): await self.start_processing_metrics() await self.process_generator(self.run_vision(frame)) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index f0c033375..320f9b828 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -270,8 +270,6 @@ async def _process_context(self, context: OpenAILLMContext): ) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - context = None if isinstance(frame, OpenAILLMContextFrame): context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context) @@ -611,7 +609,6 @@ def __init__(self, context: OpenAILLMContext | AnthropicLLMContext): self._context = AnthropicLLMContext.from_openai_context(context) async def process_frame(self, frame, direction): - await super().process_frame(frame, direction) # Our parent method has already called push_frame(). So we can't interrupt the # flow here and we don't need to call push_frame() ourselves. Possibly something # to talk through (tagging @aleix). At some point we might need to refactor these @@ -664,7 +661,6 @@ def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kw self._pending_image_frame_message = None async def process_frame(self, frame, direction): - await super().process_frame(frame, direction) # See note above about not calling push_frame() here. if isinstance(frame, StartInterruptionFrame): self._function_call_in_progress = None diff --git a/src/pipecat/services/canonical.py b/src/pipecat/services/canonical.py index 2986d8b94..1d6c0ca9c 100644 --- a/src/pipecat/services/canonical.py +++ b/src/pipecat/services/canonical.py @@ -90,7 +90,6 @@ async def cancel(self, frame: CancelFrame): await self._process_audio() async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) await self.push_frame(frame, direction) async def _process_audio(self): diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 8683fd29a..1cb7281ae 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -287,8 +287,6 @@ async def _receive_task_handler(self): await self._connect_websocket() async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - # If we received a TTSSpeakFrame and the LLM response included text (it # might be that it's only a function calling response) we pause # processing more frames until we receive a BotStoppedSpeakingFrame. diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index b829f4945..ec278b33c 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -272,8 +272,6 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)]) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - # If we received a TTSSpeakFrame and the LLM response included text (it # might be that it's only a function calling response) we pause # processing more frames until we receive a BotStoppedSpeakingFrame. diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index efdeea329..0cf5dfdee 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -107,7 +107,6 @@ def get_messages_for_initializing_history(self): class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator): async def process_frame(self, frame, direction): - await super().process_frame(frame, direction) # kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now if isinstance(frame, LLMMessagesAppendFrame): await self.push_frame(frame, direction) @@ -305,8 +304,6 @@ async def _transcribe_audio(self, audio, context): # async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - # logger.debug(f"Processing frame: {frame}") if isinstance(frame, TranscriptionFrame): diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 5442ee91c..ece4e3b33 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -652,8 +652,6 @@ async def _process_context(self, context: OpenAILLMContext): await self.push_frame(LLMFullResponseEndFrame()) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - context = None if isinstance(frame, OpenAILLMContextFrame): diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 43ad16536..d50c827a1 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -286,8 +286,6 @@ async def _process_context(self, context: OpenAILLMContext): ) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - context = None if isinstance(frame, OpenAILLMContextFrame): context: OpenAILLMContext = frame.context @@ -475,7 +473,6 @@ def __init__(self, context: OpenAILLMContext): super().__init__(context=context) async def process_frame(self, frame, direction): - await super().process_frame(frame, direction) # Our parent method has already called push_frame(). So we can't interrupt the # flow here and we don't need to call push_frame() ourselves. try: @@ -516,7 +513,6 @@ def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwarg self._pending_image_frame_message = None async def process_frame(self, frame, direction): - await super().process_frame(frame, direction) # See note above about not calling push_frame() here. if isinstance(frame, StartInterruptionFrame): self._function_calls_in_progress.clear() diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index 2b6ff968f..611d6293e 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -148,7 +148,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): async def process_frame( self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM ): - await super().process_frame(frame, direction) # Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline, # messages are only processed by the user context aggregator, which is generally what we want. But # we also need to send new messages over the websocket, so the openai realtime API has them diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index e8e2e50dd..d265972bd 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -170,8 +170,6 @@ async def _truncate_current_audio_response(self): # async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, TranscriptionFrame): pass elif isinstance(frame, OpenAILLMContextFrame): diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 00c5d4d34..1612c4657 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -265,8 +265,6 @@ async def _receive_task_handler(self): await self._connect_websocket() async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - # If we received a TTSSpeakFrame and the LLM response included text (it # might be that it's only a function calling response) we pause # processing more frames until we receive a BotStoppedSpeakingFrame. diff --git a/src/pipecat/services/tavus.py b/src/pipecat/services/tavus.py index ff2b7fb87..97788b708 100644 --- a/src/pipecat/services/tavus.py +++ b/src/pipecat/services/tavus.py @@ -92,7 +92,6 @@ async def _encode_audio_and_send( await self._send_audio_message(audio_base64, done=done) async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) if isinstance(frame, TTSStartedFrame): await self.start_processing_metrics() await self.start_ttfb_metrics() diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index cd21fbe0d..fa994cc6b 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -101,8 +101,6 @@ def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwarg self._next_send_time = 0 async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, StartInterruptionFrame): await self._write_frame(frame) self._next_send_time = 0 diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index 711bc7596..66dd8eaa0 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -139,8 +139,6 @@ async def set_client_connection(self, websocket: websockets.WebSocketServerProto self._websocket = websocket async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, StartInterruptionFrame): await self._write_frame(frame) self._next_send_time = 0 diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 7456ef816..8195663c7 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -727,8 +727,6 @@ def vad_analyzer(self) -> VADAnalyzer | None: # async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if isinstance(frame, UserImageRequestFrame): await self.request_participant_image(frame.user_id) diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index f53c8332f..1c2851ca9 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -16,7 +16,6 @@ AudioRawFrame, CancelFrame, EndFrame, - Frame, InputAudioRawFrame, OutputAudioRawFrame, StartFrame, @@ -334,12 +333,6 @@ async def stop(self, frame: EndFrame): await self._client.disconnect() logger.info("LiveKitInputTransport stopped") - async def process_frame(self, frame: Frame, direction: FrameDirection): - if isinstance(frame, EndFrame): - await self.stop(frame) - else: - await super().process_frame(frame, direction) - async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._client.disconnect() @@ -411,12 +404,6 @@ async def stop(self, frame: EndFrame): await self._client.disconnect() logger.info("LiveKitOutputTransport stopped") - async def process_frame(self, frame: Frame, direction: FrameDirection): - if isinstance(frame, EndFrame): - await self.stop(frame) - else: - await super().process_frame(frame, direction) - async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._client.disconnect()