-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_results.py
138 lines (114 loc) · 4.33 KB
/
make_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from src.reports.plots import sine_wave_hdi, sine_wave_data, variance_plot, postshape_plot, location_plot, all_hdi
from src.reports.tables import pprint_uci_bench_latex
import argparse
from src.reports.tables import pprint_mnist_latex
from src.constants import INDIST_EVAL_METRICS_NAMES
def make_plot(args):
match args.plot:
case "1d-reg":
from src.logger.reg_logger import ExpLogger
logger = ExpLogger("lppd_1d")
article_ts = "2024-09-10 14:58:42.515864"
case "collapse":
from src.logger.reg_logger import ExpLogger
logger = ExpLogger("var")
article_ts = "2024-09-30 12:29:00.384211"
if args.plot != "data":
match args.log:
case "latest":
logger.load_latest_logs()
case "article":
logger.load_logs(article_ts)
case _:
logger.load_logs(args.log)
match args.plot:
case "data":
sine_wave_data(args.format)
case "1d-reg":
all_hdi(logger, args.format)
sine_wave_hdi(logger, args.format)
case "collapse":
variance_plot(logger, add_std=False)
postshape_plot(logger, add_std=False)
location_plot(logger, add_std=False)
def load_log(logger, log_type, ats):
match log_type:
case "latest":
logger.load_latest_logs()
case "article":
logger.load_logs(ats)
case _:
logger.load_logs(args.log)
return logger
def make_table(args):
match args.table:
case "mnist":
from src.logger.class_logger import ExpLogger
l1_at = "2024-11-21 11:49:32.833326" # 1 layer MLP
l2_at = "2024-11-20 18:30:56.884221" # 2 layer mlp
ats = [l1_at, l2_at]
loggers = [load_log(ExpLogger("mnist"), args.log, at) for at in ats]
case "uci":
from src.logger.reg_logger import ExpLogger
assert args.log in ['latest', 'article'], f"timestamp not support for UCI"
aritcle_tss = [ ["2024-09-30 16:26:50.652433", "2024-08-28 13:46:40.875670"],
["2024-09-26 14:44:27.705930", "2024-08-29 07:58:36.740033"],
]
loggers = []
for ats, name in zip(aritcle_tss, ['uci_std', 'uci_gap']):
l1 = load_log(ExpLogger(name), args.log, ats[0])
l2 = load_log(ExpLogger(name), args.log, ats[1])
l1.merge(l2)
loggers.append(l1)
match args.table:
case "mnist":
latex = pprint_mnist_latex(*loggers, INDIST_EVAL_METRICS_NAMES)
case "uci":
latex = pprint_uci_bench_latex(*loggers)
print(latex)
def plot_parser(parser):
parser.prog = "Plotter"
parser.add_argument("plot", choices=["1d-reg", 'collapse'], help="...")
parser.add_argument(
"log",
nargs="?",
type=str,
default="latest",
choices=["article", "latest", "timestamp"],
help="Using `article` builds figs from downloaded logs (see README),"
"`latest` builds figs using the latest timestamp,"
"`timestamp` uses specific timestamp directory name in logs.",
)
parser.add_argument("format", nargs="?", type=str, default="png")
parser.set_defaults(func=make_plot)
def table_parser(parser):
parser.add_argument(
"table",
choices=[
"mnist",
"uci",
], # help=["Build table in E.3.1 of GGN article.", "..."]
)
parser.add_argument(
"log",
nargs="?",
type=str,
default="latest",
# choices=["article", "latest", "timestamp"],
help="Using `article` builds figs from downloaded logs (see README),"
"`latest` builds figs using the latest timestamp,"
"`timestamp` uses a timestamp directory name from logs.",
)
parser.set_defaults(func=make_table)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Visualize results",
)
subparsers = parser.add_subparsers()
plot_parser(subparsers.add_parser("plot", help="Build figs from the article."))
table_parser(subparsers.add_parser("table", help="Build tables from the article."))
args = parser.parse_args()
if hasattr(args, "func"):
args.func(args)
else:
parser.print_help()