-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
156 lines (112 loc) · 4.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
import pathlib
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import call, instantiate
from omegaconf import DictConfig, OmegaConf
# A logger for this file
log = logging.getLogger(__name__)
import os
import random
import string
import sys
import pdb
import wandb
from hydra.utils import get_original_cwd
from tqdm import tqdm
from imfas.util import print_cfg, seed_everything, train_test_split
import pandas as pd
from sklearn.experimental import enable_halving_search_cv
from imfas.losses.ranking_loss import spear_halve_loss
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
return "".join(random.choice(chars) for _ in range(size))
base_dir = os.getcwd()
@hydra.main(config_path="configs", config_name="base")
def pipe_train(cfg: DictConfig) -> None:
sys.path.append(os.getcwd())
sys.path.append("..")
print("base_dir: ", base_dir)
dict_cfg = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
print_cfg(cfg)
hydra_job = (
os.path.basename(os.path.abspath(os.path.join(HydraConfig.get().run.dir, "..")))
+ "_"
+ os.path.basename(HydraConfig.get().run.dir)
)
hydra_config = HydraConfig.get()
log.info(get_original_cwd())
cfg.wandb.id = hydra_job + "_" + id_generator()
run = wandb.init(**cfg.wandb, config=dict_cfg)
hydra_cfg = HydraConfig.get()
command = f"{hydra_cfg.job.name}.py " + " ".join(hydra_cfg.overrides.task)
if not OmegaConf.is_missing(hydra_cfg.job, "id"):
slurm_id = hydra_cfg.job.id
else:
slurm_id = None
wandb.config.update({"command": command, "slurm_id": slurm_id})
orig_cwd = hydra.utils.get_original_cwd()
# logging TODO add logging to each step of the way.
log.info("Hydra initialized a new config_raw")
log.debug(str(cfg))
seed_everything(cfg.seed)
dir_data = pathlib.Path(cfg.dataset_raw.dir_data)
dir_raw = dir_data / "raw"
dir_dataset_raw = dir_data / "raw" / cfg.dataset_raw.dataset_name
# optionally download / resubset the dataset
if cfg.dataset_raw.enable:
call(cfg.dataset_raw, _recursive_=False)
dataset_meta_features = instantiate(cfg.dataset.dataset_meta)
# train test split by dataset major
train_split, test_split = train_test_split(
len(dataset_meta_features), # todo refactor - needs to be aware of dropped meta features
cfg.dataset.split,
)
# Create the dataloaders
train_set = instantiate(cfg.dataset.dataset_class, split=train_split)
test_set = instantiate(cfg.dataset.dataset_class, split=test_split)
train_loader = instantiate(cfg.dataset.dataloader_class, dataset=train_set)
test_loader = instantiate(cfg.dataset.dataloader_class, dataset=test_set)
# update the input dims and number of algos based on the sampled stuff
# if "n_algos" not in cfg.dataset_raw.keys() and cfg.dataset.name != "LCBench":
if not cfg.model._target_.split(".")[-1] == "HalvingGridSearchCV":
input_dim = dataset_meta_features.df.columns.size
n_algos = len(train_set.lc.index)
wandb.config.update({"n_algos": n_algos, "input_dim": input_dim})
model = instantiate(cfg.model, input_dim=input_dim, algo_dim=n_algos)
valid_score = call(
cfg.training,
model,
train_dataloader=train_loader,
test_dataloader=test_loader,
_recursive_=False,
)
else:
if cfg.dataset.name == "LCBench":
cfg.model.param_grid.algo_id = list(range(len(train_set.lc.index)))
enable_halving_search_cv # ensures import is not removed in alt + L reformatting
# model.estimator.slices.split == test_split --this way datasets are parallel in seeds
spears = {}
for d in tqdm(test_split):
# indexed with 0 and slices.split holds the relevant data id already!
cfg.model.estimator.slices.split = [d]
model = instantiate(cfg.model, _convert_="partial")
# fixme: validation score should not be computed during training!
valid_score = call(
cfg.training,
model,
train_dataloader=train_loader,
test_dataloader=test_loader,
_recursive_=False,
)
final_performances = test_set.lc.transformed_df[d][-1]
spears[d] = spear_halve_loss(valid_score, final_performances).numpy()
# fixme: spearman is a constant for all test datasets.
d = pd.DataFrame.from_dict(spears, orient="index")
print(d)
if cfg.dataset.name == "LCBench":
name = "LCBench_raw"
else:
name = cfg.dataset_raw.bench
d.to_csv(f"halving_test_spear_{name}_{cfg.seed}.csv")
if __name__ == "__main__":
pipe_train()