Skip to content

Commit

Permalink
Merge pull request #29 from dexhunter/refactor/webui
Browse files Browse the repository at this point in the history
♻️ Refactor webui to live render results
  • Loading branch information
ZhengyaoJiang authored Nov 27, 2024
2 parents f3092ac + 3689ff7 commit 113f409
Showing 1 changed file with 76 additions and 109 deletions.
185 changes: 76 additions & 109 deletions aide/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def run(self):
input_col, results_col = st.columns([1, 3])
with input_col:
self.render_input_section(results_col)
with results_col:
self.render_results_section()

def render_sidebar(self):
"""
Expand Down Expand Up @@ -273,17 +271,46 @@ def run_aide(self, files, goal_text, eval_text, num_steps, results_col):
return None

experiment = self.initialize_experiment(input_dir, goal_text, eval_text)
placeholders = self.create_results_placeholders(results_col, experiment)

# Create separate placeholders for progress and config
progress_placeholder = results_col.empty()
config_placeholder = results_col.empty()
results_placeholder = results_col.empty()

for step in range(num_steps):
st.session_state.current_step = step + 1
progress = (step + 1) / num_steps
self.update_results_placeholders(placeholders, progress)

# Update progress
with progress_placeholder.container():
st.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
st.progress(progress)

# Show config only for first step
if step == 0:
with config_placeholder.container():
st.markdown("### 📋 Configuration")
st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")

experiment.run(steps=1)

self.clear_run_state(placeholders)
# Show results
with results_placeholder.container():
self.render_live_results(experiment)

# Clear config after first step
if step == 0:
config_placeholder.empty()

return self.collect_results(experiment)
# Clear progress after all steps
progress_placeholder.empty()

# Update session state
st.session_state.is_running = False
st.session_state.results = self.collect_results(experiment)
return st.session_state.results

except Exception as e:
st.session_state.is_running = False
Expand Down Expand Up @@ -355,70 +382,6 @@ def initialize_experiment(input_dir, goal_text, eval_text):
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text)
return experiment

@staticmethod
def create_results_placeholders(results_col, experiment):
"""
Create placeholders in the results column for dynamic content.
Args:
results_col (st.delta_generator.DeltaGenerator): The results column.
experiment (Experiment): The Experiment object.
Returns:
dict: Dictionary of placeholders.
"""
with results_col:
status_placeholder = st.empty()
step_placeholder = st.empty()
config_title_placeholder = st.empty()
config_placeholder = st.empty()
progress_placeholder = st.empty()

step_placeholder.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
config_title_placeholder.markdown("### 📋 Configuration")
config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")
progress_placeholder.progress(0)

placeholders = {
"status": status_placeholder,
"step": step_placeholder,
"config_title": config_title_placeholder,
"config": config_placeholder,
"progress": progress_placeholder,
}
return placeholders

@staticmethod
def update_results_placeholders(placeholders, progress):
"""
Update the placeholders with the current progress.
Args:
placeholders (dict): Dictionary of placeholders.
progress (float): Current progress value.
"""
placeholders["step"].markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
placeholders["progress"].progress(progress)

@staticmethod
def clear_run_state(placeholders):
"""
Clear the running state and placeholders after the experiment.
Args:
placeholders (dict): Dictionary of placeholders.
"""
st.session_state.is_running = False
placeholders["status"].empty()
placeholders["step"].empty()
placeholders["config_title"].empty()
placeholders["config"].empty()
placeholders["progress"].empty()

@staticmethod
def collect_results(experiment):
"""
Expand Down Expand Up @@ -454,41 +417,6 @@ def collect_results(experiment):
}
return results

def render_results_section(self):
"""
Render the results section with tabs for different outputs.
"""
st.header("Results")
if st.session_state.get("results"):
results = st.session_state.results

tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
# Display best score before the plot
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results)
else:
st.info("No results to display. Please run an experiment.")

@staticmethod
def render_tree_visualization(results):
"""
Expand Down Expand Up @@ -576,9 +504,13 @@ def get_best_metric(results):
return None

@staticmethod
def render_validation_plot(results):
def render_validation_plot(results, step):
"""
Render the validation score plot.
Args:
results (dict): The results dictionary
step (int): Current step number for unique key generation
"""
try:
journal_data = json.loads(results["journal"])
Expand Down Expand Up @@ -619,12 +551,47 @@ def render_validation_plot(results):
paper_bgcolor="rgba(0,0,0,0)",
)

st.plotly_chart(fig, use_container_width=True)
# Only keep the key for plotly_chart
st.plotly_chart(fig, use_container_width=True, key=f"plot_{step}")
else:
st.info("No validation metrics available to plot.")
st.info("No validation metrics available to plot")

except (json.JSONDecodeError, KeyError):
st.error("Could not parse validation metrics data.")
st.error("Could not parse validation metrics data")

def render_live_results(self, experiment):
"""
Render live results.
Args:
experiment (Experiment): The Experiment object
"""
results = self.collect_results(experiment)

# Create tabs for different result views
tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results, step=st.session_state.current_step)


if __name__ == "__main__":
Expand Down

0 comments on commit 113f409

Please sign in to comment.