-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·87 lines (71 loc) · 3.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 23 17:51:56 2020
@author: melike
"""
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import os.path as osp
import time
import argparse
from model import APNet
from dataset import RSDataset
import constants as C
"""
Usage:
python train.py --name reykjavik \
--split horizontal
"""
# parser = argparse.ArgumentParser(description='Train a APNet model') # Parse command-line arguments
# parser.add_argument('--name', required=True)
# parser.add_argument('--split', required=True)
# args = parser.parse_args()
train_transforms = { # Transformations to be applied on the train set to augment data.
'original': None,
'hor': transforms.Compose([transforms.RandomHorizontalFlip(p=1)]),
'ver': transforms.Compose([transforms.RandomVerticalFlip(p=1)]),
'mirror': transforms.Compose([transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1)]),
'rot45': transforms.Compose([transforms.RandomRotation(degrees=[45, 45])]),
'rot135': transforms.Compose([transforms.RandomRotation(degrees=[135, 135])])
}
def train(model, criterion, optimizer, model_name, train_loader, max_epochs, device, loss_file): # Saves best model and last epoch model.
is_better = True
best_loss = float('inf')
model.train()
for epoch in range(max_epochs):
batch_loss = 0
t_start = time.time()
for batch_samples, batch_labels in train_loader:
batch_samples, batch_labels = batch_samples.to(device), batch_labels.to(device)
optimizer.zero_grad() # Set grads to zero
output = model(batch_samples) # Feed input to model
loss = criterion(output, batch_labels) # Calculate loss
loss.backward() # Calculate grads via backprop
optimizer.step() # Update weights
batch_loss += loss.item()
delta = time.time() - t_start
is_better = batch_loss < best_loss
if is_better:
best_loss = batch_loss
torch.save(model.state_dict(), osp.join(C.MODEL_DIR, model_name + 'best.pth'))
msg = "Epoch #{}\tLoss: {:.4f}\t Time: {:.2f} seconds".format(epoch, batch_loss, delta)
if epoch % 10 == 0:
print(msg)
loss_file.write(msg + "\n")
torch.save(model.state_dict(), osp.join(C.MODEL_DIR, model_name + 'last_epoch.pth'))
if __name__ == "__main__":
use_cuda = torch.cuda.is_available() # Use GPU if available
device = torch.device("cuda:0" if use_cuda else "cpu")
params = {'batch_size': 50, # Training parameters
'shuffle': True,
'num_workers': 4}
max_epochs = 100
train_set = RSDataset(name='pavia_full', mode='train', split='original')
train_loader = DataLoader(train_set, **params)
model = APNet(*(train_set[0][0].shape), num_classes=train_set.num_classes).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
train(model, criterion, optimizer, train_set.get_model_name())