From f3092ac48ec79585cc9cba6b3596ce9d1acdfe9f Mon Sep 17 00:00:00 2001 From: "Dixing (Dex) Xu" Date: Wed, 27 Nov 2024 22:13:08 +0800 Subject: [PATCH] :sparkles: Add validation plot and score for webui (#28) * :sparkles: Add validation plot and score for webui * Add validation plot * Add validation score * Update style.css * :art: Put the best validation score under the tab * :rotating_light: update lint and example text --- aide/webui/app.py | 127 ++++++++++++++++++++++++++++++++++++++----- aide/webui/style.css | 2 +- 2 files changed, 113 insertions(+), 16 deletions(-) diff --git a/aide/webui/app.py b/aide/webui/app.py index e43d465..68f9d73 100644 --- a/aide/webui/app.py +++ b/aide/webui/app.py @@ -158,24 +158,35 @@ def handle_file_upload(self): Returns: list: List of uploaded or example files. """ - if st.button( - "Load Example Experiment", type="primary", use_container_width=True - ): - st.session_state.example_files = self.load_example_files() + # Only show file uploader if no example files are loaded + if not st.session_state.get("example_files"): + uploaded_files = st.file_uploader( + "Upload Data Files", + accept_multiple_files=True, + type=["csv", "txt", "json", "md"], + label_visibility="collapsed", + ) + + if uploaded_files: + st.session_state.pop( + "example_files", None + ) # Remove example files if any + return uploaded_files + + # Only show example button if no files are uploaded + if st.button( + "Load Example Experiment", type="primary", use_container_width=True + ): + st.session_state.example_files = self.load_example_files() if st.session_state.get("example_files"): st.info("Example files loaded! Click 'Run AIDE' to proceed.") with st.expander("View Loaded Files", expanded=False): for file in st.session_state.example_files: st.text(f"📄 {file['name']}") - uploaded_files = st.session_state.example_files - else: - uploaded_files = st.file_uploader( - "Upload Data Files", - accept_multiple_files=True, - type=["csv", "txt", "json", "md"], - ) - return uploaded_files + return st.session_state.example_files + + return [] # Return empty list if no files are uploaded or loaded def handle_user_inputs(self): """ @@ -187,12 +198,12 @@ def handle_user_inputs(self): goal_text = st.text_area( "Goal", value=st.session_state.get("goal", ""), - placeholder="Example: Predict house prices", + placeholder="Example: Predict the sales price for each house", ) eval_text = st.text_area( "Evaluation Criteria", value=st.session_state.get("eval", ""), - placeholder="Example: Use RMSE metric", + placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.", ) num_steps = st.slider( "Number of Steps", @@ -450,7 +461,16 @@ def render_results_section(self): st.header("Results") if st.session_state.get("results"): results = st.session_state.results - tabs = st.tabs(["Tree Visualization", "Best Solution", "Config", "Journal"]) + + tabs = st.tabs( + [ + "Tree Visualization", + "Best Solution", + "Config", + "Journal", + "Validation Plot", + ] + ) with tabs[0]: self.render_tree_visualization(results) @@ -460,6 +480,12 @@ def render_results_section(self): self.render_config(results) with tabs[3]: self.render_journal(results) + with tabs[4]: + # Display best score before the plot + best_metric = self.get_best_metric(results) + if best_metric is not None: + st.metric("Best Validation Score", f"{best_metric:.4f}") + self.render_validation_plot(results) else: st.info("No results to display. Please run an experiment.") @@ -529,6 +555,77 @@ def render_journal(results): else: st.info("No journal available.") + @staticmethod + def get_best_metric(results): + """ + Extract the best validation metric from results. + """ + try: + journal_data = json.loads(results["journal"]) + metrics = [] + for node in journal_data: + if node["metric"] is not None: + try: + # Convert string metric to float + metric_value = float(node["metric"]) + metrics.append(metric_value) + except (ValueError, TypeError): + continue + return max(metrics) if metrics else None + except (json.JSONDecodeError, KeyError): + return None + + @staticmethod + def render_validation_plot(results): + """ + Render the validation score plot. + """ + try: + journal_data = json.loads(results["journal"]) + steps = [] + metrics = [] + + for node in journal_data: + if node["metric"] is not None and node["metric"].lower() != "none": + try: + metric_value = float(node["metric"]) + steps.append(node["step"]) + metrics.append(metric_value) + except (ValueError, TypeError): + continue + + if metrics: + import plotly.graph_objects as go + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=steps, + y=metrics, + mode="lines+markers", + name="Validation Score", + line=dict(color="#F04370"), + marker=dict(color="#F04370"), + ) + ) + + fig.update_layout( + title="Validation Score Progress", + xaxis_title="Step", + yaxis_title="Validation Score", + template="plotly_white", + hovermode="x unified", + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", + ) + + st.plotly_chart(fig, use_container_width=True) + else: + st.info("No validation metrics available to plot.") + + except (json.JSONDecodeError, KeyError): + st.error("Could not parse validation metrics data.") + if __name__ == "__main__": app = WebUI() diff --git a/aide/webui/style.css b/aide/webui/style.css index 0219363..d364651 100644 --- a/aide/webui/style.css +++ b/aide/webui/style.css @@ -1,7 +1,7 @@ /* Main colors */ :root { --background: #F2F0E7; - --background-shaded: #EBE8DD; + --background-shaded: #FFFFFF; --card: #FFFFFF; --primary: #0D0F18; --accent: #F04370;