From 64b7dfcefe19b8da6f3bb3caa84b313786a43799 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 11:57:49 +0530 Subject: [PATCH] fix report button re-usability fuckup put download button inside the audio/video/image components consistent datetime format fix half ass caption showing up when the recipe is running make download_button re-use the button code --- bots/models.py | 14 +++---- daras_ai_v2/base.py | 55 ++++++++++++------------- daras_ai_v2/settings.py | 2 + gooey_ui/components/__init__.py | 72 +++++++++++++++++---------------- recipes/CompareText2Img.py | 5 ++- recipes/CompareUpscaler.py | 3 +- recipes/DeforumSD.py | 3 +- recipes/FaceInpainting.py | 3 +- recipes/GoogleImageGen.py | 3 +- recipes/ImageSegmentation.py | 3 +- recipes/Img2Img.py | 6 +-- recipes/Lipsync.py | 3 +- recipes/LipsyncTTS.py | 8 +++- recipes/ObjectInpainting.py | 3 +- recipes/QRCodeGenerator.py | 3 +- recipes/Text2Audio.py | 9 +++-- recipes/TextToSpeech.py | 7 +--- 17 files changed, 101 insertions(+), 101 deletions(-) diff --git a/bots/models.py b/bots/models.py index 205dfb5ff..558512e8c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -700,8 +700,8 @@ def to_df_format( .replace(tzinfo=None) ) row |= { - "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"), - "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"), + "Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT), + "First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT), "A7": not convo.d7(), "A30": not convo.d30(), "R1": last_time - first_time < datetime.timedelta(days=1), @@ -926,7 +926,7 @@ def to_df_format( "Message (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Feedback": ( message.feedbacks.first().get_display_text() if message.feedbacks.first() @@ -968,7 +968,7 @@ def to_df_analysis_format( "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Analysis JSON": message.analysis_result, } rows.append(row) @@ -1153,16 +1153,16 @@ def to_df_format( "Question Sent": feedback.message.get_previous_by_created_at() .created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Answer (EN)": feedback.message.content, "Answer Sent": feedback.message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Rating": Feedback.Rating(feedback.rating).label, "Feedback (EN)": feedback.text_english, "Feedback Sent": feedback.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Question Answered": feedback.message.question_answered, } rows.append(row) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 28756f96a..0a885bf0a 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1348,11 +1348,11 @@ def _render_output_col(self, submitted: bool): # render outputs self.render_output() - if run_state != "waiting": + if run_state != RecipeRunState.running: self._render_after_output() def _render_completed_output(self): - run_time = st.session_state.get(StateKeys.run_time, 0) + pass def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1368,12 +1368,10 @@ def render_extra_waiting_output(self): if not estimated_run_time: return if created_at := st.session_state.get("created_at"): - if isinstance(created_at, datetime.datetime): - start_time = created_at - else: - start_time = datetime.datetime.fromisoformat(created_at) + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) with st.countdown_timer( - end_time=start_time + datetime.timedelta(seconds=estimated_run_time), + end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: @@ -1514,6 +1512,8 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): + self._render_report_button() + if "seed" in self.RequestModel.schema_json(): randomize = st.button( ' Regenerate', type="tertiary" @@ -1521,27 +1521,8 @@ def _render_after_output(self): if randomize: st.session_state[StateKeys.pressed_randomize] = True st.experimental_rerun() - caption = "" - caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' - if "seed" in self.RequestModel.schema_json(): - seed = st.session_state.get("seed") - caption += f' with seed {seed} ' - created_at = st.session_state.get( - StateKeys.created_at, datetime.datetime.today() - ) - if not isinstance(created_at, datetime.datetime): - created_at = datetime.datetime.fromisoformat(created_at) - format_created_at = created_at.strftime("%d %b %Y %-I:%M%p") - caption += f' at {format_created_at}' - st.caption(caption, unsafe_allow_html=True) - def render_buttons(self, url: str): - st.download_button( - label=' Download', - url=url, - type="secondary", - ) - self._render_report_button() + render_output_caption() def state_to_doc(self, state: dict): ret = { @@ -1918,6 +1899,26 @@ def is_current_user_owner(self) -> bool: ) +def render_output_caption(): + caption = "" + + run_time = st.session_state.get(StateKeys.run_time, 0) + if run_time: + caption += f'Generated in {run_time :.2f}s' + + if seed := st.session_state.get("seed"): + caption += f' with seed {seed} ' + + created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today()) + if created_at: + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT) + caption += f' at {format_created_at}' + + st.caption(caption, unsafe_allow_html=True) + + def get_example_request_body( request_model: typing.Type[BaseModel], state: dict, diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index deda81500..512fb3834 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -162,6 +162,8 @@ es_formats.DATETIME_FORMAT = DATETIME_FORMAT +SHORT_DATETIME_FORMAT = "%b %d, %Y %-I:%M %p" + # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 7e435b725..6322cf629 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -217,6 +217,7 @@ def image( caption: str = None, alt: str = None, href: str = None, + show_download_button: bool = False, **props, ): if isinstance(src, np.ndarray): @@ -241,9 +242,18 @@ def image( **props, ), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def video(src: str, caption: str = None, autoplay: bool = False): +def video( + src: str, + caption: str = None, + autoplay: bool = False, + show_download_button: bool = False, +): autoplay_props = {} if autoplay: autoplay_props = { @@ -266,15 +276,23 @@ def video(src: str, caption: str = None, autoplay: bool = False): name="video", props=dict(src=src, caption=dedent(caption), **autoplay_props), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def audio(src: str, caption: str = None): +def audio(src: str, caption: str = None, show_download_button: bool = False): if not src: return state.RenderTreeNode( name="audio", props=dict(src=src, caption=dedent(caption)), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) def text_area( @@ -415,8 +433,9 @@ def selectbox( return value -def button( +def download_button( label: str, + url: str, key: str = None, help: str = None, *, @@ -424,43 +443,26 @@ def button( disabled: bool = False, **props, ) -> bool: - """ - Example: - st.button("Primary", key="test0", type="primary") - st.button("Secondary", key="test1") - st.button("Tertiary", key="test3", type="tertiary") - st.button("Link Button", key="test3", type="link") - """ - if not key: - key = md5_values("button", label, help, type, props) - className = f"btn-{type} " + props.pop("className", "") - state.RenderTreeNode( - name="gui-button", - props=dict( - type="submit", - value="yes", - name=key, - label=dedent(label), - help=help, - disabled=disabled, - className=className, - **props, - ), - ).mount() - return bool(state.session_state.pop(key, False)) - - -form_submit_button = button + return button( + component="download-button", + url=url, + label=label, + key=key, + help=help, + type=type, + disabled=disabled, + **props, + ) -def download_button( +def button( label: str, - url: str, key: str = None, help: str = None, *, type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", disabled: bool = False, + component: typing.Literal["download-button", "gui-button"] = "gui-button", **props, ) -> bool: """ @@ -474,11 +476,10 @@ def download_button( key = md5_values("button", label, help, type, props) className = f"btn-{type} " + props.pop("className", "") state.RenderTreeNode( - name="download-button", + name=component, props=dict( type="submit", value="yes", - url=url, name=key, label=dedent(label), help=help, @@ -490,6 +491,9 @@ def download_button( return bool(state.session_state.pop(key, False)) +form_submit_button = button + + def expander(label: str, *, expanded: bool = False, **props): node = state.RenderTreeNode( name="expander", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 47347bb65..dc5ea1ae2 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -246,8 +246,9 @@ def _render_outputs(self, state): for key in selected_models: output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: - st.image(img, caption=Text2ImgModels[key].value) - self.render_buttons(img) + st.image( + img, caption=Text2ImgModels[key].value, show_download_button=True + ) def preview_description(self, state: dict) -> str: return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc." diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 4aad4ffb2..13f62de91 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -107,8 +107,7 @@ def _render_outputs(self, state): img: dict = state.get("output_images", {}).get(key) if not img: continue - st.image(img, caption=UpscalerModels[key].value) - self.render_buttons(img) + st.image(img, caption=UpscalerModels[key].value, show_download_button=True) def get_raw_price(self, state: dict) -> int: selected_models = state.get("selected_models", []) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index d971e708c..35a71a12c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -424,8 +424,7 @@ def render_output(self): output_video = st.session_state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) def estimate_run_duration(self): # in seconds diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index f6a550d48..6d787bba6 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -198,8 +198,7 @@ def render_output(self): if output_images: st.write("#### Output Image") for url in output_images: - st.image(url) - self.render_buttons(url) + st.image(url, show_download_button=True) else: st.div() diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index fd5d2db80..278128a37 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -206,8 +206,7 @@ def render_output(self): out_imgs = st.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="#### Generated Image") - self.render_buttons(img) + st.image(img, caption="#### Generated Image", show_download_button=True) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 39713cb2b..c247fdeda 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -342,8 +342,7 @@ def render_example(self, state: dict): with col1: input_image = state.get("input_image") if input_image: - st.image(input_image, caption="Input Photo") - self.render_buttons(input_image) + st.image(input_image, caption="Input Photo", show_download_button=True) else: st.div() diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 7334ebb31..e2713aaa2 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -127,12 +127,12 @@ def render_usage_guide(self): youtube_video("narcZNyuNAg") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") output_images = st.session_state.get("output_images", []) + if not output_images: + return st.write("#### Output Image") for img in output_images: - st.image(img) - self.render_buttons(img) + st.image(img, show_download_button=True) def render_example(self, state: dict): col1, col2 = st.columns(2) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 8ea341aad..c2ac64369 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -85,8 +85,7 @@ def render_example(self, state: dict): output_video = state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) else: st.div() diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index d463d7483..1fc0f6c64 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -129,8 +129,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: def render_example(self, state: dict): output_video = state.get("output_video") if output_video: - st.video(output_video, caption="#### Output Video", autoplay=True) - self.render_buttons(output_video) + st.video( + output_video, + caption="#### Output Video", + autoplay=True, + show_download_button=True, + ) else: st.div() diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index c2c3b29c4..a1e0c2449 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -201,8 +201,7 @@ def render_output(self): if output_images: for url in output_images: - st.image(url, caption=f"{text_prompt}") - self.render_buttons(url) + st.image(url, caption=f"{text_prompt}", show_download_button=True) else: st.div() diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 1fe283825..1512e4523 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -436,7 +436,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): if max_count: output_images = output_images[:max_count] for img in output_images: - st.image(img) + st.image(img, show_download_button=True) qr_code_data = ( state.get(QrSources.qr_code_data.name) or state.get(QrSources.qr_code_input_image.name) @@ -458,7 +458,6 @@ def _render_outputs(self, state: dict, max_count: int | None = None): st.caption(f"{shortened_url} → {qr_code_data} (Views: {clicks})") else: st.caption(f"{shortened_url} → {qr_code_data}") - self.render_buttons(img) def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 4d68b8eb5..f7ddf5aab 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(self, st.session_state) + _render_output(st.session_state) def render_example(self, state: dict): col1, col2 = st.columns(2) @@ -141,10 +141,11 @@ def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." -def _render_output(self, state): +def _render_output(state): selected_models = state.get("selected_models", []) for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: - st.audio(audio, caption=Text2AudioModels[key].value) - self.render_buttons(audio) + st.audio( + audio, caption=Text2AudioModels[key].value, show_download_button=True + ) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index b90bb8a4b..7494c992d 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -131,13 +131,8 @@ def render_usage_guide(self): # loom_video("2d853b7442874b9cbbf3f27b98594add") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") audio_url = st.session_state.get("audio_url") - if audio_url: - st.audio(audio_url) - self.render_buttons(audio_url) - else: - st.div() + st.audio(audio_url, show_download_button=True) def _get_elevenlabs_price(self, state: dict): _, is_user_provided_key = self._get_elevenlabs_api_key(state)