1
1
import copy
2
+ import json
3
+ import os
2
4
from typing import Iterable , Optional , Union
3
5
4
6
import pandas as pd
9
11
from plotly .graph_objs import Figure
10
12
import plotly .graph_objs as go
11
13
import plotly .express as px
14
+ from tqdm import tqdm
15
+ from wandb .apis .public import RunArtifacts
16
+ from wandb .sdk .wandb_summary import SummarySubDict
12
17
13
18
14
19
def plot_aggregated_lines (df : pd .DataFrame ,
@@ -194,15 +199,15 @@ def plot_aggregated_lines(df: pd.DataFrame,
194
199
"DAE-AUG" : dict (dash = 'dot' ),
195
200
"DAE" : dict (dash = 'dash' )}
196
201
}
197
- color_group = {'system.group' : {"K4xC2" : color_pallet [0 ],
198
- "K4" : color_pallet [10 ],
199
- "C2" : color_pallet [1 ], }
202
+ color_group = {'system.group' : {"K4xC2" : color_pallet [1 ],
203
+ "K4" : color_pallet [7 ]}
200
204
}
201
205
202
206
data_path = Path ("mini_cheetah_sample_eff_uneven_easy_terrain.csv" )
203
207
print (data_path )
204
208
df = pd .read_csv (data_path )
205
-
209
+ # Ignore all records of group = C2
210
+ df = df [df ["system.group" ] != "C2" ]
206
211
fig = plot_aggregated_lines (df ,
207
212
x = "system.train_ratio" ,
208
213
y = "state_pred_loss/test" ,
@@ -212,9 +217,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
212
217
area_metric = "std" ,
213
218
label_replace = {"model.name" : "Model" ,
214
219
"system.group" : "Group" ,
215
- "C2" : r"$\mathbb{G}_{\Omega} =\mathbb{C}_2$" ,
216
- "K4" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4$" ,
217
- "K4xC2" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4 \times \mathbb{"
220
+ "C2" : r"$\mathbb{G}=\mathbb{C}_2$" ,
221
+ "K4" : r"$\mathbb{G}=\mathbb{K}_4$" ,
222
+ "K4xC2" : r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
218
223
r"C}_2$" , }
219
224
)
220
225
# Set the figure size to a quarter of an A4 page
@@ -233,7 +238,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
233
238
df = pd .read_csv (data_path )
234
239
STATE_DIM = 42
235
240
df ['system.obs_state_ratio' ] = df ['system.obs_state_ratio' ] * STATE_DIM
236
-
241
+ # Ignore all records of group = C2
242
+ df = df [df ["system.group" ] != "C2" ]
237
243
fig = plot_aggregated_lines (df ,
238
244
x = "system.obs_state_ratio" ,
239
245
y = "state_pred_loss/test" ,
@@ -243,9 +249,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
243
249
area_metric = "std" ,
244
250
label_replace = {"model.name" : "Model" ,
245
251
"system.group" : "Group" ,
246
- "C2" : r"$\mathbb{G}_{\Omega} =\mathbb{C}_2$" ,
247
- "K4" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4$" ,
248
- "K4xC2" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4 \times \mathbb{"
252
+ "C2" : r"$\mathbb{G}=\mathbb{C}_2$" ,
253
+ "K4" : r"$\mathbb{G}=\mathbb{K}_4$" ,
254
+ "K4xC2" : r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
249
255
r"C}_2$" , }
250
256
)
251
257
# Set the figure size to a quarter of an A4 page
@@ -258,67 +264,85 @@ def plot_aggregated_lines(df: pd.DataFrame,
258
264
fig .write_image (data_path .with_suffix (".svg" ))
259
265
fig .write_image (data_path .with_suffix (".png" ))
260
266
261
- # MSE vs Time =================================================
262
- data_path = Path ("mini_cheetah_mse_vs_time_uneven_easy_terrain.csv" )
263
- print (data_path )
264
-
267
+ # # MSE vs Time =================================================
268
+ # print(f"Mini Cheetah MSE vs Time")
265
269
# import wandb
266
270
# wandb.login()
267
271
# api = wandb.Api()
268
272
# project_path = "dls-csml/mini_cheetah"
269
- # group_name = "mse_vs_time "
270
- # metric_name = "state_pred_loss_t/test "
273
+ # group_name = "mse_vs_time_final "
274
+ # metric_name = "state_pred_loss_t"
271
275
# runs = api.runs(project_path, {"$and": [{"group": group_name}] })
272
276
# print(f"Found {len(runs)} runs for group {group_name}")
273
-
274
- df = pd .read_csv (data_path )
275
- # MSE vs time needs a bit of reformat
276
- df_reformatted = pd .DataFrame ()
277
- for col in df .columns :
278
- if "time" in col or "step" in col : continue
279
- run_name , var_name = col .split (" - " )
280
- print (var_name )
281
- # var_name = var_name.split("__")[0]
282
- if not var_name == "state_pred_loss_t/test" : continue
283
- model_name = "DAE-AUG" if "DAE-AUG" in run_name else "E-DAE" if "E-DAE" in col else "DAE"
284
- system_group = "K4xC2" if "K4xC2" in run_name else "K4" if "K4" in run_name else "C2"
285
- df_run = pd .DataFrame ({"Name" : run_name , "model.name" : model_name , "system.group" : system_group ,
286
- "time" : df ["time" ], var_name : df [col ]})
287
- df_reformatted = pd .concat ([df_reformatted , df_run ], axis = 0 )
288
-
289
- fig = plot_aggregated_lines (df_reformatted ,
290
- x = "time" ,
291
- y = "state_pred_loss_t/test" ,
292
- group_variables = ['model.name' , 'system.group' ],
293
- line_styles = line_styles ,
294
- color_group = color_group ,
295
- area_metric = "std" ,
296
- label_replace = {"model.name" : "Model" ,
297
- "system.group" : "Group" ,
298
- "C2" : r"$\mathbb{G}_{\Omega}=\mathbb{C}_2$" ,
299
- "K4" : r"$\mathbb{G}_{\Omega}=\mathbb{K}_4$" ,
300
- "K4xC2" : r"$\mathbb{G}_{\Omega}=\mathbb{K}_4 \times \mathbb{"
301
- r"C}_2$" , }
302
- )
303
- # Determine the range of your data on a logarithmic scale
304
- # y_min, y_max = 0.1, 6.0
305
- # log_min, log_max = np.log10(y_min), np.log10(y_max)
306
- # tickvals = [np.round(10 ** x, 2) for x in np.arange(log_min, log_max, 0.1)] # Adjust step for more granularity
307
- # tickvals = sorted(list(set(tickvals + [y_min, y_max]))) # Ensure start and end values are included
308
- # # Format tick labels
309
- # ticktext = [f'{val:.1f}' for val in tickvals]
310
-
311
- # Set the figure size to a quarter of an A4 page
312
- fig .update_layout (** layout_config )
313
- fig .update_xaxes (title_text = "Prediction horizon [s]" )
314
- fig .update_yaxes (title_text = "state prediction MSE" ,
315
- type = "log" ,
316
- )
317
- # fig.show()
318
- fig .write_html (data_path .with_suffix (".html" ))
319
- fig .write_image (data_path .with_suffix (".svg" ))
320
- fig .write_image (data_path .with_suffix (".png" ))
321
- # fig.show()
277
+ # df = pd.DataFrame()
278
+ # download_path = Path("./artifacts")
279
+ # # Iterate over each run
280
+ # for i, run in tqdm(enumerate(list(runs))):
281
+ # # Access the list of artifacts for the run
282
+ # artifacts = run.logged_artifacts()
283
+ # print(f"Run {i}")
284
+ # df_run = None
285
+ # for artifact in artifacts:
286
+ # if metric_name in artifact.name:
287
+ # # Construct the unique path for this run's artifact
288
+ # artifact_file_path = download_path / f"{artifact.name}.json"
289
+ # # Check if the file already exists
290
+ # if os.path.exists(artifact_file_path):
291
+ # print(f"Artifact already downloaded: {artifact_file_path}")
292
+ # else:
293
+ # print(f"Downloading artifact to : {artifact_file_path}")
294
+ # # Download the artifact
295
+ # table_dir = artifact.download()
296
+ # table_files = list(Path(table_dir).rglob('*test.table.json'))
297
+ # if len(table_files) == 0:
298
+ # print(f"Run {run.id} {run.name} did not save the state_pred_loss_t table")
299
+ # continue
300
+ # table_path = table_files[0]
301
+ # # Move the file to the specified base directory
302
+ # os.rename(table_path, artifact_file_path)
303
+ #
304
+ # with artifact_file_path.open('r') as file:
305
+ # json_dict = json.load(file)
306
+ # df_run = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
307
+ # break
308
+ # if df_run is not None:
309
+ # # Search in this run config values for model.name and system.group values and append to df
310
+ # config = run.config
311
+ # model_name = config["model"]["name"]
312
+ # system_group = config["system"]["group"]
313
+ # df_run = df_run.assign(**{"model.name": model_name, "system.group": system_group, "Name": run.name})
314
+ #
315
+ # df = pd.concat([df, df_run], axis=0)
316
+ # else:
317
+ # print(f"Run {run.id} {run.name} did not save the state_pred_loss_t table")
318
+ #
319
+ # # Ignore all records of group = C2
320
+ # df = df[df["system.group"] != "C2"]
321
+ # fig = plot_aggregated_lines(df,
322
+ # x="time",
323
+ # y="state_pred_loss_t/test",
324
+ # group_variables=['model.name', 'system.group'],
325
+ # line_styles=line_styles,
326
+ # color_group=color_group,
327
+ # area_metric="std",
328
+ # label_replace={"model.name": "Model",
329
+ # "system.group": "Group",
330
+ # "K4": r"$\mathbb{G}=\mathbb{K}_4$",
331
+ # "K4xC2": r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
332
+ # r"C}_2$", }
333
+ # )
334
+ # # Set the figure size to a quarter of an A4 page
335
+ # fig.update_layout(**layout_config)
336
+ # fig.update_xaxes(title_text="Prediction horizon [s]")
337
+ # fig.update_yaxes(title_text="state prediction MSE",
338
+ # type="log",
339
+ # )
340
+ # # fig.show()
341
+ # data_path = Path("mini_cheetah_mse_vs_time.csv")
342
+ # fig.write_html(data_path.with_suffix(".html"))
343
+ # fig.write_image(data_path.with_suffix(".svg"))
344
+ # fig.write_image(data_path.with_suffix(".png"))
345
+ # # fig.show()
322
346
# ==========================================================================
323
347
# ========================LINEAR EXPERIMENT ================================
324
348
# ==========================================================================
@@ -350,8 +374,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
350
374
line_width = 3 ,
351
375
label_replace = {"model.name" : "Model" ,
352
376
"system.group" : "Group" ,
353
- "C5" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_5$" ,
354
- "C10" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_{10}$" ,
377
+ "C5" : r"$\large \mathbb{G}=\mathbb{C}_5$" ,
378
+ "C10" : r"$\large \mathbb{G}=\mathbb{C}_{10}$" ,
355
379
}
356
380
357
381
)
@@ -380,8 +404,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
380
404
line_width = 3 ,
381
405
label_replace = {"model.name" : "Model" ,
382
406
"system.group" : "Group" ,
383
- "C5" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_5$" ,
384
- "C10" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_{10}$" ,
407
+ "C5" : r"$\large \mathbb{G}=\mathbb{C}_5$" ,
408
+ "C10" : r"$\large \mathbb{G}=\mathbb{C}_{10}$" ,
385
409
}
386
410
387
411
)
@@ -408,8 +432,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
408
432
line_width = 3 ,
409
433
label_replace = {"model.name" : "Model" ,
410
434
"system.group" : "Group" ,
411
- "C5" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_5$" ,
412
- "C10" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_{10}$" ,
435
+ "C5" : r"$\large \mathbb{G}=\mathbb{C}_5$" ,
436
+ "C10" : r"$\large \mathbb{G}=\mathbb{C}_{10}$" ,
413
437
}
414
438
415
439
)
@@ -437,8 +461,8 @@ def plot_aggregated_lines(df: pd.DataFrame,
437
461
line_width = 3 ,
438
462
label_replace = {"model.name" : "Model" ,
439
463
"system.group" : "Group" ,
440
- "C5" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_5$" ,
441
- "C10" : r"$\large \mathbb{G}_{\Omega} =\mathbb{C}_{10}$" ,
464
+ "C5" : r"$\large \mathbb{G}=\mathbb{C}_5$" ,
465
+ "C10" : r"$\large \mathbb{G}=\mathbb{C}_{10}$" ,
442
466
}
443
467
444
468
)
@@ -479,9 +503,9 @@ def plot_aggregated_lines(df: pd.DataFrame,
479
503
area_metric = "std" ,
480
504
label_replace = {"model.name" : "Model" ,
481
505
"system.group" : "Group" ,
482
- "C2" : r"$\mathbb{G}_{\Omega} =\mathbb{C}_2$" ,
483
- "K4" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4$" ,
484
- "K4xC2" : r"$\mathbb{G}_{\Omega} =\mathbb{K}_4 \times \mathbb{"
506
+ "C2" : r"$\mathbb{G}=\mathbb{C}_2$" ,
507
+ "K4" : r"$\mathbb{G}=\mathbb{K}_4$" ,
508
+ "K4xC2" : r"$\mathbb{G}=\mathbb{K}_4 \times \mathbb{"
485
509
r"C}_2$" , }
486
510
)
487
511
0 commit comments