-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_interfaces.py
55 lines (36 loc) · 1.3 KB
/
train_interfaces.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import config
import torch
import numpy as np
import torch.nn as nn
from DLBio.pt_training import ITrainInterface
EVAL_THRES = 0.5
class BinarySegmentation(ITrainInterface):
def __init__(self, model):
self.model = model
self.xent_loss = nn.CrossEntropyLoss()
self.metrics = {
'acc': accuracy,
'dice': dice_score
}
def train_step(self, sample):
images, targets = sample['x'].cuda(), sample['y'].cuda()
pred = self.model(images)
loss = self.xent_loss(pred, targets)
return loss, {k: v(pred, targets) for k, v in self.metrics.items()}
def accuracy(y_pred, y_true):
_, y_pred = y_pred.max(1) # grab class predictions
return (y_pred == y_true).float().mean().item()
def dice_score(y_pred, y_true):
assert y_pred.shape[1] == 2
assert y_true.max() <= 1.
y_pred = torch.softmax(y_pred, 1)[:, 1, ...]
y_true = y_true.float()
y_pred = (y_pred > .5).float()
true_pos = (y_true * y_pred).sum()
false_pos = (y_pred - y_true).clamp(0, 1.).sum()
false_neg = (y_true - y_pred).clamp(0, 1.).sum()
dice = 2. * true_pos / (false_pos + false_neg + 2. * true_pos + 1e-9)
return dice.item()
def apply_thres(pred):
pred_thres = (pred > EVAL_THRES).astype('float32')
return pred_thres