From 3689ff7e50232439f29bc4eab7bc24914be9a7a9 Mon Sep 17 00:00:00 2001 From: Dixing Xu Date: Wed, 27 Nov 2024 22:53:07 +0800 Subject: [PATCH] :recycle: Refactor webui to live render results * live render --- aide/webui/app.py | 185 +++++++++++++++++++--------------------------- 1 file changed, 76 insertions(+), 109 deletions(-) diff --git a/aide/webui/app.py b/aide/webui/app.py index 68f9d73..0064e92 100644 --- a/aide/webui/app.py +++ b/aide/webui/app.py @@ -100,8 +100,6 @@ def run(self): input_col, results_col = st.columns([1, 3]) with input_col: self.render_input_section(results_col) - with results_col: - self.render_results_section() def render_sidebar(self): """ @@ -273,17 +271,46 @@ def run_aide(self, files, goal_text, eval_text, num_steps, results_col): return None experiment = self.initialize_experiment(input_dir, goal_text, eval_text) - placeholders = self.create_results_placeholders(results_col, experiment) + + # Create separate placeholders for progress and config + progress_placeholder = results_col.empty() + config_placeholder = results_col.empty() + results_placeholder = results_col.empty() for step in range(num_steps): st.session_state.current_step = step + 1 progress = (step + 1) / num_steps - self.update_results_placeholders(placeholders, progress) + + # Update progress + with progress_placeholder.container(): + st.markdown( + f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" + ) + st.progress(progress) + + # Show config only for first step + if step == 0: + with config_placeholder.container(): + st.markdown("### 📋 Configuration") + st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml") + experiment.run(steps=1) - self.clear_run_state(placeholders) + # Show results + with results_placeholder.container(): + self.render_live_results(experiment) + + # Clear config after first step + if step == 0: + config_placeholder.empty() - return self.collect_results(experiment) + # Clear progress after all steps + progress_placeholder.empty() + + # Update session state + st.session_state.is_running = False + st.session_state.results = self.collect_results(experiment) + return st.session_state.results except Exception as e: st.session_state.is_running = False @@ -355,70 +382,6 @@ def initialize_experiment(input_dir, goal_text, eval_text): experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text) return experiment - @staticmethod - def create_results_placeholders(results_col, experiment): - """ - Create placeholders in the results column for dynamic content. - - Args: - results_col (st.delta_generator.DeltaGenerator): The results column. - experiment (Experiment): The Experiment object. - - Returns: - dict: Dictionary of placeholders. - """ - with results_col: - status_placeholder = st.empty() - step_placeholder = st.empty() - config_title_placeholder = st.empty() - config_placeholder = st.empty() - progress_placeholder = st.empty() - - step_placeholder.markdown( - f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" - ) - config_title_placeholder.markdown("### 📋 Configuration") - config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml") - progress_placeholder.progress(0) - - placeholders = { - "status": status_placeholder, - "step": step_placeholder, - "config_title": config_title_placeholder, - "config": config_placeholder, - "progress": progress_placeholder, - } - return placeholders - - @staticmethod - def update_results_placeholders(placeholders, progress): - """ - Update the placeholders with the current progress. - - Args: - placeholders (dict): Dictionary of placeholders. - progress (float): Current progress value. - """ - placeholders["step"].markdown( - f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" - ) - placeholders["progress"].progress(progress) - - @staticmethod - def clear_run_state(placeholders): - """ - Clear the running state and placeholders after the experiment. - - Args: - placeholders (dict): Dictionary of placeholders. - """ - st.session_state.is_running = False - placeholders["status"].empty() - placeholders["step"].empty() - placeholders["config_title"].empty() - placeholders["config"].empty() - placeholders["progress"].empty() - @staticmethod def collect_results(experiment): """ @@ -454,41 +417,6 @@ def collect_results(experiment): } return results - def render_results_section(self): - """ - Render the results section with tabs for different outputs. - """ - st.header("Results") - if st.session_state.get("results"): - results = st.session_state.results - - tabs = st.tabs( - [ - "Tree Visualization", - "Best Solution", - "Config", - "Journal", - "Validation Plot", - ] - ) - - with tabs[0]: - self.render_tree_visualization(results) - with tabs[1]: - self.render_best_solution(results) - with tabs[2]: - self.render_config(results) - with tabs[3]: - self.render_journal(results) - with tabs[4]: - # Display best score before the plot - best_metric = self.get_best_metric(results) - if best_metric is not None: - st.metric("Best Validation Score", f"{best_metric:.4f}") - self.render_validation_plot(results) - else: - st.info("No results to display. Please run an experiment.") - @staticmethod def render_tree_visualization(results): """ @@ -576,9 +504,13 @@ def get_best_metric(results): return None @staticmethod - def render_validation_plot(results): + def render_validation_plot(results, step): """ Render the validation score plot. + + Args: + results (dict): The results dictionary + step (int): Current step number for unique key generation """ try: journal_data = json.loads(results["journal"]) @@ -619,12 +551,47 @@ def render_validation_plot(results): paper_bgcolor="rgba(0,0,0,0)", ) - st.plotly_chart(fig, use_container_width=True) + # Only keep the key for plotly_chart + st.plotly_chart(fig, use_container_width=True, key=f"plot_{step}") else: - st.info("No validation metrics available to plot.") + st.info("No validation metrics available to plot") except (json.JSONDecodeError, KeyError): - st.error("Could not parse validation metrics data.") + st.error("Could not parse validation metrics data") + + def render_live_results(self, experiment): + """ + Render live results. + + Args: + experiment (Experiment): The Experiment object + """ + results = self.collect_results(experiment) + + # Create tabs for different result views + tabs = st.tabs( + [ + "Tree Visualization", + "Best Solution", + "Config", + "Journal", + "Validation Plot", + ] + ) + + with tabs[0]: + self.render_tree_visualization(results) + with tabs[1]: + self.render_best_solution(results) + with tabs[2]: + self.render_config(results) + with tabs[3]: + self.render_journal(results) + with tabs[4]: + best_metric = self.get_best_metric(results) + if best_metric is not None: + st.metric("Best Validation Score", f"{best_metric:.4f}") + self.render_validation_plot(results, step=st.session_state.current_step) if __name__ == "__main__":