Skip to content

Commit

Permalink
✨ Add validation plot and score for webui (#28)
Browse files Browse the repository at this point in the history
* ✨ Add validation plot and score for webui

* Add validation plot
* Add validation score
* Update style.css

* 🎨 Put the best validation score under the tab

* 🚨 update lint and example text
  • Loading branch information
dexhunter authored Nov 27, 2024
1 parent 5c7fa16 commit f3092ac
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 16 deletions.
127 changes: 112 additions & 15 deletions aide/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,35 @@ def handle_file_upload(self):
Returns:
list: List of uploaded or example files.
"""
if st.button(
"Load Example Experiment", type="primary", use_container_width=True
):
st.session_state.example_files = self.load_example_files()
# Only show file uploader if no example files are loaded
if not st.session_state.get("example_files"):
uploaded_files = st.file_uploader(
"Upload Data Files",
accept_multiple_files=True,
type=["csv", "txt", "json", "md"],
label_visibility="collapsed",
)

if uploaded_files:
st.session_state.pop(
"example_files", None
) # Remove example files if any
return uploaded_files

# Only show example button if no files are uploaded
if st.button(
"Load Example Experiment", type="primary", use_container_width=True
):
st.session_state.example_files = self.load_example_files()

if st.session_state.get("example_files"):
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
with st.expander("View Loaded Files", expanded=False):
for file in st.session_state.example_files:
st.text(f"📄 {file['name']}")
uploaded_files = st.session_state.example_files
else:
uploaded_files = st.file_uploader(
"Upload Data Files",
accept_multiple_files=True,
type=["csv", "txt", "json", "md"],
)
return uploaded_files
return st.session_state.example_files

return [] # Return empty list if no files are uploaded or loaded

def handle_user_inputs(self):
"""
Expand All @@ -187,12 +198,12 @@ def handle_user_inputs(self):
goal_text = st.text_area(
"Goal",
value=st.session_state.get("goal", ""),
placeholder="Example: Predict house prices",
placeholder="Example: Predict the sales price for each house",
)
eval_text = st.text_area(
"Evaluation Criteria",
value=st.session_state.get("eval", ""),
placeholder="Example: Use RMSE metric",
placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.",
)
num_steps = st.slider(
"Number of Steps",
Expand Down Expand Up @@ -450,7 +461,16 @@ def render_results_section(self):
st.header("Results")
if st.session_state.get("results"):
results = st.session_state.results
tabs = st.tabs(["Tree Visualization", "Best Solution", "Config", "Journal"])

tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
Expand All @@ -460,6 +480,12 @@ def render_results_section(self):
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
# Display best score before the plot
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results)
else:
st.info("No results to display. Please run an experiment.")

Expand Down Expand Up @@ -529,6 +555,77 @@ def render_journal(results):
else:
st.info("No journal available.")

@staticmethod
def get_best_metric(results):
"""
Extract the best validation metric from results.
"""
try:
journal_data = json.loads(results["journal"])
metrics = []
for node in journal_data:
if node["metric"] is not None:
try:
# Convert string metric to float
metric_value = float(node["metric"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue
return max(metrics) if metrics else None
except (json.JSONDecodeError, KeyError):
return None

@staticmethod
def render_validation_plot(results):
"""
Render the validation score plot.
"""
try:
journal_data = json.loads(results["journal"])
steps = []
metrics = []

for node in journal_data:
if node["metric"] is not None and node["metric"].lower() != "none":
try:
metric_value = float(node["metric"])
steps.append(node["step"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue

if metrics:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(
go.Scatter(
x=steps,
y=metrics,
mode="lines+markers",
name="Validation Score",
line=dict(color="#F04370"),
marker=dict(color="#F04370"),
)
)

fig.update_layout(
title="Validation Score Progress",
xaxis_title="Step",
yaxis_title="Validation Score",
template="plotly_white",
hovermode="x unified",
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)

st.plotly_chart(fig, use_container_width=True)
else:
st.info("No validation metrics available to plot.")

except (json.JSONDecodeError, KeyError):
st.error("Could not parse validation metrics data.")


if __name__ == "__main__":
app = WebUI()
Expand Down
2 changes: 1 addition & 1 deletion aide/webui/style.css
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Main colors */
:root {
--background: #F2F0E7;
--background-shaded: #EBE8DD;
--background-shaded: #FFFFFF;
--card: #FFFFFF;
--primary: #0D0F18;
--accent: #F04370;
Expand Down

0 comments on commit f3092ac

Please sign in to comment.