From 1117c2148365caca2202bef57792438ba8cfb27e Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Dec 2024 15:24:58 -0500 Subject: [PATCH] Refactor TranscriptProcessor into user and assistant processors --- CHANGELOG.md | 3 +- .../28a-transcription-processor-openai.py | 23 +- .../28b-transcript-processor-anthropic.py | 31 ++- .../28c-transcription-processor-gemini.py | 41 ++-- src/pipecat/frames/frames.py | 7 - .../processors/aggregators/llm_response.py | 6 - .../processors/transcript_processor.py | 221 +++++++++++------- src/pipecat/services/google.py | 5 - 8 files changed, 182 insertions(+), 155 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e784cdba..e3f2c2dc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,9 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Messages emitted with ISO 8601 timestamps indicating when they were spoken. - Supports all LLM formats (OpenAI, Anthropic, Google) via standard message format. + - Shared event handling for both user and assistant transcript updates. - New examples: `28a-transcription-processor-openai.py`, `28b-transcription-processor-anthropic.py`, and - `28c-transcription-processor-gemini.py`. + `28c-transcription-processor-gemini.py` - Add support for more languages to ElevenLabs (Arabic, Croatian, Filipino, Tamil) and PlayHT (Afrikans, Albanian, Amharic, Arabic, Bengali, Croatian, diff --git a/examples/foundational/28a-transcription-processor-openai.py b/examples/foundational/28a-transcription-processor-openai.py index 1e8463b69..0966c882f 100644 --- a/examples/foundational/28a-transcription-processor-openai.py +++ b/examples/foundational/28a-transcription-processor-openai.py @@ -15,13 +15,14 @@ from runner import configure from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame +from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.transcript_processor import TranscriptProcessor from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport @@ -57,12 +58,6 @@ async def on_transcript_update( timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" logger.info(f"{timestamp}{msg.role}: {msg.content}") - # # Log the full transcript - # logger.info("Full transcript:") - # for msg in self.messages: - # timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" - # logger.info(f"{timestamp}{msg.role}: {msg.content}") - async def main(): async with aiohttp.ClientSession() as session: @@ -70,16 +65,18 @@ async def main(): transport = DailyTransport( room_url, - token, + None, "Respond bot", DailyParams( audio_out_enabled=True, - transcription_enabled=True, vad_enabled=True, vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, ), ) + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady @@ -101,23 +98,25 @@ async def main(): context_aggregator = llm.create_context_aggregator(context) # Create transcript processor and handler - transcript_processor = TranscriptProcessor() + transcript = TranscriptProcessor() transcript_handler = TranscriptHandler() # Register event handler for transcript updates - @transcript_processor.event_handler("on_transcript_update") + @transcript.event_handler("on_transcript_update") async def on_transcript_update(processor, frame): await transcript_handler.on_transcript_update(processor, frame) pipeline = Pipeline( [ transport.input(), # Transport user input + stt, # STT + transcript.user(), # User transcripts context_aggregator.user(), # User responses llm, # LLM tts, # TTS transport.output(), # Transport bot output context_aggregator.assistant(), # Assistant spoken responses - transcript_processor, # Process transcripts + transcript.assistant(), # Assistant transcripts ] ) diff --git a/examples/foundational/28b-transcript-processor-anthropic.py b/examples/foundational/28b-transcript-processor-anthropic.py index 626206c5f..066828652 100644 --- a/examples/foundational/28b-transcript-processor-anthropic.py +++ b/examples/foundational/28b-transcript-processor-anthropic.py @@ -15,7 +15,7 @@ from runner import configure from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import LLMMessagesFrame, TranscriptionMessage, TranscriptionUpdateFrame +from pipecat.frames.frames import TranscriptionMessage, TranscriptionUpdateFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -23,6 +23,7 @@ from pipecat.processors.transcript_processor import TranscriptProcessor from pipecat.services.anthropic import AnthropicLLMService from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService from pipecat.transports.services.daily import DailyParams, DailyTransport load_dotenv(override=True) @@ -57,12 +58,6 @@ async def on_transcript_update( timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" logger.info(f"{timestamp}{msg.role}: {msg.content}") - # # Log the full transcript - # logger.info("Full transcript:") - # for msg in self.messages: - # timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" - # logger.info(f"{timestamp}{msg.role}: {msg.content}") - async def main(): async with aiohttp.ClientSession() as session: @@ -70,16 +65,18 @@ async def main(): transport = DailyTransport( room_url, - token, + None, "Respond bot", DailyParams( audio_out_enabled=True, - transcription_enabled=True, vad_enabled=True, vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, ), ) + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady @@ -101,23 +98,20 @@ async def main(): context_aggregator = llm.create_context_aggregator(context) # Create transcript processor and handler - transcript_processor = TranscriptProcessor() + transcript = TranscriptProcessor() transcript_handler = TranscriptHandler() - # Register event handler for transcript updates - @transcript_processor.event_handler("on_transcript_update") - async def on_transcript_update(processor, frame): - await transcript_handler.on_transcript_update(processor, frame) - pipeline = Pipeline( [ transport.input(), # Transport user input + stt, # STT + transcript.user(), # User transcripts context_aggregator.user(), # User responses llm, # LLM tts, # TTS transport.output(), # Transport bot output context_aggregator.assistant(), # Assistant spoken responses - transcript_processor, # Process transcripts + transcript.assistant(), # Assistant transcripts ] ) @@ -129,6 +123,11 @@ async def on_first_participant_joined(transport, participant): # Kick off the conversation. await task.queue_frames([context_aggregator.user().get_context_frame()]) + # Register event handler for transcript updates + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + await transcript_handler.on_transcript_update(processor, frame) + runner = PipelineRunner() await runner.run(task) diff --git a/examples/foundational/28c-transcription-processor-gemini.py b/examples/foundational/28c-transcription-processor-gemini.py index bf9448199..6c8118c57 100644 --- a/examples/foundational/28c-transcription-processor-gemini.py +++ b/examples/foundational/28c-transcription-processor-gemini.py @@ -22,6 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.transcript_processor import TranscriptProcessor from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService from pipecat.services.google import GoogleLLMService from pipecat.services.openai import OpenAILLMContext from pipecat.transports.services.daily import DailyParams, DailyTransport @@ -58,12 +59,6 @@ async def on_transcript_update( timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" logger.info(f"{timestamp}{msg.role}: {msg.content}") - # # Log the full transcript - # logger.info("Full transcript:") - # for msg in self.messages: - # timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" - # logger.info(f"{timestamp}{msg.role}: {msg.content}") - async def main(): async with aiohttp.ClientSession() as session: @@ -71,16 +66,18 @@ async def main(): transport = DailyTransport( room_url, - token, + None, "Respond bot", DailyParams( audio_out_enabled=True, - transcription_enabled=True, vad_enabled=True, vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, ), ) + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady @@ -104,23 +101,20 @@ async def main(): context_aggregator = llm.create_context_aggregator(context) # Create transcript processor and handler - transcript_processor = TranscriptProcessor() + transcript = TranscriptProcessor() transcript_handler = TranscriptHandler() - # Register event handler for transcript updates - @transcript_processor.event_handler("on_transcript_update") - async def on_transcript_update(processor, frame): - await transcript_handler.on_transcript_update(processor, frame) - pipeline = Pipeline( [ - transport.input(), - context_aggregator.user(), - llm, - tts, - transport.output(), - context_aggregator.assistant(), - transcript_processor, + transport.input(), # Transport user input + stt, # STT + transcript.user(), # User transcripts + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + transcript.assistant(), # Assistant transcripts ] ) @@ -139,6 +133,11 @@ async def on_first_participant_joined(transport, participant): # Kick off the conversation. await task.queue_frames([context_aggregator.user().get_context_frame()]) + # Register event handler for transcript updates + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + await transcript_handler.on_transcript_update(processor, frame) + runner = PipelineRunner() await runner.run(task) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index d02112a6f..e9d942b4e 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -207,13 +207,6 @@ def __str__(self): return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" -@dataclass -class OpenAILLMContextUserTimestampFrame(DataFrame): - """Timestamp information for user message in LLM context.""" - - timestamp: str - - @dataclass class OpenAILLMContextAssistantTimestampFrame(DataFrame): """Timestamp information for assistant message in LLM context.""" diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 612375da2..479746471 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -15,7 +15,6 @@ LLMMessagesFrame, LLMMessagesUpdateFrame, LLMSetToolsFrame, - OpenAILLMContextUserTimestampFrame, StartInterruptionFrame, TextFrame, TranscriptionFrame, @@ -27,7 +26,6 @@ OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.time import time_now_iso8601 class LLMResponseAggregator(FrameProcessor): @@ -291,10 +289,6 @@ async def _push_aggregation(self): frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) - # Push timestamp frame with current time - timestamp_frame = OpenAILLMContextUserTimestampFrame(timestamp=time_now_iso8601()) - await self.push_frame(timestamp_frame) - # Reset our accumulator state. self._reset() diff --git a/src/pipecat/processors/transcript_processor.py b/src/pipecat/processors/transcript_processor.py index be53cd79a..a95e502a8 100644 --- a/src/pipecat/processors/transcript_processor.py +++ b/src/pipecat/processors/transcript_processor.py @@ -4,7 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import List from loguru import logger @@ -12,7 +13,7 @@ ErrorFrame, Frame, OpenAILLMContextAssistantTimestampFrame, - OpenAILLMContextUserTimestampFrame, + TranscriptionFrame, TranscriptionMessage, TranscriptionUpdateFrame, ) @@ -20,55 +21,72 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -class TranscriptProcessor(FrameProcessor): - """Processes LLM context frames to generate timestamped conversation transcripts. +class BaseTranscriptProcessor(FrameProcessor, ABC): + """Base class for processing conversation transcripts. - This processor monitors OpenAILLMContextFrame frames and their corresponding - timestamp frames to build a chronological conversation transcript. Messages are - stored by role until their matching timestamp frame arrives, then emitted via - TranscriptionUpdateFrame. + Provides common functionality for handling transcript messages and updates. + """ - Each LLM context (OpenAI, Anthropic, Google) provides conversion to the standard format: - [ - { - "role": "user", - "content": [{"type": "text", "text": "Hi, how are you?"}] - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "Great! And you?"}] - } - ] + def __init__(self, **kwargs): + """Initialize processor with empty message store.""" + super().__init__(**kwargs) + self._processed_messages: List[TranscriptionMessage] = [] + self._register_event_handler("on_transcript_update") - Events: - on_transcript_update: Emitted when timestamped messages are available. - Args: TranscriptionUpdateFrame containing timestamped messages. + async def _emit_update(self, messages: List[TranscriptionMessage]): + """Emit transcript updates for new messages. - Example: - ```python - transcript_processor = TranscriptProcessor() + Args: + messages: New messages to emit in update + """ + if messages: + self._processed_messages.extend(messages) + update_frame = TranscriptionUpdateFrame(messages=messages) + await self._call_event_handler("on_transcript_update", update_frame) + await self.push_frame(update_frame) - @transcript_processor.event_handler("on_transcript_update") - async def on_transcript_update(processor, frame): - for msg in frame.messages: - print(f"[{msg.timestamp}] {msg.role}: {msg.content}") - ``` - """ + @abstractmethod + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames to build conversation transcript. - def __init__(self, **kwargs): - """Initialize the transcript processor. + Args: + frame: Input frame to process + direction: Frame processing direction + """ + await super().process_frame(frame, direction) + + +class UserTranscriptProcessor(BaseTranscriptProcessor): + """Processes user transcription frames into timestamped conversation messages.""" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process TranscriptionFrames into user conversation messages. Args: - **kwargs: Additional arguments passed to FrameProcessor + frame: Input frame to process + direction: Frame processing direction """ + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame): + message = TranscriptionMessage( + role="user", content=frame.text, timestamp=frame.timestamp + ) + await self._emit_update([message]) + + await self.push_frame(frame, direction) + + +class AssistantTranscriptProcessor(BaseTranscriptProcessor): + """Processes assistant LLM context frames into timestamped conversation messages.""" + + def __init__(self, **kwargs): + """Initialize processor with empty message stores.""" super().__init__(**kwargs) - self._processed_messages: List[TranscriptionMessage] = [] - self._register_event_handler("on_transcript_update") - self._pending_user_messages: List[TranscriptionMessage] = [] self._pending_assistant_messages: List[TranscriptionMessage] = [] def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]: - """Extract conversation messages from standard format. + """Extract assistant messages from the OpenAI standard message format. Args: messages: List of messages in OpenAI format, which can be either: @@ -80,21 +98,14 @@ def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]: """ result = [] for msg in messages: - # Only process user and assistant messages - if msg["role"] not in ("user", "assistant"): - continue - - if "content" not in msg: - logger.warning(f"Message missing content field: {msg}") + if msg["role"] != "assistant": continue content = msg.get("content") if isinstance(content, str): - # Handle simple string content if content: - result.append(TranscriptionMessage(role=msg["role"], content=content)) + result.append(TranscriptionMessage(role="assistant", content=content)) elif isinstance(content, list): - # Handle structured content text_parts = [] for part in content: if isinstance(part, dict) and part.get("type") == "text": @@ -102,13 +113,13 @@ def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]: if text_parts: result.append( - TranscriptionMessage(role=msg["role"], content=" ".join(text_parts)) + TranscriptionMessage(role="assistant", content=" ".join(text_parts)) ) return result def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]: - """Find messages in current that aren't in self._processed_messages. + """Find unprocessed messages from current list. Args: current: List of current messages @@ -126,28 +137,15 @@ def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[Transc return current[processed_len:] async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames to build a timestamped conversation transcript. - - Handles three frame types in sequence: - 1. OpenAILLMContextFrame: Contains new messages to be timestamped - 2. OpenAILLMContextUserTimestampFrame: Timestamp for user messages - 3. OpenAILLMContextAssistantTimestampFrame: Timestamp for assistant messages - - Messages are stored by role until their corresponding timestamp frame arrives. - When a timestamp frame is received, the matching messages are timestamped and - emitted in chronological order via TranscriptionUpdateFrame. + """Process frames into assistant conversation messages. Args: - frame: The frame to process + frame: Input frame to process direction: Frame processing direction - - Raises: - ErrorFrame: If message processing fails """ await super().process_frame(frame, direction) if isinstance(frame, OpenAILLMContextFrame): - # Extract and store messages by role standard_messages = [] for msg in frame.context.messages: converted = frame.context.to_standard_messages(msg) @@ -155,34 +153,83 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): current_messages = self._extract_messages(standard_messages) new_messages = self._find_new_messages(current_messages) - - # Store new messages by role - for msg in new_messages: - if msg.role == "user": - self._pending_user_messages.append(msg) - elif msg.role == "assistant": - self._pending_assistant_messages.append(msg) - - elif isinstance(frame, OpenAILLMContextUserTimestampFrame): - # Process pending user messages with timestamp - if self._pending_user_messages: - for msg in self._pending_user_messages: - msg.timestamp = frame.timestamp - self._processed_messages.extend(self._pending_user_messages) - update_frame = TranscriptionUpdateFrame(messages=self._pending_user_messages) - await self._call_event_handler("on_transcript_update", update_frame) - await self.push_frame(update_frame) - self._pending_user_messages = [] + self._pending_assistant_messages.extend(new_messages) elif isinstance(frame, OpenAILLMContextAssistantTimestampFrame): - # Process pending assistant messages with timestamp if self._pending_assistant_messages: for msg in self._pending_assistant_messages: msg.timestamp = frame.timestamp - self._processed_messages.extend(self._pending_assistant_messages) - update_frame = TranscriptionUpdateFrame(messages=self._pending_assistant_messages) - await self._call_event_handler("on_transcript_update", update_frame) - await self.push_frame(update_frame) + await self._emit_update(self._pending_assistant_messages) self._pending_assistant_messages = [] await self.push_frame(frame, direction) + + +class TranscriptProcessor: + """Factory for creating and managing transcript processors. + + Provides unified access to user and assistant transcript processors + with shared event handling. + + Example: + ```python + transcript = TranscriptProcessor() + + pipeline = Pipeline( + [ + transport.input(), + stt, + transcript.user(), # User transcripts + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + transcript.assistant(), # Assistant transcripts + ] + ) + + @transcript.event_handler("on_transcript_update") + async def handle_update(processor, frame): + print(f"New messages: {frame.messages}") + ``` + """ + + def __init__(self, **kwargs): + """Initialize factory with user and assistant processors.""" + self._user_processor = UserTranscriptProcessor(**kwargs) + self._assistant_processor = AssistantTranscriptProcessor(**kwargs) + self._event_handlers = {} + + def user(self) -> UserTranscriptProcessor: + """Get the user transcript processor.""" + return self._user_processor + + def assistant(self) -> AssistantTranscriptProcessor: + """Get the assistant transcript processor.""" + return self._assistant_processor + + def event_handler(self, event_name: str): + """Register event handler for both processors. + + Args: + event_name: Name of event to handle + + Returns: + Decorator function that registers handler with both processors + """ + + def decorator(handler): + self._event_handlers[event_name] = handler + + @self._user_processor.event_handler(event_name) + async def user_handler(processor, frame): + return await handler(processor, frame) + + @self._assistant_processor.event_handler(event_name) + async def assistant_handler(processor, frame): + return await handler(processor, frame) + + return handler + + return decorator diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index c7d32eff3..6bbf1d000 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -24,7 +24,6 @@ LLMMessagesFrame, LLMUpdateSettingsFrame, OpenAILLMContextAssistantTimestampFrame, - OpenAILLMContextUserTimestampFrame, TextFrame, TTSAudioRawFrame, TTSStartedFrame, @@ -234,10 +233,6 @@ async def _push_aggregation(self): frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) - # Push timestamp frame with current time - timestamp_frame = OpenAILLMContextUserTimestampFrame(timestamp=time_now_iso8601()) - await self.push_frame(timestamp_frame) - # Reset our accumulator state. self._reset()