diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 9e1c679fb..d3a52175d 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1,4 +1,5 @@ import datetime +import hashlib import html import inspect import json @@ -139,8 +140,7 @@ class RequestModel(BaseModel): functions: list[RecipeFunction] | None = Field( title="🧩 Developer Tools and Functions", ) - variables: dict[str, typing.Any] = Field( - None, + variables: dict[str, typing.Any] | None = Field( title="⌥ Variables", description="Variables to be used as Jinja prompt templates and in functions as arguments", ) @@ -305,8 +305,8 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - _, run_id, uid = extract_query_params(gui.get_query_params()) - channel = self.realtime_channel_name(run_id, uid) + sr = self.get_current_sr() + channel = self.realtime_channel_name(sr.run_id, sr.uid) output = gui.realtime_pull([channel])[0] if output: gui.session_state.update(output) @@ -326,7 +326,7 @@ def render(self): self.render_report_form() return - self._render_header() + header_placeholder = gui.div() gui.newline() with gui.nav_tabs(): @@ -337,14 +337,19 @@ def render(self): with gui.nav_tab_content(): self.render_selected_tab() + with header_placeholder: + self._render_header() + def _render_header(self): current_run = self.get_current_sr() published_run = self.get_current_published_run() - is_example = published_run and published_run.saved_run == current_run + is_example = published_run.saved_run == current_run is_root_example = is_example and published_run.is_root() tbreadcrumbs = get_title_breadcrumbs( self, current_run, published_run, tab=self.tab ) + can_edit = self.can_user_edit_run(current_run, published_run) + request_changed = self._has_request_changed() with gui.div(className="d-flex justify-content-between mt-4"): with gui.div(className="d-lg-flex d-block align-items-center"): @@ -361,7 +366,6 @@ def _render_header(self): ) if is_example: - assert published_run author = published_run.created_by else: author = self.run_user or current_run.get_creator() @@ -369,35 +373,27 @@ def _render_header(self): self.render_author(author) with gui.div(className="d-flex align-items-center"): - can_user_edit_run = self.can_user_edit_run(current_run, published_run) - has_unpublished_changes = ( - published_run - and published_run.saved_run != current_run - and self.request - and self.request.user - ) - - if can_user_edit_run and has_unpublished_changes: + if request_changed or (can_edit and not is_example): self._render_unpublished_changes_indicator() with gui.div(className="d-flex align-items-start right-action-icons"): gui.html( """ - - """ + + """ ) - if published_run and can_user_edit_run: - self._render_published_run_buttons( + show_save_buttons = request_changed or can_edit + if show_save_buttons: + self._render_published_run_save_buttons( current_run=current_run, published_run=published_run, ) - - self._render_social_buttons(show_button_text=not can_user_edit_run) + self._render_social_buttons(show_button_text=not show_save_buttons) if tbreadcrumbs.has_breadcrumbs() or self.run_user: # only render title here if the above row was not empty @@ -451,11 +447,10 @@ def _render_unpublished_changes_indicator(self): gui.html("Unpublished changes") def _render_social_buttons(self, show_button_text: bool = False): - button_text = ( - ' Copy Link' - if show_button_text - else "" - ) + if show_button_text: + button_text = ' Copy Link' + else: + button_text = "" copy_to_clipboard_button( f'{button_text}', @@ -464,12 +459,11 @@ def _render_social_buttons(self, show_button_text: bool = False): className="mb-0 ms-lg-2", ) - def _render_published_run_buttons( + def _render_published_run_save_buttons( self, *, current_run: SavedRun, published_run: PublishedRun, - redirect_to: str | None = None, ): is_update_mode = ( self.is_current_user_admin() @@ -527,7 +521,6 @@ def _render_published_run_buttons( published_run=published_run, modal=publish_modal, is_update_mode=is_update_mode, - redirect_to=redirect_to, ) def _render_publish_modal( @@ -537,7 +530,6 @@ def _render_publish_modal( published_run: PublishedRun, modal: gui.Modal, is_update_mode: bool = False, - redirect_to: str | None = None, ): if published_run.is_root() and self.is_current_user_admin(): with gui.div(className="text-danger"): @@ -625,6 +617,11 @@ def _render_publish_modal( gui.error(str(e)) return + if self._has_request_changed(): + current_run = self.on_submit() + if not current_run: + modal.close() + if is_update_mode: updates = dict( saved_run=current_run, @@ -647,7 +644,7 @@ def _render_publish_modal( notes=published_run_notes.strip(), visibility=published_run_visibility, ) - raise gui.RedirectException(redirect_to or published_run.get_app_url()) + raise gui.RedirectException(published_run.get_app_url()) def _validate_published_run_title(self, title: str): if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: @@ -677,6 +674,25 @@ def _has_published_run_changed( or published_run.saved_run != saved_run ) + def _has_request_changed(self) -> bool: + if gui.session_state.get("--has-request-changed"): + return True + + try: + curr_req = self.RequestModel.parse_obj(gui.session_state) + except ValidationError: + # if the request model fails to parse, the request has likely changed + return True + + curr_hash = hashlib.md5(curr_req.json(sort_keys=True).encode()).hexdigest() + prev_hash = gui.session_state.setdefault("--prev-request-hash", curr_hash) + + if curr_hash != prev_hash: + gui.session_state["--has-request-changed"] = True # cache it for next time + return True + else: + return False + def _render_options_modal( self, *, @@ -1512,7 +1528,7 @@ def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = Fals submitted = True if submitted or self.should_submit_after_login(): - self.on_submit() + self.submit_and_redirect() run_state = self.get_run_state(gui.session_state) match run_state: @@ -1574,6 +1590,12 @@ def render_extra_waiting_output(self): def estimate_run_duration(self) -> int | None: pass + def submit_and_redirect(self): + sr = self.on_submit() + if not sr: + return + raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + def on_submit(self): try: sr = self.create_new_run(enable_rate_limits=True) @@ -1585,10 +1607,8 @@ def on_submit(self): gui.session_state[StateKeys.run_status] = None gui.session_state[StateKeys.error_msg] = e.detail.get("error", "") return - self.call_runner_task(sr) - - raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + return sr def should_submit_after_login(self) -> bool: return ( diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index c315da996..b81932533 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -55,7 +55,9 @@ def render_title_desc(): - set(gui.session_state.keys()) # dont show other session state variables ) - gui.session_state[key] = new_vars = {} + new_vars = {} + if all_var_names: + gui.session_state[key] = new_vars title_shown = False for name in sorted(all_var_names): var_key = f"--{key}:{name}" diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index edc3cadf3..41e7a33a8 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -662,7 +662,7 @@ def on_send( gui.session_state["input_images"] = new_input_images or None gui.session_state["input_documents"] = new_input_documents or None - self.on_submit() + self.submit_and_redirect() def render_steps(self): if gui.session_state.get("tts_provider"):