diff --git a/bots/models.py b/bots/models.py
index 205dfb5ff..558512e8c 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -700,8 +700,8 @@ def to_df_format(
.replace(tzinfo=None)
)
row |= {
- "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"),
- "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"),
+ "Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT),
+ "First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT),
"A7": not convo.d7(),
"A30": not convo.d30(),
"R1": last_time - first_time < datetime.timedelta(days=1),
@@ -926,7 +926,7 @@ def to_df_format(
"Message (EN)": message.content,
"Sent": message.created_at.astimezone(tz)
.replace(tzinfo=None)
- .strftime("%b %d, %Y %I:%M %p"),
+ .strftime(settings.SHORT_DATETIME_FORMAT),
"Feedback": (
message.feedbacks.first().get_display_text()
if message.feedbacks.first()
@@ -968,7 +968,7 @@ def to_df_analysis_format(
"Answer (EN)": message.content,
"Sent": message.created_at.astimezone(tz)
.replace(tzinfo=None)
- .strftime("%b %d, %Y %I:%M %p"),
+ .strftime(settings.SHORT_DATETIME_FORMAT),
"Analysis JSON": message.analysis_result,
}
rows.append(row)
@@ -1153,16 +1153,16 @@ def to_df_format(
"Question Sent": feedback.message.get_previous_by_created_at()
.created_at.astimezone(tz)
.replace(tzinfo=None)
- .strftime("%b %d, %Y %I:%M %p"),
+ .strftime(settings.SHORT_DATETIME_FORMAT),
"Answer (EN)": feedback.message.content,
"Answer Sent": feedback.message.created_at.astimezone(tz)
.replace(tzinfo=None)
- .strftime("%b %d, %Y %I:%M %p"),
+ .strftime(settings.SHORT_DATETIME_FORMAT),
"Rating": Feedback.Rating(feedback.rating).label,
"Feedback (EN)": feedback.text_english,
"Feedback Sent": feedback.created_at.astimezone(tz)
.replace(tzinfo=None)
- .strftime("%b %d, %Y %I:%M %p"),
+ .strftime(settings.SHORT_DATETIME_FORMAT),
"Question Answered": feedback.message.question_answered,
}
rows.append(row)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 28756f96a..0a885bf0a 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1348,11 +1348,11 @@ def _render_output_col(self, submitted: bool):
# render outputs
self.render_output()
- if run_state != "waiting":
+ if run_state != RecipeRunState.running:
self._render_after_output()
def _render_completed_output(self):
- run_time = st.session_state.get(StateKeys.run_time, 0)
+ pass
def _render_failed_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
@@ -1368,12 +1368,10 @@ def render_extra_waiting_output(self):
if not estimated_run_time:
return
if created_at := st.session_state.get("created_at"):
- if isinstance(created_at, datetime.datetime):
- start_time = created_at
- else:
- start_time = datetime.datetime.fromisoformat(created_at)
+ if isinstance(created_at, str):
+ created_at = datetime.datetime.fromisoformat(created_at)
with st.countdown_timer(
- end_time=start_time + datetime.timedelta(seconds=estimated_run_time),
+ end_time=created_at + datetime.timedelta(seconds=estimated_run_time),
delay_text="Sorry for the wait. Your run is taking longer than we expected.",
):
if self.is_current_user_owner() and self.request.user.email:
@@ -1514,6 +1512,8 @@ def clear_outputs(self):
st.session_state.pop(field_name, None)
def _render_after_output(self):
+ self._render_report_button()
+
if "seed" in self.RequestModel.schema_json():
randomize = st.button(
' Regenerate', type="tertiary"
@@ -1521,27 +1521,8 @@ def _render_after_output(self):
if randomize:
st.session_state[StateKeys.pressed_randomize] = True
st.experimental_rerun()
- caption = ""
- caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s'
- if "seed" in self.RequestModel.schema_json():
- seed = st.session_state.get("seed")
- caption += f' with seed {seed} '
- created_at = st.session_state.get(
- StateKeys.created_at, datetime.datetime.today()
- )
- if not isinstance(created_at, datetime.datetime):
- created_at = datetime.datetime.fromisoformat(created_at)
- format_created_at = created_at.strftime("%d %b %Y %-I:%M%p")
- caption += f' at {format_created_at}'
- st.caption(caption, unsafe_allow_html=True)
- def render_buttons(self, url: str):
- st.download_button(
- label=' Download',
- url=url,
- type="secondary",
- )
- self._render_report_button()
+ render_output_caption()
def state_to_doc(self, state: dict):
ret = {
@@ -1918,6 +1899,26 @@ def is_current_user_owner(self) -> bool:
)
+def render_output_caption():
+ caption = ""
+
+ run_time = st.session_state.get(StateKeys.run_time, 0)
+ if run_time:
+ caption += f'Generated in {run_time :.2f}s'
+
+ if seed := st.session_state.get("seed"):
+ caption += f' with seed {seed} '
+
+ created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today())
+ if created_at:
+ if isinstance(created_at, str):
+ created_at = datetime.datetime.fromisoformat(created_at)
+ format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT)
+ caption += f' at {format_created_at}'
+
+ st.caption(caption, unsafe_allow_html=True)
+
+
def get_example_request_body(
request_model: typing.Type[BaseModel],
state: dict,
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index deda81500..512fb3834 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -162,6 +162,8 @@
es_formats.DATETIME_FORMAT = DATETIME_FORMAT
+SHORT_DATETIME_FORMAT = "%b %d, %Y %-I:%M %p"
+
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py
index 7e435b725..6322cf629 100644
--- a/gooey_ui/components/__init__.py
+++ b/gooey_ui/components/__init__.py
@@ -217,6 +217,7 @@ def image(
caption: str = None,
alt: str = None,
href: str = None,
+ show_download_button: bool = False,
**props,
):
if isinstance(src, np.ndarray):
@@ -241,9 +242,18 @@ def image(
**props,
),
).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
-def video(src: str, caption: str = None, autoplay: bool = False):
+def video(
+ src: str,
+ caption: str = None,
+ autoplay: bool = False,
+ show_download_button: bool = False,
+):
autoplay_props = {}
if autoplay:
autoplay_props = {
@@ -266,15 +276,23 @@ def video(src: str, caption: str = None, autoplay: bool = False):
name="video",
props=dict(src=src, caption=dedent(caption), **autoplay_props),
).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
-def audio(src: str, caption: str = None):
+def audio(src: str, caption: str = None, show_download_button: bool = False):
if not src:
return
state.RenderTreeNode(
name="audio",
props=dict(src=src, caption=dedent(caption)),
).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
def text_area(
@@ -415,8 +433,9 @@ def selectbox(
return value
-def button(
+def download_button(
label: str,
+ url: str,
key: str = None,
help: str = None,
*,
@@ -424,43 +443,26 @@ def button(
disabled: bool = False,
**props,
) -> bool:
- """
- Example:
- st.button("Primary", key="test0", type="primary")
- st.button("Secondary", key="test1")
- st.button("Tertiary", key="test3", type="tertiary")
- st.button("Link Button", key="test3", type="link")
- """
- if not key:
- key = md5_values("button", label, help, type, props)
- className = f"btn-{type} " + props.pop("className", "")
- state.RenderTreeNode(
- name="gui-button",
- props=dict(
- type="submit",
- value="yes",
- name=key,
- label=dedent(label),
- help=help,
- disabled=disabled,
- className=className,
- **props,
- ),
- ).mount()
- return bool(state.session_state.pop(key, False))
-
-
-form_submit_button = button
+ return button(
+ component="download-button",
+ url=url,
+ label=label,
+ key=key,
+ help=help,
+ type=type,
+ disabled=disabled,
+ **props,
+ )
-def download_button(
+def button(
label: str,
- url: str,
key: str = None,
help: str = None,
*,
type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
disabled: bool = False,
+ component: typing.Literal["download-button", "gui-button"] = "gui-button",
**props,
) -> bool:
"""
@@ -474,11 +476,10 @@ def download_button(
key = md5_values("button", label, help, type, props)
className = f"btn-{type} " + props.pop("className", "")
state.RenderTreeNode(
- name="download-button",
+ name=component,
props=dict(
type="submit",
value="yes",
- url=url,
name=key,
label=dedent(label),
help=help,
@@ -490,6 +491,9 @@ def download_button(
return bool(state.session_state.pop(key, False))
+form_submit_button = button
+
+
def expander(label: str, *, expanded: bool = False, **props):
node = state.RenderTreeNode(
name="expander",
diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py
index 47347bb65..dc5ea1ae2 100644
--- a/recipes/CompareText2Img.py
+++ b/recipes/CompareText2Img.py
@@ -246,8 +246,9 @@ def _render_outputs(self, state):
for key in selected_models:
output_images: dict = state.get("output_images", {}).get(key, [])
for img in output_images:
- st.image(img, caption=Text2ImgModels[key].value)
- self.render_buttons(img)
+ st.image(
+ img, caption=Text2ImgModels[key].value, show_download_button=True
+ )
def preview_description(self, state: dict) -> str:
return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc."
diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py
index 4aad4ffb2..13f62de91 100644
--- a/recipes/CompareUpscaler.py
+++ b/recipes/CompareUpscaler.py
@@ -107,8 +107,7 @@ def _render_outputs(self, state):
img: dict = state.get("output_images", {}).get(key)
if not img:
continue
- st.image(img, caption=UpscalerModels[key].value)
- self.render_buttons(img)
+ st.image(img, caption=UpscalerModels[key].value, show_download_button=True)
def get_raw_price(self, state: dict) -> int:
selected_models = state.get("selected_models", [])
diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py
index d971e708c..35a71a12c 100644
--- a/recipes/DeforumSD.py
+++ b/recipes/DeforumSD.py
@@ -424,8 +424,7 @@ def render_output(self):
output_video = st.session_state.get("output_video")
if output_video:
st.write("#### Output Video")
- st.video(output_video, autoplay=True)
- self.render_buttons(output_video)
+ st.video(output_video, autoplay=True, show_download_button=True)
def estimate_run_duration(self):
# in seconds
diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py
index f6a550d48..6d787bba6 100644
--- a/recipes/FaceInpainting.py
+++ b/recipes/FaceInpainting.py
@@ -198,8 +198,7 @@ def render_output(self):
if output_images:
st.write("#### Output Image")
for url in output_images:
- st.image(url)
- self.render_buttons(url)
+ st.image(url, show_download_button=True)
else:
st.div()
diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py
index fd5d2db80..278128a37 100644
--- a/recipes/GoogleImageGen.py
+++ b/recipes/GoogleImageGen.py
@@ -206,8 +206,7 @@ def render_output(self):
out_imgs = st.session_state.get("output_images")
if out_imgs:
for img in out_imgs:
- st.image(img, caption="#### Generated Image")
- self.render_buttons(img)
+ st.image(img, caption="#### Generated Image", show_download_button=True)
else:
st.div()
diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py
index 39713cb2b..c247fdeda 100644
--- a/recipes/ImageSegmentation.py
+++ b/recipes/ImageSegmentation.py
@@ -342,8 +342,7 @@ def render_example(self, state: dict):
with col1:
input_image = state.get("input_image")
if input_image:
- st.image(input_image, caption="Input Photo")
- self.render_buttons(input_image)
+ st.image(input_image, caption="Input Photo", show_download_button=True)
else:
st.div()
diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py
index 7334ebb31..e2713aaa2 100644
--- a/recipes/Img2Img.py
+++ b/recipes/Img2Img.py
@@ -127,12 +127,12 @@ def render_usage_guide(self):
youtube_video("narcZNyuNAg")
def render_output(self):
- text_prompt = st.session_state.get("text_prompt", "")
output_images = st.session_state.get("output_images", [])
+ if not output_images:
+ return
st.write("#### Output Image")
for img in output_images:
- st.image(img)
- self.render_buttons(img)
+ st.image(img, show_download_button=True)
def render_example(self, state: dict):
col1, col2 = st.columns(2)
diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py
index 8ea341aad..c2ac64369 100644
--- a/recipes/Lipsync.py
+++ b/recipes/Lipsync.py
@@ -85,8 +85,7 @@ def render_example(self, state: dict):
output_video = state.get("output_video")
if output_video:
st.write("#### Output Video")
- st.video(output_video, autoplay=True)
- self.render_buttons(output_video)
+ st.video(output_video, autoplay=True, show_download_button=True)
else:
st.div()
diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py
index d463d7483..1fc0f6c64 100644
--- a/recipes/LipsyncTTS.py
+++ b/recipes/LipsyncTTS.py
@@ -129,8 +129,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
def render_example(self, state: dict):
output_video = state.get("output_video")
if output_video:
- st.video(output_video, caption="#### Output Video", autoplay=True)
- self.render_buttons(output_video)
+ st.video(
+ output_video,
+ caption="#### Output Video",
+ autoplay=True,
+ show_download_button=True,
+ )
else:
st.div()
diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py
index c2c3b29c4..a1e0c2449 100644
--- a/recipes/ObjectInpainting.py
+++ b/recipes/ObjectInpainting.py
@@ -201,8 +201,7 @@ def render_output(self):
if output_images:
for url in output_images:
- st.image(url, caption=f"{text_prompt}")
- self.render_buttons(url)
+ st.image(url, caption=f"{text_prompt}", show_download_button=True)
else:
st.div()
diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py
index 1fe283825..1512e4523 100644
--- a/recipes/QRCodeGenerator.py
+++ b/recipes/QRCodeGenerator.py
@@ -436,7 +436,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None):
if max_count:
output_images = output_images[:max_count]
for img in output_images:
- st.image(img)
+ st.image(img, show_download_button=True)
qr_code_data = (
state.get(QrSources.qr_code_data.name)
or state.get(QrSources.qr_code_input_image.name)
@@ -458,7 +458,6 @@ def _render_outputs(self, state: dict, max_count: int | None = None):
st.caption(f"{shortened_url} → {qr_code_data} (Views: {clicks})")
else:
st.caption(f"{shortened_url} → {qr_code_data}")
- self.render_buttons(img)
def run(self, state: dict) -> typing.Iterator[str | None]:
request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state)
diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py
index 4d68b8eb5..f7ddf5aab 100644
--- a/recipes/Text2Audio.py
+++ b/recipes/Text2Audio.py
@@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
)
def render_output(self):
- _render_output(self, st.session_state)
+ _render_output(st.session_state)
def render_example(self, state: dict):
col1, col2 = st.columns(2)
@@ -141,10 +141,11 @@ def preview_description(self, state: dict) -> str:
return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)."
-def _render_output(self, state):
+def _render_output(state):
selected_models = state.get("selected_models", [])
for key in selected_models:
output: dict = state.get("output_audios", {}).get(key, [])
for audio in output:
- st.audio(audio, caption=Text2AudioModels[key].value)
- self.render_buttons(audio)
+ st.audio(
+ audio, caption=Text2AudioModels[key].value, show_download_button=True
+ )
diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py
index b90bb8a4b..7494c992d 100644
--- a/recipes/TextToSpeech.py
+++ b/recipes/TextToSpeech.py
@@ -131,13 +131,8 @@ def render_usage_guide(self):
# loom_video("2d853b7442874b9cbbf3f27b98594add")
def render_output(self):
- text_prompt = st.session_state.get("text_prompt", "")
audio_url = st.session_state.get("audio_url")
- if audio_url:
- st.audio(audio_url)
- self.render_buttons(audio_url)
- else:
- st.div()
+ st.audio(audio_url, show_download_button=True)
def _get_elevenlabs_price(self, state: dict):
_, is_user_provided_key = self._get_elevenlabs_api_key(state)