-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy pathmain.py
155 lines (122 loc) · 6.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import print_function
import os
import sys
import time
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
from datasets import list_dataset
from datasets.ava_dataset import Ava
from core.optimization import *
from cfg import parser
from core.utils import *
from core.region_loss import RegionLoss, RegionLoss_Ava
from core.model import YOWO, get_fine_tuning_parameters
####### Load configuration arguments
# ---------------------------------------------------------------
args = parser.parse_args()
cfg = parser.load_config(args)
####### Check backup directory, create if necessary
# ---------------------------------------------------------------
if not os.path.exists(cfg.BACKUP_DIR):
os.makedirs(cfg.BACKUP_DIR)
####### Create model
# ---------------------------------------------------------------
model = YOWO(cfg)
model = model.cuda()
model = nn.DataParallel(model, device_ids=None) # in multi-gpu case
# print(model)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging('Total number of trainable parameters: {}'.format(pytorch_total_params))
seed = int(time.time())
torch.manual_seed(seed)
use_cuda = True
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # TODO: add to config e.g. 0,1,2,3
torch.cuda.manual_seed(seed)
####### Create optimizer
# ---------------------------------------------------------------
parameters = get_fine_tuning_parameters(model, cfg)
optimizer = torch.optim.Adam(parameters, lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
best_score = 0 # initialize best score
# optimizer = optim.SGD(parameters, lr=cfg.TRAIN.LEARNING_RATE/batch_size, momentum=cfg.SOLVER.MOMENTUM, dampening=0, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
####### Load resume path if necessary
# ---------------------------------------------------------------
if cfg.TRAIN.RESUME_PATH:
print("===================================================================")
print('loading checkpoint {}'.format(cfg.TRAIN.RESUME_PATH))
checkpoint = torch.load(cfg.TRAIN.RESUME_PATH)
cfg.TRAIN.BEGIN_EPOCH = checkpoint['epoch'] + 1
best_score = checkpoint['score']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("Loaded model score: ", checkpoint['score'])
print("===================================================================")
del checkpoint
####### Create backup directory if necessary
# ---------------------------------------------------------------
if not os.path.exists(cfg.BACKUP_DIR):
os.mkdir(cfg.BACKUP_DIR)
####### Data loader, training scheme and loss function are different for AVA and UCF24/JHMDB21 datasets
# ---------------------------------------------------------------
dataset = cfg.TRAIN.DATASET
assert dataset == 'ucf24' or dataset == 'jhmdb21' or dataset == 'ava', 'invalid dataset'
if dataset == 'ava':
train_dataset = Ava(cfg, split='train', only_detection=False)
test_dataset = Ava(cfg, split='val', only_detection=False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True,
num_workers=cfg.DATA_LOADER.NUM_WORKERS, drop_last=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False,
num_workers=cfg.DATA_LOADER.NUM_WORKERS, drop_last=False, pin_memory=True)
loss_module = RegionLoss_Ava(cfg).cuda()
train = getattr(sys.modules[__name__], 'train_ava')
test = getattr(sys.modules[__name__], 'test_ava')
elif dataset in ['ucf24', 'jhmdb21']:
train_dataset = list_dataset.UCF_JHMDB_Dataset(cfg.LISTDATA.BASE_PTH, cfg.LISTDATA.TRAIN_FILE, dataset=dataset,
shape=(cfg.DATA.TRAIN_CROP_SIZE, cfg.DATA.TRAIN_CROP_SIZE),
transform=transforms.Compose([transforms.ToTensor()]),
train=True, clip_duration=cfg.DATA.NUM_FRAMES, sampling_rate=cfg.DATA.SAMPLING_RATE)
test_dataset = list_dataset.UCF_JHMDB_Dataset(cfg.LISTDATA.BASE_PTH, cfg.LISTDATA.TEST_FILE, dataset=dataset,
shape=(cfg.DATA.TRAIN_CROP_SIZE, cfg.DATA.TRAIN_CROP_SIZE),
transform=transforms.Compose([transforms.ToTensor()]),
train=False, clip_duration=cfg.DATA.NUM_FRAMES, sampling_rate=cfg.DATA.SAMPLING_RATE)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size= cfg.TRAIN.BATCH_SIZE, shuffle=True,
num_workers=cfg.DATA_LOADER.NUM_WORKERS, drop_last=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size= cfg.TRAIN.BATCH_SIZE, shuffle=False,
num_workers=cfg.DATA_LOADER.NUM_WORKERS, drop_last=False, pin_memory=True)
loss_module = RegionLoss(cfg).cuda()
train = getattr(sys.modules[__name__], 'train_ucf24_jhmdb21')
test = getattr(sys.modules[__name__], 'test_ucf24_jhmdb21')
####### Training and Testing Schedule
# ---------------------------------------------------------------
if cfg.TRAIN.EVALUATE:
logging('evaluating ...')
test(cfg, 0, model, test_loader)
else:
for epoch in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH + 1):
# Adjust learning rate
lr_new = adjust_learning_rate(optimizer, epoch, cfg)
# Train and test model
logging('training at epoch %d, lr %f' % (epoch, lr_new))
train(cfg, epoch, model, train_loader, loss_module, optimizer)
logging('testing at epoch %d' % (epoch))
score = test(cfg, epoch, model, test_loader)
# Save the model to backup directory
is_best = score > best_score
if is_best:
print("New best score is achieved: ", score)
print("Previous score was: ", best_score)
best_score = score
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'score': score
}
save_checkpoint(state, is_best, cfg.BACKUP_DIR, cfg.TRAIN.DATASET, cfg.DATA.NUM_FRAMES)
logging('Weights are saved to backup directory: %s' % (cfg.BACKUP_DIR))