Skip to content

Commit

Permalink
fix report button re-usability fuckup
Browse files Browse the repository at this point in the history
put download button inside the audio/video/image components
consistent datetime format
fix half ass caption showing up when the recipe is running
make download_button re-use the button code
  • Loading branch information
devxpy committed Feb 12, 2024
1 parent 6f75eca commit 64b7dfc
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 101 deletions.
14 changes: 7 additions & 7 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ def to_df_format(
.replace(tzinfo=None)
)
row |= {
"Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"),
"First Sent": first_time.strftime("%b %d, %Y %I:%M %p"),
"Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT),
"First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT),
"A7": not convo.d7(),
"A30": not convo.d30(),
"R1": last_time - first_time < datetime.timedelta(days=1),
Expand Down Expand Up @@ -926,7 +926,7 @@ def to_df_format(
"Message (EN)": message.content,
"Sent": message.created_at.astimezone(tz)
.replace(tzinfo=None)
.strftime("%b %d, %Y %I:%M %p"),
.strftime(settings.SHORT_DATETIME_FORMAT),
"Feedback": (
message.feedbacks.first().get_display_text()
if message.feedbacks.first()
Expand Down Expand Up @@ -968,7 +968,7 @@ def to_df_analysis_format(
"Answer (EN)": message.content,
"Sent": message.created_at.astimezone(tz)
.replace(tzinfo=None)
.strftime("%b %d, %Y %I:%M %p"),
.strftime(settings.SHORT_DATETIME_FORMAT),
"Analysis JSON": message.analysis_result,
}
rows.append(row)
Expand Down Expand Up @@ -1153,16 +1153,16 @@ def to_df_format(
"Question Sent": feedback.message.get_previous_by_created_at()
.created_at.astimezone(tz)
.replace(tzinfo=None)
.strftime("%b %d, %Y %I:%M %p"),
.strftime(settings.SHORT_DATETIME_FORMAT),
"Answer (EN)": feedback.message.content,
"Answer Sent": feedback.message.created_at.astimezone(tz)
.replace(tzinfo=None)
.strftime("%b %d, %Y %I:%M %p"),
.strftime(settings.SHORT_DATETIME_FORMAT),
"Rating": Feedback.Rating(feedback.rating).label,
"Feedback (EN)": feedback.text_english,
"Feedback Sent": feedback.created_at.astimezone(tz)
.replace(tzinfo=None)
.strftime("%b %d, %Y %I:%M %p"),
.strftime(settings.SHORT_DATETIME_FORMAT),
"Question Answered": feedback.message.question_answered,
}
rows.append(row)
Expand Down
55 changes: 28 additions & 27 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,11 +1348,11 @@ def _render_output_col(self, submitted: bool):
# render outputs
self.render_output()

if run_state != "waiting":
if run_state != RecipeRunState.running:
self._render_after_output()

def _render_completed_output(self):
run_time = st.session_state.get(StateKeys.run_time, 0)
pass

def _render_failed_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
Expand All @@ -1368,12 +1368,10 @@ def render_extra_waiting_output(self):
if not estimated_run_time:
return
if created_at := st.session_state.get("created_at"):
if isinstance(created_at, datetime.datetime):
start_time = created_at
else:
start_time = datetime.datetime.fromisoformat(created_at)
if isinstance(created_at, str):
created_at = datetime.datetime.fromisoformat(created_at)
with st.countdown_timer(
end_time=start_time + datetime.timedelta(seconds=estimated_run_time),
end_time=created_at + datetime.timedelta(seconds=estimated_run_time),
delay_text="Sorry for the wait. Your run is taking longer than we expected.",
):
if self.is_current_user_owner() and self.request.user.email:
Expand Down Expand Up @@ -1514,34 +1512,17 @@ def clear_outputs(self):
st.session_state.pop(field_name, None)

def _render_after_output(self):
self._render_report_button()

if "seed" in self.RequestModel.schema_json():
randomize = st.button(
'<i class="fa-solid fa-recycle"></i> Regenerate', type="tertiary"
)
if randomize:
st.session_state[StateKeys.pressed_randomize] = True
st.experimental_rerun()
caption = ""
caption += f'\\\nGenerated in <span style="color: black;">{st.session_state.get(StateKeys.run_time, 0):.2f}s</span>'
if "seed" in self.RequestModel.schema_json():
seed = st.session_state.get("seed")
caption += f' with seed <span style="color: black;">{seed}</span> '
created_at = st.session_state.get(
StateKeys.created_at, datetime.datetime.today()
)
if not isinstance(created_at, datetime.datetime):
created_at = datetime.datetime.fromisoformat(created_at)
format_created_at = created_at.strftime("%d %b %Y %-I:%M%p")
caption += f' at <span style="color: black;">{format_created_at}</span>'
st.caption(caption, unsafe_allow_html=True)

def render_buttons(self, url: str):
st.download_button(
label='<i class="fa-regular fa-download"></i> Download',
url=url,
type="secondary",
)
self._render_report_button()
render_output_caption()

def state_to_doc(self, state: dict):
ret = {
Expand Down Expand Up @@ -1918,6 +1899,26 @@ def is_current_user_owner(self) -> bool:
)


def render_output_caption():
caption = ""

run_time = st.session_state.get(StateKeys.run_time, 0)
if run_time:
caption += f'Generated in <span style="color: black;">{run_time :.2f}s</span>'

if seed := st.session_state.get("seed"):
caption += f' with seed <span style="color: black;">{seed}</span> '

created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today())
if created_at:
if isinstance(created_at, str):
created_at = datetime.datetime.fromisoformat(created_at)
format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT)
caption += f' at <span style="color: black;">{format_created_at}</span>'

