-
Notifications
You must be signed in to change notification settings - Fork 2
/
make_final_result.py
111 lines (90 loc) · 3.26 KB
/
make_final_result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import random
from pathlib import Path
from collections import defaultdict
import torch
import numpy as np
import argparse
import timm
import segmentation_models_pytorch as smp
import openslide
from inference.pni_segmentation import SegInferer
def random_seed(seed_value, use_cuda):
np.random.seed(seed_value)
torch.manual_seed(seed_value)
random.seed(seed_value)
if use_cuda:
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # gpu vars\n
torch.backends.cudnn.deterministic = True # needed\n
torch.backends.cudnn.benchmark = False
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ["yes", "true", "t", "y", "1"]:
return True
elif v.lower() in ["no", "false", "f", "n", "0"]:
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def parse_args():
parser = argparse.ArgumentParser(
description="Train Model Organ Specific for Probability Map"
)
parser.add_argument("--root_dir", type=str, help="Whole Slide Images Directory")
parser.add_argument("--result_dir", type=str, help="Model & Result Directory")
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--organ", type=str, help="col, pros, pan")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--use_gpu", type=str2bool, default=True)
return parser.parse_args()
def main():
args = parse_args()
random_seed(args.seed, True)
if args.use_gpu:
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
else:
device = torch.device("cpu")
model_path = defaultdict(str)
for organ in ["col", "pros", "pan"]:
key_ = organ + "_" + "1"
model_path[key_] = os.path.join(
args.result_dir, "seg", organ, f"level_1", "checkpoint.pt"
)
wsi_fns = []
hmaps = []
key_ = args.organ + "_" + "1"
model = smp.Unet(
encoder_name="timm-efficientnet-b0",
encoder_weights="noisy-student",
in_channels=3,
classes=2,
)
model.load_state_dict(torch.load(model_path[key_]))
model = model.to(device).eval()
inferer = SegInferer(args, model, device)
for wsi in sorted(os.listdir(args.root_dir)):
if wsi.split(".")[-1] == "svs":
if args.organ.title() == wsi:
wsi_fns.append(os.path.join(args.root_dir, wsi))
for hmap in sorted(
os.listdir(os.path.join(args.result_dir, "probmap", args.organ))
):
if hmap.split(".")[-1] == "npy":
hmaps.append(os.path.join(args.result_dir, "probmap", args.organ, hmap))
for wsi_fn, hmap in zip(wsi_fns, hmaps):
name = wsi_fn.split("/")[-1].split(".")[0]
slide = openslide.OpenSlide(wsi_fn)
overlay = np.load(hmap)
overlay = np.where(overlay >= 0.5, 1, 0)
overlay = overlay.astype(np.uint8)
result = inferer.read_wsi_seg(slide, overlay)
ppath = Path(
os.path.join(args.result_dir, "final_result", args.organ, f"{name}.npy")
)
ppath.parent.mkdir(parents=True, exist_ok=True)
np.save(str(ppath), result)
if __name__ == "__main__":
main()