Skip to content

Commit

Permalink
Add notes, pricing for 11labs with custom key and remove it from db f…
Browse files Browse the repository at this point in the history
…ields
  • Loading branch information
nikochiko committed Nov 7, 2023
1 parent 7994887 commit bb7e8c8
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 22 deletions.
10 changes: 8 additions & 2 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,21 @@ def api_url(self, example_id=None, run_id=None, uid=None) -> furl:
def endpoint(self) -> str:
return f"/v2/{self.slug_versions[0]}/"

def render(self):
def before_render(self):
"""
Side-effects to apply before doing the actual render.
This shouldn't actually render anything to the page.
"""
with sentry_sdk.configure_scope() as scope:
scope.set_extra("base_url", self.app_url())
scope.set_transaction_name(
"/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE
)

example_id, run_id, uid = extract_query_params(gooey_get_query_params())
def render(self):
self.before_render()

example_id, run_id, uid = extract_query_params(gooey_get_query_params())
if st.session_state.get(StateKeys.run_status):
channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
output = realtime_pull([channel])[0]
Expand Down
2 changes: 1 addition & 1 deletion daras_ai_v2/text_to_speech_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class TextToSpeechProviders(Enum):
GOOGLE_TTS = "Google Cloud Text-to-Speech"
ELEVEN_LABS = "Eleven Labs (Premium)"
ELEVEN_LABS = "Eleven Labs"
UBERDUCK = "uberduck.ai"
BARK = "Bark (suno-ai)"

Expand Down
1 change: 0 additions & 1 deletion recipes/LipsyncTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class LipsyncTTSPage(LipsyncPage, TextToSpeechPage):
"elevenlabs_stability": 0.5,
"elevenlabs_similarity_boost": 0.75,
}
private_fields = ["elevenlabs_api_key"]

class RequestModel(BaseModel):
input_face: str
Expand Down
50 changes: 39 additions & 11 deletions recipes/TextToSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class TextToSpeechPage(BasePage):
"elevenlabs_stability": 0.5,
"elevenlabs_similarity_boost": 0.75,
}
private_fields = ["elevenlabs_api_key"]

class RequestModel(BaseModel):
text_prompt: str
Expand Down Expand Up @@ -81,6 +80,16 @@ def fallback_preivew_image(self) -> str | None:
def preview_description(self, state: dict) -> str:
return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app."

def before_render(self):
super().before_render()
if st.session_state.get("tts_provider") == TextToSpeechProviders.ELEVEN_LABS.name:
if elevenlabs_api_key := st.session_state.get("elevenlabs_api_key"):
self.request.session["state"] = dict(elevenlabs_api_key=elevenlabs_api_key)
elif "elevenlabs_api_key" in self.request.session.get("state", {}):
st.session_state["elevenlabs_api_key"] = self.request.session["state"][
"elevenlabs_api_key"
]

