Skip to content

Commit

Permalink
Refactor run change detection and submission logic
Browse files Browse the repository at this point in the history
- Always fetch the run_id & uid from the current run instead of query params when subscribing to realtime channel
- Rename `_has_current_run_changed` ->  `_has_request_changed` + add caching
- Handle state reload after redirect to same url in react code
- Remove un-used parameter `redirect_to`
- Rename `submit` -> `submit_and_redirect`
- Render header after inputs so that changes can be detected
- Don't save published run if saved run creation fails, close dialog and show error instead
- Remove checks for published_run being None since we have published_run always
- Show published run save buttons immediately after making a change
- Avoid init-ing state.variables = {} un-necessarily
  • Loading branch information
devxpy committed Aug 29, 2024
1 parent 320f026 commit 9b7f636
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 72 deletions.
127 changes: 57 additions & 70 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import hashlib
import html
import inspect
import json
Expand Down Expand Up @@ -139,8 +140,7 @@ class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
title="🧩 Developer Tools and Functions",
)
variables: dict[str, typing.Any] = Field(
None,
variables: dict[str, typing.Any] | None = Field(
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)
Expand Down Expand Up @@ -305,12 +305,8 @@ def sentry_event_set_user(self, event, hint):
return event

def refresh_state(self):
example_id, run_id, uid = extract_query_params(gui.get_query_params())
if not run_id:
sr = self.get_sr_from_query_params(example_id, run_id, uid)
run_id, uid = sr.run_id, sr.uid

channel = self.realtime_channel_name(run_id, uid)
sr = self.get_current_sr()
channel = self.realtime_channel_name(sr.run_id, sr.uid)
output = gui.realtime_pull([channel])[0]
if output:
gui.session_state.update(output)
Expand All @@ -330,7 +326,7 @@ def render(self):
self.render_report_form()
return

self._render_header()
header_placeholder = gui.div()
gui.newline()

with gui.nav_tabs():
Expand All @@ -341,14 +337,19 @@ def render(self):
with gui.nav_tab_content():
self.render_selected_tab()

with header_placeholder:
self._render_header()

def _render_header(self):
current_run = self.get_current_sr()
published_run = self.get_current_published_run()
is_example = published_run and published_run.saved_run == current_run
is_example = published_run.saved_run == current_run
is_root_example = is_example and published_run.is_root()
tbreadcrumbs = get_title_breadcrumbs(
self, current_run, published_run, tab=self.tab
)
can_edit = self.can_user_edit_run(current_run, published_run)
request_changed = self._has_request_changed()

with gui.div(className="d-flex justify-content-between mt-4"):
with gui.div(className="d-lg-flex d-block align-items-center"):
Expand All @@ -365,45 +366,36 @@ def _render_header(self):
)

if is_example:
assert published_run
author = published_run.created_by
else:
author = self.run_user or current_run.get_creator()
if not is_root_example:
self.render_author(author)

with gui.div(className="d-flex align-items-center"):
can_user_edit_run = self.can_user_edit_run(current_run, published_run)
has_unpublished_changes = (
published_run
and published_run.saved_run != current_run
and self.request
and self.request.user
)

if (
can_user_edit_run and has_unpublished_changes
) or self._has_current_run_changed(current_run):
if request_changed or (
can_edit and published_run.saved_run != current_run
):
self._render_unpublished_changes_indicator()

