diff --git a/CHANGELOG.md b/CHANGELOG.md index cc1622fc6..408094dde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added `MarkdownRemovalProcessor`. This processor removes markdown formatting - from a TextFrame. It's intended to be used between the LLM and TTS in order - to remove markdown from the text the TTS speaks. +- Added a new util called `MarkdownTextFilter` which is a subclass of a new + base class called `BaseTextFilter`. This is a configurable utility which + is intended to filter text received by TTS services. - Added new `RTVIUserLLMTextProcessor`. This processor will send an RTVI `user-llm-text` message with the user content's that was sent to the LLM. diff --git a/src/pipecat/processors/text/markdown_remover.py b/src/pipecat/processors/text/markdown_remover.py deleted file mode 100644 index 3c08d5fb2..000000000 --- a/src/pipecat/processors/text/markdown_remover.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright (c) 2024, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import re - -from markdown import Markdown - -from pipecat.frames.frames import Frame, TextFrame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor - - -class MarkdownRemovalProcessor(FrameProcessor): - """Removes Markdown formatting from text in TextFrames. - - Converts Markdown to plain text while preserving the overall structure, - including leading and trailing spaces. Handles special cases like - asterisks and table formatting. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._md = Markdown() - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - cleaned_text = self._remove_markdown(frame.text) - await self.push_frame(TextFrame(text=cleaned_text)) - else: - await self.push_frame(frame, direction) - - def _remove_markdown(self, markdown_string: str) -> str: - # Replace newlines with spaces to handle cases with leading newlines - markdown_string = markdown_string.replace("\n", " ") - - # Preserve numbered list items with a unique marker, §NUM§ - markdown_string = re.sub(r"^(\d+\.)\s", r"§NUM§\1 ", markdown_string) - - # Preserve leading/trailing spaces with a unique marker, § - # Critical for word-by-word streaming in bot-tts-text - preserved_markdown = re.sub( - r"^( +)|\s+$", lambda m: "§" * len(m.group(0)), markdown_string, flags=re.MULTILINE - ) - - # Convert markdown to HTML - md = Markdown() - html = md.convert(preserved_markdown) - - # Remove HTML tags - text = re.sub("<[^<]+?>", "", html) - - # Replace HTML entities - text = text.replace(" ", " ") - text = text.replace("<", "<") - text = text.replace(">", ">") - text = text.replace("&", "&") - - # Remove leading/trailing asterisks - # Necessary for bot-tts-text, as they appear as literal asterisks - text = re.sub(r"^\*{1,2}|\*{1,2}$", "", text) - - # Remove Markdown table formatting - text = re.sub(r"\|", "", text) - text = re.sub(r"^\s*[-:]+\s*$", "", text, flags=re.MULTILINE) - - # Restore numbered list items - text = text.replace("§NUM§", "") - - # Restore leading and trailing spaces - text = re.sub("§", " ", text) - - return text diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 3deabf68c..20a587912 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -37,6 +37,7 @@ from pipecat.transcriptions.language import Language from pipecat.utils.audio import calculate_audio_volume from pipecat.utils.string import match_endofsentence +from pipecat.utils.text.base_text_filter import BaseTextFilter from pipecat.utils.time import seconds_to_nanoseconds from pipecat.utils.utils import exp_smoothing @@ -172,6 +173,7 @@ def __init__( stop_frame_timeout_s: float = 1.0, # TTS output sample rate sample_rate: int = 16000, + text_filter: Optional[BaseTextFilter] = None, **kwargs, ): super().__init__(**kwargs) @@ -182,6 +184,7 @@ def __init__( self._sample_rate: int = sample_rate self._voice_id: str = "" self._settings: Dict[str, Any] = {} + self._text_filter: Optional[BaseTextFilter] = text_filter self._stop_frame_task: Optional[asyncio.Task] = None self._stop_frame_queue: asyncio.Queue = asyncio.Queue() @@ -242,6 +245,8 @@ async def _update_settings(self, settings: Dict[str, Any]): self.set_model_name(value) elif key == "voice": self.set_voice(value) + elif key == "text_filter" and self._text_filter: + self._text_filter.update_settings(value) else: logger.warning(f"Unknown setting for TTS service: {key}") @@ -312,6 +317,8 @@ async def _push_tts_frames(self, text: str): return await self.start_processing_metrics() + if self._text_filter: + text = self._text_filter.filter(text) await self.process_generator(self.run_tts(text)) await self.stop_processing_metrics() if self._push_text_frames: diff --git a/src/pipecat/processors/text/__init__.py b/src/pipecat/utils/text/__init__.py similarity index 100% rename from src/pipecat/processors/text/__init__.py rename to src/pipecat/utils/text/__init__.py diff --git a/src/pipecat/utils/text/base_text_filter.py b/src/pipecat/utils/text/base_text_filter.py new file mode 100644 index 000000000..69d5d4fe1 --- /dev/null +++ b/src/pipecat/utils/text/base_text_filter.py @@ -0,0 +1,18 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import ABC, abstractmethod +from typing import Any, Mapping + + +class BaseTextFilter(ABC): + @abstractmethod + def update_settings(self, settings: Mapping[str, Any]): + pass + + @abstractmethod + def filter(self, text: str) -> str: + pass diff --git a/src/pipecat/utils/text/markdown_text_filter.py b/src/pipecat/utils/text/markdown_text_filter.py new file mode 100644 index 000000000..3018b8788 --- /dev/null +++ b/src/pipecat/utils/text/markdown_text_filter.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import re +from typing import Any, Mapping + +from markdown import Markdown +from pydantic import BaseModel + +from pipecat.utils.text.base_text_filter import BaseTextFilter + + +class MarkdownTextFilter(BaseTextFilter): + """Removes Markdown formatting from text in TextFrames. + + Converts Markdown to plain text while preserving the overall structure, + including leading and trailing spaces. Handles special cases like + asterisks and table formatting. + """ + + class InputParams(BaseModel): + enable_text_filter: bool = True + + def __init__(self, params: InputParams = InputParams(), **kwargs): + super().__init__(**kwargs) + self._settings = params + + def update_settings(self, settings: Mapping[str, Any]): + for key, value in settings.items(): + if hasattr(self._settings, key): + setattr(self._settings, key, value) + + def filter(self, text: str) -> str: + if self._settings.enable_text_filter: + # Replace newlines with spaces only when there's no text before or after + text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE) + + # Remove repeated sequences of 5 or more characters + text = re.sub(r"(\S)(\1{4,})", "", text) + + # Preserve numbered list items with a unique marker, §NUM§ + text = re.sub(r"^(\d+\.)\s", r"§NUM§\1 ", text) + + # Preserve leading/trailing spaces with a unique marker, § + # Critical for word-by-word streaming in bot-tts-text + preserved_markdown = re.sub( + r"^( +)|\s+$", lambda m: "§" * len(m.group(0)), text, flags=re.MULTILINE + ) + + # Convert markdown to HTML + md = Markdown() + html = md.convert(preserved_markdown) + + # Remove HTML tags + filtered_text = re.sub("<[^<]+?>", "", html) + + # Replace HTML entities + filtered_text = filtered_text.replace(" ", " ") + filtered_text = filtered_text.replace("<", "<") + filtered_text = filtered_text.replace(">", ">") + filtered_text = filtered_text.replace("&", "&") + + # Remove double asterisks (consecutive without any exceptions) + filtered_text = re.sub(r"\*\*", "", filtered_text) + + # Remove single asterisks at the start or end of words + filtered_text = re.sub(r"(^|\s)\*|\*($|\s)", r"\1\2", filtered_text) + + # Remove Markdown table formatting + filtered_text = re.sub(r"\|", "", filtered_text) + filtered_text = re.sub(r"^\s*[-:]+\s*$", "", filtered_text, flags=re.MULTILINE) + + # Restore numbered list items + filtered_text = filtered_text.replace("§NUM§", "") + + # Restore leading and trailing spaces + filtered_text = re.sub("§", " ", filtered_text) + + return filtered_text + else: + return text