Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run before save - if unsaved changes in current run #445

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 60 additions & 40 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import hashlib
import html
import inspect
import json
Expand Down Expand Up @@ -139,8 +140,7 @@ class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
title="🧩 Developer Tools and Functions",
)
variables: dict[str, typing.Any] = Field(
None,
variables: dict[str, typing.Any] | None = Field(
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)
Expand Down Expand Up @@ -305,8 +305,8 @@ def sentry_event_set_user(self, event, hint):
return event

def refresh_state(self):
_, run_id, uid = extract_query_params(gui.get_query_params())
channel = self.realtime_channel_name(run_id, uid)
sr = self.get_current_sr()
channel = self.realtime_channel_name(sr.run_id, sr.uid)
output = gui.realtime_pull([channel])[0]
if output:
gui.session_state.update(output)
Expand All @@ -326,7 +326,7 @@ def render(self):
self.render_report_form()
return

self._render_header()
header_placeholder = gui.div()
gui.newline()

with gui.nav_tabs():
Expand All @@ -337,14 +337,19 @@ def render(self):
with gui.nav_tab_content():
self.render_selected_tab()

with header_placeholder:
self._render_header()

def _render_header(self):
current_run = self.get_current_sr()
published_run = self.get_current_published_run()
is_example = published_run and published_run.saved_run == current_run
is_example = published_run.saved_run == current_run
is_root_example = is_example and published_run.is_root()
tbreadcrumbs = get_title_breadcrumbs(
self, current_run, published_run, tab=self.tab
)
can_edit = self.can_user_edit_run(current_run, published_run)
request_changed = self._has_request_changed()

with gui.div(className="d-flex justify-content-between mt-4"):
with gui.div(className="d-lg-flex d-block align-items-center"):
Expand All @@ -361,43 +366,34 @@ def _render_header(self):
)

if is_example:
assert published_run
author = published_run.created_by
else:
author = self.run_user or current_run.get_creator()
if not is_root_example:
self.render_author(author)

with gui.div(className="d-flex align-items-center"):
can_user_edit_run = self.can_user_edit_run(current_run, published_run)
has_unpublished_changes = (
published_run
and published_run.saved_run != current_run
and self.request
and self.request.user
)

if can_user_edit_run and has_unpublished_changes:
if request_changed or (can_edit and not is_example):
self._render_unpublished_changes_indicator()

with gui.div(className="d-flex align-items-start right-action-icons"):
gui.html(
"""
<style>
.right-action-icons .btn {
padding: 6px;
}
</style>
"""
<style>
.right-action-icons .btn {
padding: 6px;
}
</style>
"""
)

if published_run and can_user_edit_run:
self._render_published_run_buttons(
show_save_buttons = request_changed or can_edit
if show_save_buttons:
self._render_published_run_save_buttons(
current_run=current_run,
published_run=published_run,
)

self._render_social_buttons(show_button_text=not can_user_edit_run)
self._render_social_buttons(show_button_text=not show_save_buttons)

if tbreadcrumbs.has_breadcrumbs() or self.run_user:
# only render title here if the above row was not empty
Expand Down Expand Up @@ -451,11 +447,10 @@ def _render_unpublished_changes_indicator(self):
gui.html("Unpublished changes")

def _render_social_buttons(self, show_button_text: bool = False):
button_text = (
'<span class="d-none d-lg-inline"> Copy Link</span>'
if show_button_text
else ""
)
if show_button_text:
button_text = '<span class="d-none d-lg-inline"> Copy Link</span>'
else:
button_text = ""

copy_to_clipboard_button(
f'<i class="fa-regular fa-link"></i>{button_text}',
Expand All @@ -464,12 +459,11 @@ def _render_social_buttons(self, show_button_text: bool = False):
className="mb-0 ms-lg-2",
)

def _render_published_run_buttons(
def _render_published_run_save_buttons(
self,
*,
current_run: SavedRun,
published_run: PublishedRun,
redirect_to: str | None = None,
):
is_update_mode = (
self.is_current_user_admin()
Expand Down Expand Up @@ -527,7 +521,6 @@ def _render_published_run_buttons(
published_run=published_run,
modal=publish_modal,
is_update_mode=is_update_mode,
redirect_to=redirect_to,
)

def _render_publish_modal(
Expand All @@ -537,7 +530,6 @@ def _render_publish_modal(
published_run: PublishedRun,
modal: gui.Modal,
is_update_mode: bool = False,
redirect_to: str | None = None,
):
if published_run.is_root() and self.is_current_user_admin():
with gui.div(className="text-danger"):
Expand Down Expand Up @@ -625,6 +617,11 @@ def _render_publish_modal(
gui.error(str(e))
return

if self._has_request_changed():
current_run = self.on_submit()
if not current_run:
modal.close()

if is_update_mode:
updates = dict(
saved_run=current_run,
Expand All @@ -647,7 +644,7 @@ def _render_publish_modal(
notes=published_run_notes.strip(),
visibility=published_run_visibility,
)
raise gui.RedirectException(redirect_to or published_run.get_app_url())
raise gui.RedirectException(published_run.get_app_url())

def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
Expand Down Expand Up @@ -677,6 +674,25 @@ def _has_published_run_changed(
or published_run.saved_run != saved_run
)

def _has_request_changed(self) -> bool:
if gui.session_state.get("--has-request-changed"):
return True

try:
curr_req = self.RequestModel.parse_obj(gui.session_state)
except ValidationError:
# if the request model fails to parse, the request has likely changed
return True

curr_hash = hashlib.md5(curr_req.json(sort_keys=True).encode()).hexdigest()
prev_hash = gui.session_state.setdefault("--prev-request-hash", curr_hash)

if curr_hash != prev_hash:
gui.session_state["--has-request-changed"] = True # cache it for next time
return True
else:
return False

def _render_options_modal(
self,
*,
Expand Down Expand Up @@ -1512,7 +1528,7 @@ def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = Fals
submitted = True

if submitted or self.should_submit_after_login():
self.on_submit()
self.submit_and_redirect()

run_state = self.get_run_state(gui.session_state)
match run_state:
Expand Down Expand Up @@ -1574,6 +1590,12 @@ def render_extra_waiting_output(self):
def estimate_run_duration(self) -> int | None:
pass

def submit_and_redirect(self):
sr = self.on_submit()
if not sr:
return
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

def on_submit(self):
try:
sr = self.create_new_run(enable_rate_limits=True)
Expand All @@ -1585,10 +1607,8 @@ def on_submit(self):
gui.session_state[StateKeys.run_status] = None
gui.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return

self.call_runner_task(sr)

raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
return sr

def should_submit_after_login(self) -> bool:
return (
Expand Down
4 changes: 3 additions & 1 deletion daras_ai_v2/prompt_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def render_title_desc():
- set(gui.session_state.keys()) # dont show other session state variables
)

gui.session_state[key] = new_vars = {}
new_vars = {}
if all_var_names:
gui.session_state[key] = new_vars
title_shown = False
for name in sorted(all_var_names):
var_key = f"--{key}:{name}"
Expand Down
2 changes: 1 addition & 1 deletion recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def on_send(
gui.session_state["input_images"] = new_input_images or None
gui.session_state["input_documents"] = new_input_documents or None

self.on_submit()
self.submit_and_redirect()

def render_steps(self):
if gui.session_state.get("tts_provider"):
Expand Down
Loading