-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
119 lines (96 loc) · 3.95 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import git
import os
import socket
import time
from weakref import proxy
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.callbacks import ModelSummary, LearningRateMonitor
from muvo.config import get_parser, get_cfg
from muvo.data.dataset import DataModule
from muvo.trainer import WorldModelTrainer
from clearml import Task, Dataset, Model
class SaveGitDiffHashCallback(pl.Callback):
def setup(self, trainer, pl_model, stage):
repo = git.Repo()
trainer.git_hash = repo.head.object.hexsha
trainer.git_diff = repo.git.diff(repo.head.commit.tree)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint['world_size'] = trainer.world_size
checkpoint['git_hash'] = trainer.git_hash
checkpoint['git_diff'] = trainer.git_diff
class MyModelCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
filename = filepath.split('/')[-1]
_checkpoint = trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
try:
torch.save(_checkpoint, filename)
except AttributeError as err:
key = "hyper_parameters"
_checkpoint.pop(key, None)
print(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
torch.save(_checkpoint, filename)
self._last_global_step_saved = trainer.global_step
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
def main():
args = get_parser().parse_args()
cfg = get_cfg(args)
# task = Task.init(project_name=cfg.CML_PROJECT, task_name=cfg.CML_TASK, task_type=cfg.CML_TYPE, tags=cfg.TAG)
# task.connect(cfg)
# cml_logger = task.get_logger()
#
# dataset_root = Dataset.get(dataset_project=cfg.CML_PROJECT,
# dataset_name=cfg.CML_DATASET,
# ).get_local_copy()
# data = DataModule(cfg, dataset_root=dataset_root)
data = DataModule(cfg)
input_model = Model(model_id='').get_local_copy() if cfg.PRETRAINED.CML_MODEL else None
# input_model = cfg.PRETRAINED.PATH
model = WorldModelTrainer(cfg.convert_to_dict(), pretrained_path=input_model)
# model = WorldModelTrainer.load_from_checkpoint(checkpoint_path=input_model)
# model.get_cml_logger(cml_logger)
save_dir = os.path.join(
cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
)
logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
callbacks = [
ModelSummary(),
SaveGitDiffHashCallback(),
LearningRateMonitor(),
MyModelCheckpoint(
save_dir, every_n_train_steps=cfg.VAL_CHECK_INTERVAL,
),
]
if cfg.LIMIT_VAL_BATCHES in [0, 1]:
limit_val_batches = float(cfg.LIMIT_VAL_BATCHES)
else:
limit_val_batches = cfg.LIMIT_VAL_BATCHES
replace_sampler_ddp = not cfg.SAMPLER.ENABLED
trainer = pl.Trainer(
# devices=cfg.GPUS,
accelerator='auto',
# strategy='ddp',
precision=cfg.PRECISION,
# sync_batchnorm=True,
max_epochs=None,
max_steps=cfg.STEPS,
callbacks=callbacks,
logger=logger,
log_every_n_steps=cfg.LOGGING_INTERVAL,
val_check_interval=cfg.VAL_CHECK_INTERVAL * cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
check_val_every_n_epoch=None,
# limit_val_batches=limit_val_batches,
limit_val_batches=3,
# use_distributed_sampler=replace_sampler_ddp,
accumulate_grad_batches=cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
num_sanity_val_steps=2,
profiler='simple',
)
trainer.fit(model, datamodule=data)
trainer.test(model, dataloaders=data.test_dataloader())
if __name__ == '__main__':
main()