Skip to content

Commit 8287616

Browse files
committed
Add animation md
1 parent 53cfc2b commit 8287616

4 files changed

+102
-77
lines changed
Binary file not shown.

data/results/paper_plots.py

+101-77
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import copy
2+
import json
3+
import os
24
from typing import Iterable, Optional, Union
35

46
import pandas as pd
@@ -9,6 +11,9 @@
911
from plotly.graph_objs import Figure
1012
import plotly.graph_objs as go
1113
import plotly.express as px
14+
from tqdm import tqdm
15+
from wandb.apis.public import RunArtifacts
16+
from wandb.sdk.wandb_summary import SummarySubDict
1217

1318

1419
def plot_aggregated_lines(df: pd.DataFrame,
@@ -194,15 +199,15 @@ def plot_aggregated_lines(df: pd.DataFrame,
194199
"DAE-AUG": dict(dash='dot'),
195200
"DAE": dict(dash='dash')}
196201
}
197-
color_group = {'system.group': {"K4xC2": color_pallet[0],
198-
"K4": color_pallet[10],
199-
"C2": color_pallet[1], }
202+
color_group = {'system.group': {"K4xC2": color_pallet[1],
203+
"K4": color_pallet[7]}
200204
}
201205

202206
data_path = Path("mini_cheetah_sample_eff_uneven_easy_terrain.csv")
203207
print(data_path)
204208
df = pd.read_csv(data_path)
205-
209+
# Ignore all records of group = C2
210+
df = df[df["system.group"] != "C2"]
206211
fig = plot_aggregated_lines(df,
207212
x="system.train_ratio",
208213
y="state_pred_loss/test",
@@ -212,9 +217,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
212217
area_metric="std",
213218
label_replace={"model.name": "Model",
214219
"system.group": "Group",
215-
"C2": r"$\mathbb{G}_{\Omega}=\mathbb{C}_2$",
216-
"K4": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4$",
217-
"K4xC2": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4 \times \mathbb{"
220+
"C2": r"$\mathbb{G}=\mathbb{C}_2$",
221+
"K4": r"$\mathbb{G}=\mathbb{K}_4$",
222+
"K4xC2": r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
218223
r"C}_2$", }
219224
)
220225
# Set the figure size to a quarter of an A4 page
@@ -233,7 +238,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
233238
df = pd.read_csv(data_path)
234239
STATE_DIM = 42
235240
df['system.obs_state_ratio'] = df['system.obs_state_ratio'] * STATE_DIM
236-
241+
# Ignore all records of group = C2
242+
df = df[df["system.group"] != "C2"]
237243
fig = plot_aggregated_lines(df,
238244
x="system.obs_state_ratio",
239245
y="state_pred_loss/test",
@@ -243,9 +249,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
243249
area_metric="std",
244250
label_replace={"model.name": "Model",
245251
"system.group": "Group",
246-
"C2": r"$\mathbb{G}_{\Omega}=\mathbb{C}_2$",
247-
"K4": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4$",
248-
"K4xC2": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4 \times \mathbb{"
252+
"C2": r"$\mathbb{G}=\mathbb{C}_2$",
253+
"K4": r"$\mathbb{G}=\mathbb{K}_4$",
254+
"K4xC2": r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
249255
r"C}_2$", }
250256
)
251257
# Set the figure size to a quarter of an A4 page
@@ -258,67 +264,85 @@ def plot_aggregated_lines(df: pd.DataFrame,
258264
fig.write_image(data_path.with_suffix(".svg"))
259265
fig.write_image(data_path.with_suffix(".png"))
260266

261-
# MSE vs Time =================================================
262-
data_path = Path("mini_cheetah_mse_vs_time_uneven_easy_terrain.csv")
263-
print(data_path)
264-
267+
# # MSE vs Time =================================================
268+
# print(f"Mini Cheetah MSE vs Time")
265269
# import wandb
266270
# wandb.login()
267271
# api = wandb.Api()
268272
# project_path = "dls-csml/mini_cheetah"
269-
# group_name = "mse_vs_time"
270-
# metric_name = "state_pred_loss_t/test"
273+
# group_name = "mse_vs_time_final"
274+
# metric_name = "state_pred_loss_t"
271275
# runs = api.runs(project_path, {"$and": [{"group": group_name}] })
272276
# print(f"Found {len(runs)} runs for group {group_name}")
273-
274-
df = pd.read_csv(data_path)
275-
# MSE vs time needs a bit of reformat
276-
df_reformatted = pd.DataFrame()
277-
for col in df.columns:
278-
if "time" in col or "step" in col: continue
279-
run_name, var_name = col.split(" - ")
280-
print(var_name)
281-
# var_name = var_name.split("__")[0]
282-
if not var_name == "state_pred_loss_t/test": continue
283-
model_name = "DAE-AUG" if "DAE-AUG" in run_name else "E-DAE" if "E-DAE" in col else "DAE"
284-
system_group = "K4xC2" if "K4xC2" in run_name else "K4" if "K4" in run_name else "C2"
285-
df_run = pd.DataFrame({"Name": run_name, "model.name": model_name, "system.group": system_group,
286-
"time": df["time"], var_name: df[col]})
287-
df_reformatted = pd.concat([df_reformatted, df_run], axis=0)
288-
289-
fig = plot_aggregated_lines(df_reformatted,
290-
x="time",
291-
y="state_pred_loss_t/test",
292-
group_variables=['model.name', 'system.group'],
293-
line_styles=line_styles,
294-
color_group=color_group,
295-
area_metric="std",
296-
label_replace={"model.name": "Model",
297-
"system.group": "Group",
298-
"C2": r"$\mathbb{G}_{\Omega}=\mathbb{C}_2$",
299-
"K4": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4$",
300-
"K4xC2": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4 \times \mathbb{"
301-
r"C}_2$", }
302-
)
303-
# Determine the range of your data on a logarithmic scale
304-
# y_min, y_max = 0.1, 6.0
305-
# log_min, log_max = np.log10(y_min), np.log10(y_max)
306-
# tickvals = [np.round(10 ** x, 2) for x in np.arange(log_min, log_max, 0.1)] # Adjust step for more granularity
307-
# tickvals = sorted(list(set(tickvals + [y_min, y_max]))) # Ensure start and end values are included
308-
# # Format tick labels
309-
# ticktext = [f'{val:.1f}' for val in tickvals]
310-
311-
# Set the figure size to a quarter of an A4 page
312-
fig.update_layout(**layout_config)
313-
fig.update_xaxes(title_text="Prediction horizon [s]")
314-
fig.update_yaxes(title_text="state prediction MSE",
315-
type="log",
316-
)
317-
# fig.show()
318-
fig.write_html(data_path.with_suffix(".html"))
319-
fig.write_image(data_path.with_suffix(".svg"))
320-
fig.write_image(data_path.with_suffix(".png"))
321-
# fig.show()
277+
# df = pd.DataFrame()
278+
# download_path = Path("./artifacts")
279+
# # Iterate over each run
280+
# for i, run in tqdm(enumerate(list(runs))):
281+
# # Access the list of artifacts for the run
282+
# artifacts = run.logged_artifacts()
283+
# print(f"Run {i}")
284+
# df_run = None
285+
# for artifact in artifacts:
286+
# if metric_name in artifact.name:
287+
# # Construct the unique path for this run's artifact
288+
# artifact_file_path = download_path / f"{artifact.name}.json"
289+
# # Check if the file already exists
290+
# if os.path.exists(artifact_file_path):
291+
# print(f"Artifact already downloaded: {artifact_file_path}")
292+
# else:
293+
# print(f"Downloading artifact to : {artifact_file_path}")
294+
# # Download the artifact
295+
# table_dir = artifact.download()
296+
# table_files = list(Path(table_dir).rglob('*test.table.json'))
297+
# if len(table_files) == 0:
298+
# print(f"Run {run.id} {run.name} did not save the state_pred_loss_t table")
299+
# continue
300+
# table_path = table_files[0]
301+
# # Move the file to the specified base directory
302+
# os.rename(table_path, artifact_file_path)
303+
#
304+
# with artifact_file_path.open('r') as file:
305+
# json_dict = json.load(file)
306+
# df_run = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
307+
# break
308+
# if df_run is not None:
309+
# # Search in this run config values for model.name and system.group values and append to df
310+
# config = run.config
311+
# model_name = config["model"]["name"]
312+
# system_group = config["system"]["group"]
313+
# df_run = df_run.assign(**{"model.name": model_name, "system.group": system_group, "Name": run.name})
314+
#
315+
# df = pd.concat([df, df_run], axis=0)
316+
# else:
317+
# print(f"Run {run.id} {run.name} did not save the state_pred_loss_t table")
318+
#
319+
# # Ignore all records of group = C2
320+
# df = df[df["system.group"] != "C2"]
321+
# fig = plot_aggregated_lines(df,
322+
# x="time",
323+
# y="state_pred_loss_t/test",
324+
# group_variables=['model.name', 'system.group'],
325+
# line_styles=line_styles,
326+
# color_group=color_group,
327+
# area_metric="std",
328+
# label_replace={"model.name": "Model",
329+
# "system.group": "Group",
330+
# "K4": r"$\mathbb{G}=\mathbb{K}_4$",
331+
# "K4xC2": r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
332+
# r"C}_2$", }
333+
# )
334+
# # Set the figure size to a quarter of an A4 page
335+
# fig.update_layout(**layout_config)
336+
# fig.update_xaxes(title_text="Prediction horizon [s]")
337+
# fig.update_yaxes(title_text="state prediction MSE",
338+
# type="log",
339+
# )
340+
# # fig.show()
341+
# data_path = Path("mini_cheetah_mse_vs_time.csv")
342+
# fig.write_html(data_path.with_suffix(".html"))
343+
# fig.write_image(data_path.with_suffix(".svg"))
344+
# fig.write_image(data_path.with_suffix(".png"))
345+
# # fig.show()
322346
# ==========================================================================
323347
# ========================LINEAR EXPERIMENT ================================
324348
# ==========================================================================
@@ -350,8 +374,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
350374
line_width=3,
351375
label_replace={"model.name": "Model",
352376
"system.group": "Group",
353-
"C5": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_5$",
354-
"C10": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_{10}$",
377+
"C5": r"$\large \mathbb{G}=\mathbb{C}_5$",
378+
"C10": r"$\large \mathbb{G}=\mathbb{C}_{10}$",
355379
}
356380

357381
)
@@ -380,8 +404,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
380404
line_width=3,
381405
label_replace={"model.name": "Model",
382406
"system.group": "Group",
383-
"C5": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_5$",
384-
"C10": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_{10}$",
407+
"C5": r"$\large \mathbb{G}=\mathbb{C}_5$",
408+
"C10": r"$\large \mathbb{G}=\mathbb{C}_{10}$",
385409
}
386410

387411
)
@@ -408,8 +432,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
408432
line_width=3,
409433
label_replace={"model.name": "Model",
410434
"system.group": "Group",
411-
"C5": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_5$",
412-
"C10": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_{10}$",
435+
"C5": r"$\large \mathbb{G}=\mathbb{C}_5$",
436+
"C10": r"$\large \mathbb{G}=\mathbb{C}_{10}$",
413437
}
414438

415439
)
@@ -437,8 +461,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
437461
line_width=3,
438462
label_replace={"model.name": "Model",
439463
"system.group": "Group",
440-
"C5": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_5$",
441-
"C10": r"$\large \mathbb{G}_{\Omega}=\mathbb{C}_{10}$",
464+
"C5": r"$\large \mathbb{G}=\mathbb{C}_5$",
465+
"C10": r"$\large \mathbb{G}=\mathbb{C}_{10}$",
442466
}
443467

444468
)
@@ -479,9 +503,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
479503
area_metric="std",
480504
label_replace={"model.name": "Model",
481505
"system.group": "Group",
482-
"C2": r"$\mathbb{G}_{\Omega}=\mathbb{C}_2$",
483-
"K4": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4$",
484-
"K4xC2": r"$\mathbb{G}_{\Omega}=\mathbb{K}_4 \times \mathbb{"
506+
"C2": r"$\mathbb{G}=\mathbb{C}_2$",
507+
"K4": r"$\mathbb{G}=\mathbb{K}_4$",
508+
"K4xC2": r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
485509
r"C}_2$", }
486510
)
487511

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This md will hold the videos for the deocmpositon of the many trajectories of motion of the system

media/mini_cheetah_symmetry_group.md

Whitespace-only changes.

0 commit comments

Comments
 (0)