with gui.div(className="d-flex align-items-start right-action-icons"):
gui.html(
"""
<style>
.right-action-icons .btn {
padding: 6px;
}
</style>
"""
<style>
.right-action-icons .btn {
padding: 6px;
}
</style>
"""
)

if published_run and can_user_edit_run:
self._render_published_run_buttons(
show_save_buttons = request_changed or can_edit
if show_save_buttons:
self._render_published_run_save_buttons(
current_run=current_run,
published_run=published_run,
)

self._render_social_buttons(show_button_text=not can_user_edit_run)
self._render_social_buttons(show_button_text=not show_save_buttons)

if tbreadcrumbs.has_breadcrumbs() or self.run_user:
# only render title here if the above row was not empty
Expand Down Expand Up @@ -457,11 +449,10 @@ def _render_unpublished_changes_indicator(self):
gui.html("Unpublished changes")

def _render_social_buttons(self, show_button_text: bool = False):
button_text = (
'<span class="d-none d-lg-inline"> Copy Link</span>'
if show_button_text
else ""
)
if show_button_text:
button_text = '<span class="d-none d-lg-inline"> Copy Link</span>'
else:
button_text = ""

copy_to_clipboard_button(
f'<i class="fa-regular fa-link"></i>{button_text}',
Expand All @@ -470,12 +461,11 @@ def _render_social_buttons(self, show_button_text: bool = False):
className="mb-0 ms-lg-2",
)

def _render_published_run_buttons(
def _render_published_run_save_buttons(
self,
*,
current_run: SavedRun,
published_run: PublishedRun,
redirect_to: str | None = None,
):
is_update_mode = (
self.is_current_user_admin()
Expand Down Expand Up @@ -533,7 +523,6 @@ def _render_published_run_buttons(
published_run=published_run,
modal=publish_modal,
is_update_mode=is_update_mode,
redirect_to=redirect_to,
)

def _render_publish_modal(
Expand All @@ -543,10 +532,7 @@ def _render_publish_modal(
published_run: PublishedRun,
modal: gui.Modal,
is_update_mode: bool = False,
redirect_to: str | None = None,
):
is_example = published_run.saved_run == current_run

if published_run.is_root() and self.is_current_user_admin():
with gui.div(className="text-danger"):
gui.write(
Expand Down Expand Up @@ -633,10 +619,10 @@ def _render_publish_modal(
gui.error(str(e))
return

if self._has_current_run_changed(current_run):
sr = self._on_submit()
if sr:
current_run = sr
if self._has_request_changed():
current_run = self.on_submit()
if not current_run:
modal.close()

if is_update_mode:
updates = dict(
Expand All @@ -660,13 +646,7 @@ def _render_publish_modal(
notes=published_run_notes.strip(),
visibility=published_run_visibility,
)

if redirect_to:
raise gui.RedirectException(redirect_to)
elif is_example:
modal.close() # implicit gui.rerun to reload the updated run
else:
raise gui.RedirectException(published_run.get_app_url())
raise gui.RedirectException(published_run.get_app_url())

def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
Expand Down Expand Up @@ -696,15 +676,23 @@ def _has_published_run_changed(
or published_run.saved_run != saved_run
)

def _has_current_run_changed(self, sr: SavedRun) -> bool:
"""are there unsaved changes that haven't been run?"""
def _has_request_changed(self) -> bool:
if gui.session_state.get("--has-request-changed"):
return True

try:
extracted_state = self.RequestModel.parse_obj(gui.session_state)
extracted_sr = self.RequestModel.parse_obj(sr.to_dict())
return extracted_sr != extracted_state
except ValidationError as e:
# don't want page to be inaccessible if ever validation fails - log and continue
sentry_sdk.capture_exception(e)
curr_req = self.RequestModel.parse_obj(gui.session_state)
except ValidationError:
# if the request model fails to parse, the request has likely changed
return True

curr_hash = hashlib.md5(curr_req.json(sort_keys=True).encode()).hexdigest()
prev_hash = gui.session_state.setdefault("--prev-request-hash", curr_hash)

if curr_hash != prev_hash:
gui.session_state["--has-request-changed"] = True # cache it for next time
return True
else:
return False

def _render_options_modal(
Expand Down Expand Up @@ -1542,7 +1530,7 @@ def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = Fals
submitted = True

if submitted or self.should_submit_after_login():
self.on_submit()
self.submit_and_redirect()

run_state = self.get_run_state(gui.session_state)
match run_state:
Expand Down Expand Up @@ -1604,12 +1592,13 @@ def render_extra_waiting_output(self):
def estimate_run_duration(self) -> int | None:
pass

def on_submit(self):
sr = self._on_submit()
if sr:
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
def submit_and_redirect(self):
sr = self.on_submit()
if not sr:
return
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

def _on_submit(self):
def on_submit(self):
try:
sr = self.create_new_run(enable_rate_limits=True)
except ValidationError as e:
Expand All @@ -1620,9 +1609,7 @@ def _on_submit(self):
gui.session_state[StateKeys.run_status] = None
gui.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return

self.call_runner_task(sr)

return sr

def should_submit_after_login(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion daras_ai_v2/prompt_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def render_title_desc():
- set(gui.session_state.keys()) # dont show other session state variables
)

gui.session_state[key] = new_vars = {}
new_vars = {}
if all_var_names:
gui.session_state[key] = new_vars
title_shown = False
for name in sorted(all_var_names):
var_key = f"--{key}:{name}"
Expand Down
2 changes: 1 addition & 1 deletion recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def on_send(
gui.session_state["input_images"] = new_input_images or None
gui.session_state["input_documents"] = new_input_documents or None

self.on_submit()
self.submit_and_redirect()

def render_steps(self):
if gui.session_state.get("tts_provider"):
Expand Down

0 comments on commit 9b7f636

Please sign in to comment.