-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_independent.py
94 lines (71 loc) · 2.62 KB
/
train_independent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#
# Copyright (C) 2022 Apple Inc. All rights reserved.
#
from typing import Dict
from argparse import ArgumentParser
import yaml
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
import torch.nn as nn
from trainers import BackboneTrainer
from dataset import SubImageCIFAR100
from utils.net_utils import LabelSmoothing
from utils.getters import get_model, get_optimizer
from utils.schedulers import get_policy
from accelerate import Accelerator
def main(config: Dict) -> None:
"""Run training.
:param config: A dictionary with all configurations to run training.
:return:
"""
torch.backends.cudnn.benchmark = True
model = get_model(config.get('arch_params'))
accelerator = Accelerator()
device = accelerator.device
if torch.cuda.is_available():
model = torch.nn.DataParallel(model)
model.to(device)
trainer = BackboneTrainer()
optimizer = get_optimizer(model, **config.get('optimizer_params'))
data = SubImageCIFAR100(**config.get('dataset_params'))
lr_policy = get_policy(optimizer, **config.get('lr_policy_params'))
if config.get('label_smoothing') is None:
criterion = nn.CrossEntropyLoss()
else:
criterion = LabelSmoothing(smoothing=config.get('label_smoothing'))
train_loader = data.train_loader
val_loader = data.val_loader
# Training loop
model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, \
train_loader, val_loader)
for epoch in range(config.get('epochs')):
lr_policy(epoch, iteration=None)
train_acc1, train_acc5, train_loss = trainer.train(
train_loader=train_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
device=device,
accelerator = accelerator
)
print(
"Train: epoch = {}, Loss = {}, Top 1 = {}, Top 5 = {}".format(
epoch, train_loss, train_acc1, train_acc5
))
if (epoch+1) % 1 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, config.get('output_model_path'))
model = accelerator.prepare(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--config', type=str, required=True,
help='Path to config file for this pipeline.')
args = parser.parse_args()
with open(args.config) as f:
read_config = yaml.safe_load(f)
main(read_config)