From 314f3888d44d44626eaf0f00504a4aa3cfb6dac6 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 29 Feb 2024 19:42:48 +0530 Subject: [PATCH] refactor tts settings: replace complex function with conditional logic based on input flags with functional components --- .../language_model_settings_widgets.py | 2 - .../text_to_speech_settings_widgets.py | 486 +++++++++--------- gooey_ui/components/__init__.py | 2 +- recipes/TextToSpeech.py | 7 +- recipes/VideoBots.py | 10 +- 5 files changed, 245 insertions(+), 262 deletions(-) diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index 278fc1293..4083ece31 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -1,8 +1,6 @@ import gooey_ui as st -from daras_ai_v2.azure_doc_extract import azure_form_recognizer_models from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.field_render import field_title_desc, field_desc from daras_ai_v2.language_model import LargeLanguageModels diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index aec7788fe..9a820f7ae 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -142,267 +142,249 @@ class TextToSpeechProviders(Enum): } -def text_to_speech_settings( - page, include_title=True, include_selector=True, include_settings=True -): - if include_title: - st.write( +def text_to_speech_provider_selector(page): + col1, col2 = st.columns(2) + with col1: + tts_provider = enum_selector( + TextToSpeechProviders, + "###### Speech Provider", + key="tts_provider", + ) + with col2: + match tts_provider: + case TextToSpeechProviders.BARK.name: + bark_selector() + case TextToSpeechProviders.GOOGLE_TTS.name: + google_tts_selector() + case TextToSpeechProviders.UBERDUCK.name: + uberduck_selector() + case TextToSpeechProviders.ELEVEN_LABS.name: + elevenlabs_selector(page) + return tts_provider + + +def text_to_speech_settings(page, tts_provider): + match tts_provider: + case TextToSpeechProviders.BARK.name: + pass + case TextToSpeechProviders.GOOGLE_TTS.name: + google_tts_settings() + case TextToSpeechProviders.UBERDUCK.name: + uberduck_settings() + case TextToSpeechProviders.ELEVEN_LABS.name: + elevenlabs_settings() + + +def bark_selector(): + st.selectbox( + label=""" + ###### Bark History Prompt + """, + key="bark_history_prompt", + format_func=BARK_ALLOWED_PROMPTS.__getitem__, + options=BARK_ALLOWED_PROMPTS.keys(), + ) + + +def google_tts_selector(): + voices = google_tts_voices() + st.selectbox( + label=""" + ###### Voice name (Google TTS) + """, + key="google_voice_name", + format_func=voices.__getitem__, + options=voices.keys(), + ) + st.caption( + "*Please refer to the list of voice names [here](https://cloud.google.com/text-to-speech/docs/voices)*" + ) + + +def google_tts_settings(): + st.write(f"##### 🗣️ {TextToSpeechProviders.GOOGLE_TTS.value} Settings") + col1, col2 = st.columns(2) + with col1: + st.slider( """ - ##### 🗣️ Voice Settings + ###### Speaking rate + *`1.0` is the normal native speed of the speaker* + """, + min_value=0.3, + max_value=4.0, + step=0.1, + key="google_speaking_rate", + ) + with col2: + st.slider( """ + ###### Pitch + *Increase/Decrease semitones from the original pitch* + """, + min_value=-20.0, + max_value=20.0, + step=0.25, + key="google_pitch", ) - col1, col2 = st.columns(2) - if include_selector: - with col1: - tts_provider = enum_selector( - TextToSpeechProviders, - "###### Speech Provider", - key="tts_provider", - ) + +def uberduck_selector(): + st.selectbox( + label=""" + ###### Voice name (Uberduck) + """, + key="uberduck_voice_name", + format_func=lambda option: f"{option}", + options=UBERDUCK_VOICES.keys(), + ) + + +def uberduck_settings(): + st.write(f"##### 🗣️ {TextToSpeechProviders.UBERDUCK.value} Settings") + st.slider( + """ + ###### Speaking rate + *`1.0` is the normal native speed of the speaker* + """, + min_value=0.5, + max_value=3.0, + step=0.25, + key="uberduck_speaking_rate", + ) + + +def elevenlabs_selector(page): + if not st.session_state.get("elevenlabs_api_key"): + st.session_state["elevenlabs_api_key"] = page.request.session.get( + SESSION_ELEVENLABS_API_KEY + ) + + elevenlabs_use_custom_key = st.checkbox( + "Use custom API key + Voice ID", + value=bool(st.session_state.get("elevenlabs_api_key")), + ) + if elevenlabs_use_custom_key: + st.session_state["elevenlabs_voice_name"] = None + elevenlabs_api_key = st.text_input( + """ + ###### Your ElevenLabs API key + *Read this + to know how to obtain an API key from + ElevenLabs.* + """, + key="elevenlabs_api_key", + ) + + selected_voice_id = st.session_state.get("elevenlabs_voice_id") + elevenlabs_voices = ( + {selected_voice_id: selected_voice_id} if selected_voice_id else {} + ) + + if elevenlabs_api_key: + try: + elevenlabs_voices = fetch_elevenlabs_voices(elevenlabs_api_key) + except requests.exceptions.HTTPError as e: + st.error(f"Invalid ElevenLabs API key. Failed to fetch voices: {e}") + + st.selectbox( + """ + ###### Voice ID (ElevenLabs) + """, + key="elevenlabs_voice_id", + options=elevenlabs_voices.keys(), + format_func=elevenlabs_voices.__getitem__, + ) else: - tts_provider = st.session_state.get("tts_provider") + st.session_state["elevenlabs_api_key"] = None + st.session_state["elevenlabs_voice_id"] = None + if not ( + page and (page.is_current_user_paying() or page.is_current_user_admin()) + ): + st.caption( + """ + Note: Please purchase Gooey.AI credits to use ElevenLabs voices [here](/account). + Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. + """ + ) - col = col2 if include_selector else st.div() - match tts_provider: - case TextToSpeechProviders.BARK.name: - if not include_settings: - return - - with col: - st.selectbox( - label=""" - ###### Bark History Prompt - """, - key="bark_history_prompt", - format_func=BARK_ALLOWED_PROMPTS.__getitem__, - options=BARK_ALLOWED_PROMPTS.keys(), - ) + st.session_state.update(elevenlabs_api_key=None, elevenlabs_voice_id=None) + st.selectbox( + """ + ###### Voice Name (ElevenLabs) + """, + key="elevenlabs_voice_name", + format_func=str, + options=ELEVEN_LABS_VOICES.keys(), + ) - case TextToSpeechProviders.GOOGLE_TTS.name: - with col2: - if include_selector: - voices = google_tts_voices() - st.selectbox( - label=""" - ###### Voice name (Google TTS) - """, - key="google_voice_name", - format_func=voices.__getitem__, - options=voices.keys(), - ) - st.caption( - "*Please refer to the list of voice names [here](https://cloud.google.com/text-to-speech/docs/voices)*" - ) - - if not include_settings: - return - - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Speaking rate - *`1.0` is the normal native speed of the speaker* - """, - min_value=0.3, - max_value=4.0, - step=0.1, - key="google_speaking_rate", - ) - with col2: - st.slider( - """ - ###### Pitch - *Increase/Decrease semitones from the original pitch* - """, - min_value=-20.0, - max_value=20.0, - step=0.25, - key="google_pitch", - ) + page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get( + "elevenlabs_api_key" + ) - case TextToSpeechProviders.UBERDUCK.name: - with col2: - if include_selector: - st.selectbox( - label=""" - ###### Voice name (Uberduck) - """, - key="uberduck_voice_name", - format_func=lambda option: f"{option}", - options=UBERDUCK_VOICES.keys(), - ) - - if not include_settings: - return - - with col: - st.slider( - """ - ###### Speaking rate - *`1.0` is the normal native speed of the speaker* - """, - min_value=0.5, - max_value=3.0, - step=0.25, - key="uberduck_speaking_rate", - ) + st.selectbox( + """ + ###### Voice Model + """, + key="elevenlabs_model", + format_func=ELEVEN_LABS_MODELS.__getitem__, + options=ELEVEN_LABS_MODELS.keys(), + ) - case TextToSpeechProviders.ELEVEN_LABS.name: - if include_selector: - with col2: - st.selectbox( - """ - ###### Voice Name (ElevenLabs) - """, - key="elevenlabs_voice_name", - format_func=str, - options=ELEVEN_LABS_VOICES.keys(), - ) - - if not include_settings: - return - - with col: - if not st.session_state.get("elevenlabs_api_key"): - st.session_state["elevenlabs_api_key"] = page.request.session.get( - SESSION_ELEVENLABS_API_KEY - ) - - elevenlabs_use_custom_key = st.checkbox( - "Use custom ElevenLabs API key + Voice ID", - value=bool(st.session_state.get("elevenlabs_api_key")), - ) - - if elevenlabs_use_custom_key: - st.session_state["elevenlabs_voice_name"] = None - elevenlabs_api_key = st.text_input( - """ - ###### Your ElevenLabs API key - *Read this - to know how to obtain an API key from - ElevenLabs.* - """, - key="elevenlabs_api_key", - ) - - selected_voice_id = st.session_state.get("elevenlabs_voice_id") - elevenlabs_voices = ( - {selected_voice_id: selected_voice_id} - if selected_voice_id - else {} - ) - - if elevenlabs_api_key: - try: - elevenlabs_voices = fetch_elevenlabs_voices( - elevenlabs_api_key - ) - except requests.exceptions.HTTPError as e: - st.error( - f"Invalid ElevenLabs API key. Failed to fetch voices: {e}" - ) - - st.selectbox( - """ - ###### Voice ID (ElevenLabs) - """, - key="elevenlabs_voice_id", - options=elevenlabs_voices.keys(), - format_func=elevenlabs_voices.__getitem__, - ) - else: - st.session_state["elevenlabs_api_key"] = None - st.session_state["elevenlabs_voice_id"] = None - if not ( - page - and ( - page.is_current_user_paying() - or page.is_current_user_admin() - ) - ): - st.caption( - """ - Note: Please purchase Gooey.AI credits to use ElevenLabs voices [here](/account). - Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. - """ - ) - - st.session_state.update( - elevenlabs_api_key=None, elevenlabs_voice_id=None - ) - - page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get( - "elevenlabs_api_key" - ) - - st.selectbox( - """ - ###### Voice Model - """, - key="elevenlabs_model", - format_func=ELEVEN_LABS_MODELS.__getitem__, - options=ELEVEN_LABS_MODELS.keys(), - ) - - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Stability - *A lower stability provides a broader emotional range. - A value lower than 0.3 can lead to too much instability. - [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#stability).* - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_stability", - ) - with col2: - st.slider( - """ - ###### Similarity Boost - *Dictates how hard the model should try to replicate the original voice. - [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#similarity).* - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_similarity_boost", - ) - - if st.session_state.get("elevenlabs_model") == "eleven_multilingual_v2": - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Style Exaggeration - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_style", - value=0.0, - ) - with col2: - st.checkbox( - "Speaker Boost", - key="elevenlabs_speaker_boost", - value=True, - ) - - with st.expander( - "Eleven Labs Supported Languages", - style={"fontSize": "0.9rem", "textDecoration": "underline"}, - ): - st.caption( - "With Multilingual V2 voice model", style={"fontSize": "0.8rem"} - ) - st.caption( - ", ".join(ELEVEN_LABS_SUPPORTED_LANGS), style={"fontSize": "0.8rem"} - ) + +def elevenlabs_settings(): + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Stability + *A lower stability provides a broader emotional range. + A value lower than 0.3 can lead to too much instability. + [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#stability).* + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_stability", + ) + with col2: + st.slider( + """ + ###### Similarity Boost + *Dictates how hard the model should try to replicate the original voice. + [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#similarity).* + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_similarity_boost", + ) + + if st.session_state.get("elevenlabs_model") == "eleven_multilingual_v2": + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Style Exaggeration + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_style", + value=0.0, + ) + with col2: + st.checkbox( + "Speaker Boost", + key="elevenlabs_speaker_boost", + value=True, + ) + + with st.expander( + "Eleven Labs Supported Languages", + style={"fontSize": "0.9rem", "textDecoration": "underline"}, + ): + st.caption("With Multilingual V2 voice model", style={"fontSize": "0.8rem"}) + st.caption(", ".join(ELEVEN_LABS_SUPPORTED_LANGS), style={"fontSize": "0.8rem"}) @redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 6322cf629..5cd8d9721 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -396,7 +396,7 @@ def multiselect( def selectbox( label: str, - options: typing.Sequence[T], + options: typing.Iterable[T], format_func: typing.Callable[[T], typing.Any] = _default_format, key: str = None, help: str = None, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 7d8f301fe..716dc5f1d 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -21,6 +21,7 @@ ELEVEN_LABS_MODELS, text_to_speech_settings, TextToSpeechProviders, + text_to_speech_provider_selector, ) DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png" @@ -109,6 +110,7 @@ def render_form_v2(self): """, key="text_prompt", ) + text_to_speech_provider_selector(self) def fields_to_save(self): fields = super().fields_to_save() @@ -117,10 +119,11 @@ def fields_to_save(self): return fields def validate_form_v2(self): - assert st.session_state["text_prompt"], "Text input cannot be empty" + assert st.session_state.get("text_prompt"), "Text input cannot be empty" + assert st.session_state.get("tts_provider"), "Please select a TTS provider" def render_settings(self): - text_to_speech_settings(page=self) + text_to_speech_settings(self, st.session_state.get("tts_provider")) def get_raw_price(self, state: dict): tts_provider = self._get_tts_provider(state) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index c9706df49..adf7301b3 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -72,6 +72,7 @@ from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, text_to_speech_settings, + text_to_speech_provider_selector, ) from daras_ai_v2.vector_search import DocSearchRequest from recipes.DocSearch import ( @@ -348,9 +349,7 @@ def render_form_v2(self): st.session_state["tts_provider"] = None enable_video = False else: - text_to_speech_settings( - page=self, include_title=False, include_settings=False - ) + text_to_speech_provider_selector(self) st.write("---") if not "__enable_video" in st.session_state: st.session_state["__enable_video"] = bool( @@ -411,8 +410,9 @@ def render_usage_guide(self): youtube_video("-j2su1r8pEg") def render_settings(self): - if st.session_state.get("__enable_audio"): - text_to_speech_settings(page=self, include_selector=False) + tts_provider = st.session_state.get("tts_provider") + if tts_provider: + text_to_speech_settings(self, tts_provider) if st.session_state.get("__enable_video"): lipsync_settings()