def render_description(self):
st.write(
"""
Expand All @@ -103,6 +112,12 @@ def render_form_v2(self):
key="text_prompt",
)

def fields_to_save(self):
fields = super().fields_to_save()
if "elevenlabs_api_key" in fields:
fields.remove("elevenlabs_api_key")
return fields

def validate_form_v2(self):
assert st.session_state["text_prompt"], "Text input cannot be empty"

Expand Down Expand Up @@ -130,9 +145,13 @@ def render_output(self):
st.div()

def _get_eleven_labs_price(self, state: dict):
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079
_, is_user_provided_key = self._get_elevenlabs_api_key(state)
if is_user_provided_key:
return 0
else:
text = state.get("text_prompt", "")
# 0.079 credits / character ~ 4 credits / 10 words
return len(text) * 0.079

def _get_tts_provider(self, state: dict):
tts_provider = state.get("tts_provider", TextToSpeechProviders.UBERDUCK.name)
Expand All @@ -142,9 +161,15 @@ def _get_tts_provider(self, state: dict):
def additional_notes(self):
tts_provider = st.session_state.get("tts_provider")
if tts_provider == TextToSpeechProviders.ELEVEN_LABS.name:
return """
*Eleven Labs cost ≈ 4 credits per 10 words*
"""
_, is_user_provided_key = self._get_elevenlabs_api_key(st.session_state)
if is_user_provided_key:
return """
*Eleven Labs cost ≈ No additional credit charge given we'll use your API key*
"""
else:
return """
*Eleven Labs cost ≈ 4 credits per 10 words*
"""
else:
return ""

Expand Down Expand Up @@ -242,7 +267,7 @@ def run(self, state: dict):
)

case TextToSpeechProviders.ELEVEN_LABS:
xi_api_key = self._get_elevenlabs_api_key(state)
xi_api_key, _ = self._get_elevenlabs_api_key(state)
voice_model = self._get_elevenlabs_voice_model(state)
voice_id = self._get_elevenlabs_voice_id(state)

Expand Down Expand Up @@ -290,17 +315,20 @@ def _get_elevenlabs_voice_id(self, state: dict[str, str]):
assert voice_name in ELEVEN_LABS_VOICES, f"Invalid voice_name: {voice_name}"
return ELEVEN_LABS_VOICES[voice_name] # voice_name -> voice_id

def _get_elevenlabs_api_key(self, state: dict[str, str]):
def _get_elevenlabs_api_key(self, state: dict[str, str]) -> tuple[str, bool]:
"""
Returns the 11labs API key and whether it is a user-provided key or the default key
"""
# ElevenLabs is available for non-paying users with their own API key
if state.get("elevenlabs_api_key"):
return state["elevenlabs_api_key"]
return state["elevenlabs_api_key"], True
else:
assert (
self.is_current_user_paying() or self.is_current_user_admin()
), """
Please purchase Gooey.AI credits to use ElevenLabs voices <a href="/account">here</a>.
"""
return settings.ELEVEN_LABS_API_KEY
return settings.ELEVEN_LABS_API_KEY, False

def related_workflows(self) -> list:
from recipes.VideoBots import VideoBotsPage
Expand Down
16 changes: 14 additions & 2 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ class VideoBotsPage(BasePage):
"use_url_shortener": False,
"dense_weight": 1.0,
}
private_fields = ["elevenlabs_api_key"]

class RequestModel(BaseModel):
input_prompt: str
Expand Down Expand Up @@ -276,6 +275,16 @@ class ResponseModel(BaseModel):
def preview_image(self, state: dict) -> str | None:
return DEFAULT_COPILOT_META_IMG

def before_render(self):
super().before_render()
if st.session_state.get("tts_provider") == TextToSpeechProviders.ELEVEN_LABS.name:
if elevenlabs_api_key := st.session_state.get("elevenlabs_api_key"):
self.request.session["state"] = dict(elevenlabs_api_key=elevenlabs_api_key)
elif "elevenlabs_api_key" in self.request.session.get("state", {}):
st.session_state["elevenlabs_api_key"] = self.request.session["state"][
"elevenlabs_api_key"
]

def related_workflows(self):
from recipes.LipsyncTTS import LipsyncTTSPage
from recipes.CompareText2Img import CompareText2ImgPage
Expand Down Expand Up @@ -403,7 +412,10 @@ def render_settings(self):
lipsync_settings()

def fields_to_save(self) -> [str]:
return super().fields_to_save() + ["landbot_url"]
fields = super().fields_to_save() + ["landbot_url"]
if "elevenlabs_api_key" in fields:
fields.remove("elevenlabs_api_key")
return fields

def render_example(self, state: dict):
input_prompt = state.get("input_prompt")
Expand Down
5 changes: 0 additions & 5 deletions routers/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,6 @@ def st_page(
state.update(db_state)
for k, v in page.sane_defaults.items():
state.setdefault(k, v)

# pop private fields if not the owner
if not page.is_current_user_owner():
for k in page.private_fields:
state.pop(k, None)
if state is None:
raise HTTPException(status_code=404)

Expand Down

0 comments on commit bb7e8c8

Please sign in to comment.