diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 9e1c679fb..0f1871cca 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -305,7 +305,11 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - _, run_id, uid = extract_query_params(gui.get_query_params()) + example_id, run_id, uid = extract_query_params(gui.get_query_params()) + if not run_id: + sr = self.get_sr_from_query_params(example_id, run_id, uid) + run_id, uid = sr.run_id, sr.uid + channel = self.realtime_channel_name(run_id, uid) output = gui.realtime_pull([channel])[0] if output: @@ -377,7 +381,9 @@ def _render_header(self): and self.request.user ) - if can_user_edit_run and has_unpublished_changes: + if ( + can_user_edit_run and has_unpublished_changes + ) or self._has_current_run_changed(current_run): self._render_unpublished_changes_indicator() with gui.div(className="d-flex align-items-start right-action-icons"): @@ -539,6 +545,8 @@ def _render_publish_modal( is_update_mode: bool = False, redirect_to: str | None = None, ): + is_example = published_run.saved_run == current_run + if published_run.is_root() and self.is_current_user_admin(): with gui.div(className="text-danger"): gui.write( @@ -647,7 +655,13 @@ def _render_publish_modal( notes=published_run_notes.strip(), visibility=published_run_visibility, ) - raise gui.RedirectException(redirect_to or published_run.get_app_url()) + + if redirect_to: + raise gui.RedirectException(redirect_to) + elif is_example: + modal.close() # implicit gui.rerun to reload the updated run + else: + raise gui.RedirectException(published_run.get_app_url()) def _validate_published_run_title(self, title: str): if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: @@ -677,6 +691,17 @@ def _has_published_run_changed( or published_run.saved_run != saved_run ) + def _has_current_run_changed(self, sr: SavedRun) -> bool: + """are there unsaved changes that haven't been run?""" + try: + extracted_state = self.RequestModel.parse_obj(gui.session_state) + extracted_sr = self.RequestModel.parse_obj(sr.to_dict()) + return extracted_sr != extracted_state + except ValidationError as e: + # don't want page to be inaccessible if ever validation fails - log and continue + sentry_sdk.capture_exception(e) + return False + def _render_options_modal( self, *, @@ -1575,6 +1600,11 @@ def estimate_run_duration(self) -> int | None: pass def on_submit(self): + sr = self._on_submit() + if sr: + raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + + def _on_submit(self): try: sr = self.create_new_run(enable_rate_limits=True) except ValidationError as e: @@ -1588,7 +1618,7 @@ def on_submit(self): self.call_runner_task(sr) - raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + return sr def should_submit_after_login(self) -> bool: return (