-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
69 lines (56 loc) · 2.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import time
import torch
from tqdm import tqdm
def evaluate(model, loader, criterion, device, mean, mad):
"""Evaluate the model on the validation/test set.
Args:
model (torch.nn.Module): The model to evaluate.
loader (torch.utils.data.DataLoader): The validation set loader.
criterion (torch.nn.Module): The loss function.
device (torch.device): The device to use.
mean (float): The mean of the training set.
mad (float): The mean absolute deviation of the training set.
Returns:
float: The mean absolute error on the validation set.
"""
mae = 0.0
model.eval()
for _, batch in enumerate(tqdm(loader)):
batch = batch.to(device)
target = torch.squeeze(batch.y[:, 1])
pred = model(batch)
loss = criterion(pred * mad + mean, target)
mae += loss.item()
return mae / len(loader.dataset)
def train(model, loader, criterion,
optimizer, device,
mean, mad):
"""Train the model on the training set.
Args:
model (torch.nn.Module): The model to train.
loader (torch.utils.data.DataLoader): The training set loader.
criterion (torch.nn.Module): The loss function.
optimizer (torch.optim.Optimizer): The optimizer.
device (torch.device): The device to use.
mean (float): The mean of the training set.
mad (float): The mean absolute deviation of the training set.
Returns:
float: The mean absolute error on the training set.
"""
mae = 0.0
model.train()
for _, batch in enumerate(tqdm(loader)):
batch = batch.to(device)
target = torch.squeeze(batch.y[:, 1]).to(device)
# Perform forward pass
pred = model(batch)
# Calculate train loss
loss = criterion(pred, (target - mean) / mad)
mae += criterion(pred * mad + mean, target).item()
# Delete info on previous gradients
optimizer.zero_grad()
# Propagate & optimizer step
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return mae / len(loader.dataset)