-
Notifications
You must be signed in to change notification settings - Fork 8
/
main_pl.py
229 lines (184 loc) · 6.89 KB
/
main_pl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#!/usr/bin/env python
import omegaconf
import time
import wandb
from hydra.core.hydra_config import HydraConfig
import logging
from omegaconf import DictConfig, OmegaConf, open_dict
import os
from rich.traceback import install
install()
import hydra
import lightning.pytorch as pl
import torch.utils.data
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import StochasticWeightAveraging, RichProgressBar, ModelSummary
from lightning.pytorch.loggers import WandbLogger
from exp_utils import (
load_from_checkpoint,
plot_distribution,
)
from models_pl import SpnDiscriminative, SpnGenerative
from simple_einet.dist import Dist
from simple_einet.data import build_dataloader
from simple_einet.sampling_utils import init_einet_stats
# A logger for this file
logger = logging.getLogger(__name__)
def main(cfg: DictConfig):
"""
Main function for training and evaluating an Einet.
Args:
cfg: Config file.
"""
preprocess_cfg(cfg)
# Get hydra config
hydra_cfg = HydraConfig.get()
run_dir = hydra_cfg.runtime.output_dir
logger.info("Working directory : {}".format(os.getcwd()))
# Save config
with open(os.path.join(run_dir, "config.yaml"), "w") as f:
OmegaConf.save(config=cfg, f=f)
# Safe run_dir in config (use open_dict to make config writable)
with open_dict(cfg):
cfg.run_dir = run_dir
logger.info("\n" + OmegaConf.to_yaml(cfg, resolve=True))
logger.info("Run dir: " + run_dir)
if not cfg.wandb:
os.environ["WANDB_MODE"] = "offline"
# Ensure that everything is properly seeded
seed_everything(cfg.seed, workers=True)
# Setup devices
if torch.cuda.is_available():
accelerator = "gpu"
if type(cfg.gpu) == int:
devices = [int(cfg.gpu)]
else:
devices = [int(g) for g in cfg.gpu]
else:
accelerator = "cpu"
devices = 1
logger.info("Training model...")
# Create dataloader
normalize = cfg.dist in [Dist.NORMAL, Dist.NORMAL_RAT, Dist.MULTIVARIATE_NORMAL]
train_loader, val_loader, test_loader = build_dataloader(
dataset_name=cfg.dataset,
data_dir=cfg.data_dir,
batch_size=cfg.batch_size,
num_workers=min(cfg.num_workers, os.cpu_count()),
loop=False,
normalize=normalize,
seed=cfg.seed,
)
# Create callbacks
cfg_container = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
logger_wandb = WandbLogger(
name=cfg.tag,
project=cfg.project_name,
group=cfg.group_tag,
offline=not cfg.wandb,
config=cfg_container,
reinit=True,
save_dir=run_dir,
settings=wandb.Settings(start_method="thread"),
)
# Load or create model
if cfg.load_and_eval:
model = load_from_checkpoint(
run_dir,
load_fn=SpnGenerative.load_from_checkpoint,
args=cfg,
)
else:
if cfg.classification:
model = SpnDiscriminative(cfg, steps_per_epoch=len(train_loader))
else:
model = SpnGenerative(cfg, steps_per_epoch=len(train_loader))
if cfg.torch_compile: # Doesn't seem to work with einsum yet
# Rase an error since einsum doesn't seem to work with compilation yet
# model = torch.compile(model)
raise NotImplementedError("Torch compilation not yet supported with einsum.")
if cfg.mixture:
# If we chose a mixture of einets, we need to initialize the mixture weights
logger.info("Initializing Einet mixture weights")
model.spn.initialize(dataloader=train_loader, device=devices[0])
if cfg.init_leaf_data:
logger.info("Initializing leaf distributions from data statistics")
init_einet_stats(model.spn, train_loader)
# Setup callbacks
callbacks = []
# Store number of model parameters
summary = ModelSummary(max_depth=-1)
callbacks.append(summary)
# Add StochasticWeightAveraging callback
if cfg.swa:
swa_callback = StochasticWeightAveraging()
callbacks.append(swa_callback)
# Enable rich progress bar
if not cfg.debug:
# Cannot "breakpoint()" in the training loop when RichProgressBar is active
callbacks.append(RichProgressBar())
# Create trainer
trainer = pl.Trainer(
max_epochs=cfg.epochs,
logger=logger_wandb,
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
precision=cfg.precision,
fast_dev_run=cfg.debug,
profiler=cfg.profiler,
default_root_dir=run_dir,
enable_checkpointing=False,
detect_anomaly=cfg.debug,
)
if not cfg.load_and_eval:
# Fit model
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info("Evaluating model...")
if "synth" in cfg.dataset and not cfg.classification:
plot_distribution(model=model.spn, dataset_name=cfg.dataset, logger_wandb=logger_wandb)
# Evaluate spn reconstruction error
trainer.test(model=model, dataloaders=[train_loader, val_loader, test_loader], verbose=True)
logger.info("Finished evaluation...")
# Save checkpoint in general models directory to be used across experiments
chpt_path = os.path.join(run_dir, "model.pt")
logger.info("Saving checkpoint: " + chpt_path)
trainer.save_checkpoint(chpt_path)
def preprocess_cfg(cfg: DictConfig):
"""
Preprocesses the config file.
Replace defaults if not set (such as data/results dir).
Args:
cfg: Config file.
"""
home = os.getenv("HOME")
# If results dir is not set, get from ENV, else take ~/data
if "data_dir" not in cfg:
cfg.data_dir = os.getenv("DATA_DIR", os.path.join(home, "data"))
# If results dir is not set, get from ENV, else take ~/results
if "results_dir" not in cfg:
cfg.results_dir = os.getenv("RESULTS_DIR", os.path.join(home, "results"))
# If FP16/FP32 is given, convert to int (else it's "bf16", keep string)
if cfg.precision == "16" or cfg.precision == "32":
cfg.precision = int(cfg.precision)
if "profiler" not in cfg:
cfg.profiler = None # Accepted by PyTorch Lightning Trainer class
if "tag" not in cfg:
cfg.tag = None
if "group_tag" not in cfg:
cfg.group_tag = None
if "seed" not in cfg:
cfg.seed = int(time.time())
# Convert dist string to enum
cfg.dist = Dist[cfg.dist.upper()]
@hydra.main(version_base=None, config_path="./conf", config_name="config")
def main_hydra(cfg: DictConfig):
try:
main(cfg)
except Exception as e:
logging.critical(e, exc_info=True) # log exception info at CRITICAL log level
finally:
# Close wandb instance. Necessary for hydra multi-runs where main() is called multipel times
wandb.finish()
if __name__ == "__main__":
main_hydra()