forked from rosecao113/MyUnet_3Dseg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_evaluation.py
78 lines (64 loc) · 2.51 KB
/
unet_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging
import sys
import tempfile
import torch
from ignite.engine import Engine
import monai
from monai.data import decollate_batch
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNet
from monai.transforms import (
AsDiscrete,
Compose,
SaveImage,
)
from data import setup_data
def main(tempdir):
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# create same Unet as training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(
spatial_dims=3,
in_channels=1,
out_channels=5,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
# sliding window size and batch size for windows inference
roi_size = (64, 64, 64)
sw_batch_size = 4
# post transforms to generate mask
post_trans = Compose([AsDiscrete(argmax=True)])
# save mask to ./tempdir
save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")
def _sliding_window_processor(engine, batch):
net.eval()
with torch.no_grad():
val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
for seg_prob in seg_probs:
save_image(seg_prob)
return seg_probs, val_labels
evaluator = Engine(_sliding_window_processor)
# evaluation metric
MeanDice().attach(evaluator, "Mean_Dice")
# StatsHandler prints loss at every iteration and print metrics at every epoch,
# no need to print loss for evaluator, just metrics
val_stats_handler = StatsHandler(
name="evaluator",
output_transform=lambda x: None, # no need to print loss, so disable per iteration output
)
val_stats_handler.attach(evaluator)
# load the model trained by "unet_training"
CheckpointLoader(load_path="./runs_dict/net_checkpoint_600.pt", load_dict={"net": net}).attach(evaluator)
# sliding window inference for one image at every iteration
val_loader = setup_data(data_dir='/data/to_huairuo/ski10train/val_nii', train=False)
state = evaluator.run(val_loader)
print(state)
if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tempdir:
main(tempdir)