-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
31 lines (23 loc) · 1.25 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff
@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_pred = net(image)
assert mask_true.min() >= 0 and mask_true.max() < 2, 'True mask indices should be in [0, n_classes['
# convert to one-hot format
mask_true = F.one_hot(mask_true, 2).permute(0, 3, 1, 2).float()
mask_pred = F.one_hot(mask_pred.argmax(dim=1), 2).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
net.train()
return dice_score / max(num_val_batches, 1)