st.caption(caption, unsafe_allow_html=True)


def get_example_request_body(
request_model: typing.Type[BaseModel],
state: dict,
Expand Down
2 changes: 2 additions & 0 deletions daras_ai_v2/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@

es_formats.DATETIME_FORMAT = DATETIME_FORMAT

SHORT_DATETIME_FORMAT = "%b %d, %Y %-I:%M %p"

# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/

Expand Down
72 changes: 38 additions & 34 deletions gooey_ui/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def image(
caption: str = None,
alt: str = None,
href: str = None,
show_download_button: bool = False,
**props,
):
if isinstance(src, np.ndarray):
Expand All @@ -241,9 +242,18 @@ def image(
**props,
),
).mount()
if show_download_button:
download_button(
label='<i class="fa-regular fa-download"></i> Download', url=src
)


def video(src: str, caption: str = None, autoplay: bool = False):
def video(
src: str,
caption: str = None,
autoplay: bool = False,
show_download_button: bool = False,
):
autoplay_props = {}
if autoplay:
autoplay_props = {
Expand All @@ -266,15 +276,23 @@ def video(src: str, caption: str = None, autoplay: bool = False):
name="video",
props=dict(src=src, caption=dedent(caption), **autoplay_props),
).mount()
if show_download_button:
download_button(
label='<i class="fa-regular fa-download"></i> Download', url=src
)


def audio(src: str, caption: str = None):
def audio(src: str, caption: str = None, show_download_button: bool = False):
if not src:
return
state.RenderTreeNode(
name="audio",
props=dict(src=src, caption=dedent(caption)),
).mount()
if show_download_button:
download_button(
label='<i class="fa-regular fa-download"></i> Download', url=src
)


