-
Notifications
You must be signed in to change notification settings - Fork 0
/
seq_dataset.py
106 lines (91 loc) · 3.54 KB
/
seq_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
import pydicom
import numpy as np
import torch
from typing import Callable, Optional, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class SyntaxDataset(Dataset):
def __init__(
self,
root: str, # dataset dir
meta: str, # metadata
train: bool, # training mode
length: int, # video length
label: str, # label field name
artery: str, # left or right artery
inference: bool = False,
transform: Optional[Callable] = None,
) -> None:
self.root = root
self.train = train
self.length = length
self.label = label
self.artery = artery
self.inference = inference
self.transform = transform
with open(os.path.join(root, meta)) as f:
dataset = json.load(f)
if not self.inference:
dataset = [rec for rec in dataset if len(rec[f"videos_{artery}"]) > 0]
if self.train:
self.dataset = [rec for rec in dataset if rec[f"syntax_{artery}"] > 0]
self.negative_dataset = [rec for rec in dataset if rec[f"syntax_{artery}"] == 0]
assert len(self.dataset) + len(self.negative_dataset) == len(dataset)
for rec in self.dataset:
rec["weight"] = 1.0
for rec in self.negative_dataset:
rec["weight"] = 1.0
else:
self.dataset = dataset
self.negative_dataset = None
for rec in self.dataset:
rec["weight"] = 1.0
def __len__(self):
coef = 2 if self.negative_dataset else 1
return coef * len(self.dataset)
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
if self.negative_dataset:
if idx % 2 == 0:
idx = idx // 2
rec = self.dataset[idx]
else:
idx = torch.randint(low=0, high=len(self.negative_dataset), size=(1,))
rec = self.negative_dataset[idx]
else:
rec = self.dataset[idx]
weight = rec["weight"]
sid = rec["study_id"]
label = torch.tensor([int(rec[self.label] > 0)], dtype=torch.float32)
target = torch.tensor([np.log(1.0+rec[self.label])], dtype=torch.float32)
nv = len(rec[f"videos_{self.artery}"])
if self.inference:
if nv == 0:
return 0, label, target, weight, sid
seq = range(nv)
else:
seq = torch.randint(low=0, high=nv, size = (4,))
videos = []
for vi in seq:
video_rec = rec[f"videos_{self.artery}"][vi]
path = video_rec["path"]
full_path = os.path.join(self.root, path)
video = pydicom.dcmread(full_path).pixel_array # Time, HW or WH
while len(video) < self.length:
video = np.concatenate([video, video])
t = len(video)
if self.train:
begin = torch.randint(low=0, high=t-self.length+1, size=(1,))
end = begin + self.length
video = video[begin:end, :, :]
else:
begin = (t - self.length) // 2
end = begin + self.length
video = video[begin:end, :, :]
video = torch.tensor(np.stack([video, video, video], axis=-1))
if self.transform is not None:
video = self.transform(video)
videos.append(video)
videos = torch.stack(videos, dim=0)
return videos, label, target, weight, sid