-
Notifications
You must be signed in to change notification settings - Fork 4
/
trainer.py
136 lines (106 loc) · 4.86 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import random
import easytorch.vision.imageutils as imgutils
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as tmf
from PIL import Image as IMG
from easytorch import ETTrainer, Prf1a, ETMeter
from easytorch.vision import (merge_patches)
from imagedataset2d import BinaryPatchDataset
from easytorch.vision.transforms import RandomGaussJitter
from models import UNet
sep = os.sep
class BinarySemSegImgPatchDatasetCustomTransform(BinaryPatchDataset):
def get_transforms(self):
if self.mode == "test":
return tmf.Compose([tmf.ToPILImage(), tmf.ToTensor()])
_tf = [
tmf.ToPILImage(),
tmf.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
tmf.RandomAutocontrast(),
RandomGaussJitter(0.3, 0.5),
tmf.ToTensor()
]
return tmf.Compose(_tf)
def __getitem__(self, index):
dname, file, row_from, row_to, col_from, col_to, cache_key = self.indices[index]
obj = self.diskcache.get(cache_key)
img = obj.array[:, :, 1] # Only Green Channel
gt = obj.ground_truth[row_from:row_to, col_from:col_to]
p, q, r, s, pad = imgutils.expand_and_mirror_patch(
img.shape,
[row_from, row_to, col_from, col_to],
self.dataspecs[dname]['expand_by']
)
if len(img.shape) == 3:
pad = [*pad, (0, 0)]
img = np.pad(img[p:q, r:s], pad, 'reflect')
if self.mode == 'train' and random.uniform(0, 1) <= 0.5:
img = np.flip(img, 0)
gt = np.flip(gt, 0)
if self.mode == 'train' and random.uniform(0, 1) <= 0.5:
img = np.flip(img, 1)
gt = np.flip(gt, 1)
img = self.transforms(img)
gt = self.pil_to_tensor(gt)
return {'indices': self.indices[index], 'input': img, 'label': gt.squeeze()}
class VesselSegTrainer(ETTrainer):
def _init_nn_model(self):
self.nn['model'] = UNet(self.args['num_channel'], self.args['num_class'], reduce_by=self.args['model_scale'])
def iteration(self, batch, **kw):
r"""
:param batch:
:return: dict with keys - loss(computation graph), averages, output, metrics, predictions
"""
inputs = batch['input'].to(self.device['gpu']).float()
labels = batch['label'].to(self.device['gpu']).long()
out = self.nn['model'](inputs)
wt = None
if self.args.get('random_class_weights') is not None:
wt = torch.randint(1, 101, (self.args['num_class'],), device=self.device['gpu']).float()
elif self.args.get('class_weights') is not None:
wt = self.cache.setdefault('class_weights', torch.from_numpy(
np.array(self.args.get('class_weights'))
).float().to(self.device['gpu']))
loss = F.cross_entropy(out, labels, weight=wt)
out = F.softmax(out, 1)
_, pred = torch.max(out, 1)
meter = self.new_meter()
meter.averages.add(loss.item(), len(inputs))
if self.args['num_class'] == 2:
meter.metrics['prf1a'].add(pred, labels.float())
else:
meter.metrics['cfm'].add(pred, labels.float())
return {'loss': loss, 'output': out, 'meter': meter, 'predictions': pred, 'labels': labels}
def save_predictions(self, dataset, its):
if not self.args.get('load_sparse'):
return None
"""load_sparse option in default params loads patches of single image in one dataloader.
This enables to merge them safely to form the whole image """
dname, file, cache_key = dataset.indices[0][0], dataset.indices[0][1], dataset.indices[0][-1]
dspec = dataset.dataspecs[dname]
obj = dataset.diskcache.get(cache_key)
"""
Auto gather all the predicted patches of one image and merge together by calling as follows."""
img_shape = obj.array.shape[:2]
patches = its['output']()[:, 1, :, :].cpu().numpy() * 255
img = merge_patches(patches, img_shape, dspec['patch_shape'], dspec['patch_offset'])
_dset_dir = self.cache['log_dir']
if self.args.get('pooled_run'):
_dset_dir = f"{_dset_dir}{sep}{dspec['name']}"
os.makedirs(_dset_dir, exist_ok=True)
IMG.fromarray(img).save(_dset_dir + sep + file.split('.')[0] + '.png')
patches = its['predictions']().cpu().numpy() * 255
pred = merge_patches(patches, img_shape, dspec['patch_shape'], dspec['patch_offset'])
sc = Prf1a()
sc.add(torch.Tensor(pred), torch.Tensor(obj.ground_truth))
return ETMeter(prf1a=sc)
def init_experiment_cache(self):
self.cache.update(monitor_metric='f1', metric_direction='maximize')
self.cache.update(log_header='Loss|Accuracy,F1,Precision,Recall')
def new_meter(self):
return ETMeter(
prf1a=Prf1a()
)