def text_area(
Expand Down Expand Up @@ -415,52 +433,36 @@ def selectbox(
return value


def button(
def download_button(
label: str,
url: str,
key: str = None,
help: str = None,
*,
type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
disabled: bool = False,
**props,
) -> bool:
"""
Example:
st.button("Primary", key="test0", type="primary")
st.button("Secondary", key="test1")
st.button("Tertiary", key="test3", type="tertiary")
st.button("Link Button", key="test3", type="link")
"""
if not key:
key = md5_values("button", label, help, type, props)
className = f"btn-{type} " + props.pop("className", "")
state.RenderTreeNode(
name="gui-button",
props=dict(
type="submit",
value="yes",
name=key,
label=dedent(label),
help=help,
disabled=disabled,
className=className,
**props,
),
).mount()
return bool(state.session_state.pop(key, False))


form_submit_button = button
return button(
component="download-button",
url=url,
label=label,
key=key,
help=help,
type=type,
disabled=disabled,
**props,
)


def download_button(
def button(
label: str,
url: str,
key: str = None,
help: str = None,
*,
type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
disabled: bool = False,
component: typing.Literal["download-button", "gui-button"] = "gui-button",
**props,
) -> bool:
"""
Expand All @@ -474,11 +476,10 @@ def download_button(
key = md5_values("button", label, help, type, props)
className = f"btn-{type} " + props.pop("className", "")
state.RenderTreeNode(
name="download-button",
name=component,
props=dict(
type="submit",
value="yes",
url=url,
name=key,
label=dedent(label),
help=help,
Expand All @@ -490,6 +491,9 @@ def download_button(
return bool(state.session_state.pop(key, False))


form_submit_button = button


def expander(label: str, *, expanded: bool = False, **props):
node = state.RenderTreeNode(
name="expander",
Expand Down
5 changes: 3 additions & 2 deletions recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _render_outputs(self, state):
for key in selected_models:
output_images: dict = state.get("output_images", {}).get(key, [])
for img in output_images:
st.image(img, caption=Text2ImgModels[key].value)
self.render_buttons(img)
st.image(
img, caption=Text2ImgModels[key].value, show_download_button=True
)

def preview_description(self, state: dict) -> str:
return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc."
Expand Down
3 changes: 1 addition & 2 deletions recipes/CompareUpscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def _render_outputs(self, state):
img: dict = state.get("output_images", {}).get(key)
if not img:
continue
st.image(img, caption=UpscalerModels[key].value)
self.render_buttons(img)
st.image(img, caption=UpscalerModels[key].value, show_download_button=True)

def get_raw_price(self, state: dict) -> int:
selected_models = state.get("selected_models", [])
Expand Down
3 changes: 1 addition & 2 deletions recipes/DeforumSD.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ def render_output(self):
output_video = st.session_state.get("output_video")
if output_video:
st.write("#### Output Video")
st.video(output_video, autoplay=True)
self.render_buttons(output_video)
st.video(output_video, autoplay=True, show_download_button=True)

def estimate_run_duration(self):
# in seconds
Expand Down
3 changes: 1 addition & 2 deletions recipes/FaceInpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ def render_output(self):
if output_images:
st.write("#### Output Image")
for url in output_images:
st.image(url)
self.render_buttons(url)
st.image(url, show_download_button=True)
else:
st.div()

Expand Down
3 changes: 1 addition & 2 deletions recipes/GoogleImageGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def render_output(self):
out_imgs = st.session_state.get("output_images")
if out_imgs:
for img in out_imgs:
st.image(img, caption="#### Generated Image")
self.render_buttons(img)
st.image(img, caption="#### Generated Image", show_download_button=True)
else:
st.div()

Expand Down
3 changes: 1 addition & 2 deletions recipes/ImageSegmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ def render_example(self, state: dict):
with col1:
input_image = state.get("input_image")
if input_image:
st.image(input_image, caption="Input Photo")
self.render_buttons(input_image)
st.image(input_image, caption="Input Photo", show_download_button=True)
else:
st.div()

Expand Down
6 changes: 3 additions & 3 deletions recipes/Img2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def render_usage_guide(self):
youtube_video("narcZNyuNAg")

def render_output(self):
text_prompt = st.session_state.get("text_prompt", "")
output_images = st.session_state.get("output_images", [])
if not output_images:
return
st.write("#### Output Image")
for img in output_images:
st.image(img)
self.render_buttons(img)
st.image(img, show_download_button=True)

def render_example(self, state: dict):
col1, col2 = st.columns(2)
Expand Down
3 changes: 1 addition & 2 deletions recipes/Lipsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def render_example(self, state: dict):
output_video = state.get("output_video")
if output_video:
st.write("#### Output Video")
st.video(output_video, autoplay=True)
self.render_buttons(output_video)
st.video(output_video, autoplay=True, show_download_button=True)
else:
st.div()

Expand Down
8 changes: 6 additions & 2 deletions recipes/LipsyncTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
def render_example(self, state: dict):
output_video = state.get("output_video")
if output_video:
st.video(output_video, caption="#### Output Video", autoplay=True)
self.render_buttons(output_video)
st.video(
output_video,
caption="#### Output Video",
autoplay=True,
show_download_button=True,
)
else:
st.div()

Expand Down
Loading

0 comments on commit 64b7dfc

Please sign in to comment.