Skip to content

Commit

Permalink
run before save - if unsaved changes in current run
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko authored and devxpy committed Aug 28, 2024
1 parent d9e602a commit afed2b1
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@ def sentry_event_set_user(self, event, hint):
return event

def refresh_state(self):
_, run_id, uid = extract_query_params(gui.get_query_params())
example_id, run_id, uid = extract_query_params(gui.get_query_params())
if not run_id:
sr = self.get_sr_from_query_params(example_id, run_id, uid)
run_id, uid = sr.run_id, sr.uid

channel = self.realtime_channel_name(run_id, uid)
output = gui.realtime_pull([channel])[0]
if output:
Expand Down Expand Up @@ -377,7 +381,9 @@ def _render_header(self):
and self.request.user
)

if can_user_edit_run and has_unpublished_changes:
if (
can_user_edit_run and has_unpublished_changes
) or self._has_current_run_changed(current_run):
self._render_unpublished_changes_indicator()

with gui.div(className="d-flex align-items-start right-action-icons"):
Expand Down Expand Up @@ -539,6 +545,8 @@ def _render_publish_modal(
is_update_mode: bool = False,
redirect_to: str | None = None,
):
is_example = published_run.saved_run == current_run

if published_run.is_root() and self.is_current_user_admin():
with gui.div(className="text-danger"):
gui.write(
Expand Down Expand Up @@ -647,7 +655,13 @@ def _render_publish_modal(
notes=published_run_notes.strip(),
visibility=published_run_visibility,
)
raise gui.RedirectException(redirect_to or published_run.get_app_url())

if redirect_to:
raise gui.RedirectException(redirect_to)
elif is_example:
modal.close() # implicit gui.rerun to reload the updated run
else:
raise gui.RedirectException(published_run.get_app_url())

def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
Expand Down Expand Up @@ -677,6 +691,17 @@ def _has_published_run_changed(
or published_run.saved_run != saved_run
)

def _has_current_run_changed(self, sr: SavedRun) -> bool:
"""are there unsaved changes that haven't been run?"""
try:
extracted_state = self.RequestModel.parse_obj(gui.session_state)
extracted_sr = self.RequestModel.parse_obj(sr.to_dict())
return extracted_sr != extracted_state
except ValidationError as e:
# don't want page to be inaccessible if ever validation fails - log and continue
sentry_sdk.capture_exception(e)
return False

def _render_options_modal(
self,
*,
Expand Down Expand Up @@ -1575,6 +1600,11 @@ def estimate_run_duration(self) -> int | None:
pass

def on_submit(self):
sr = self._on_submit()
if sr:
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

def _on_submit(self):
try:
sr = self.create_new_run(enable_rate_limits=True)
except ValidationError as e:
Expand All @@ -1588,7 +1618,7 @@ def on_submit(self):

self.call_runner_task(sr)

raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
return sr

def should_submit_after_login(self) -> bool:
return (
Expand Down

0 comments on commit afed2b1

Please sign in to comment.