-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
222 lines (179 loc) · 9.56 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
File: dataset.py
Implement the PyTorch Dataset to navigate through the images and masks in
the folder structure of the dataset.
TODO:
- Improve docstring for PlantDataset
- Implement a fail-safe in case the mask_idx is not found in the mask directory
- in the csv file the name of the column has a type in it, it should be 'id original img' instead of 'id orginal img'
"""
import os
import re
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from utils import extract_ids_from_name, extract_tag_from_name
class PlantDataset(Dataset):
"""
Dataset class to navigate through the images and masks in the folder.
The folder structure is the following:
root
├── folder
│ ├── original
│ │ ├── 0_0.jpg
│ ├── original_labeled
│ │ ├── 0
│ │ │ ├── task-0-annotation-0-0-0.jpg
│ │ │ ├── task-0-annotation-0-0-1.jpg
│ │ │ ├── task-0-annotation-0-0-2.jpg
label2id should be a dictionary with the class names and its corresponding
integer representation. For example:
label2id = {'normal': 0,
'normal_cut': 1,
'noise': 2,}
It's important that the labels name provide in label2id are equal to the
labels used in the name structure of each mask to avoid key error probelms.
"""
def __init__(self,
root,
folder,
inventary_name,
transform=None,
label2id=None,
alternative_masks=False):
""" Assuming inventary name is within the root folder """
self.root = root
self.folder = folder
self.transform = transform
self._label2id = label2id
self.alternative_masks = alternative_masks
assert label2id is not None, 'label2id parameter must be provided (e.g. {"normal": 0, "normal_cut": 1, "noise": 2}))'
# create label-to-id and id-to-label dictionaries with the
# class names and its corresponding integer representation
if self._label2id is None:
self._label2id = {'normal': 0,
'normal_cut': 1,
'noise': 2,}
self._id2label = {idx: ch for ch, idx in self._label2id.items()}
# read inventary csv file and store it in a pandas dataframe
self.idx_table = pd.read_csv(os.path.join(self.root, self.folder, inventary_name))
# self.masks is a nested list in which in each position we have 1 or more
# masks associated to the image in the same position in the images list
# same for self.labels, a nested list but with the labels (integers)
# of each mask
self.images = []
self.masks = []
self.labels = []
self.mask_idx_folder = [] # populate when iterate trough the images
# iterate over files in root + folder path + 'original'
files_in_folder = os.listdir(os.path.join(self.root, self.folder, 'original'))
# ensures that the order of the files follows the original image id and then the split id
files_in_folder.sort(key=lambda x: (int(re.findall(r'\d+', x.split('_')[0])[0]),
int(re.findall(r'\d+', x.split('_')[1])[0])))
# alternative_masks is a boolean flag to indicate if we want to add alternative masks
# provided by the dataset (computed by other method). Mainly used for evaluation purposes.
if self.alternative_masks:
self.benchmark_masks = []
alternative_masks_folder = os.listdir(os.path.join(self.root, self.folder, 'original_labeled_imageJ_mask'))
alternative_masks_folder.sort(key=lambda x: (int(re.findall(r'\d+', x.split('_')[0])[0]),
int(re.findall(r'\d+', x.split('_')[2])[0])))
for i, img in enumerate(files_in_folder):
self.images.append(img)
if self.alternative_masks:
self.benchmark_masks.append(alternative_masks_folder[i])
# Get the id of the image's mask folder from the inventary (self.idx_table)
# TODO: in the csv file the name of the column has a type in it,
# it should be 'id original img' instead of 'id orginal img'
id_original_img, id_after_split = extract_ids_from_name(img)
mask_idx = self.idx_table[(self.idx_table['id orginal img'] == id_original_img) &
(self.idx_table['id after split'] == id_after_split)]['id']._values[0]
self.mask_idx_folder.append(str(mask_idx))
# TODO: implement a fail-safe in case the mask_idx is not found in the mask directory
masks = [x for x in os.listdir(os.path.join(self.root, self.folder, 'original_labeled', str(mask_idx)))]
self.masks.append(masks)
labels = [self._label2id[extract_tag_from_name(m)] for m in masks]
self.labels.append(labels)
def __len__(self):
""" Return the length of the dataset (# of images) """
return len(self.images)
def __getitem__(self, idx):
""" Return the image and the masks associated to it """
image_path = os.path.join(self.root, self.folder, 'original', self.images[idx])
image = Image.open(image_path)
image = image.convert('RGB')
if self.transform:
image = self.transform(image)
masks_path = [str(os.path.join(self.root, self.folder, 'original_labeled', self.mask_idx_folder[idx], m)) for m in self.masks[idx]]
labels = self.labels[idx]
if self.alternative_masks:
alternative_mask_path = os.path.join(self.root, self.folder, 'original_labeled_imageJ_mask', self.benchmark_masks[idx])
return {
'image': image,
'masks': masks_path,
'labels': labels,
'alternative_masks': Image.open(alternative_mask_path)
}
return {
'image': image,
'masks': masks_path,
'labels': labels
}
def get_number_of_masks(self):
""" Return the number of masks per image """
return [len(m) for m in self.masks]
def get_masks_per_labels(self):
""" Return the number of masks per label """
pattern = ''.join([s + '-|' if idx < len(self._label2id)-1 else s for idx, s in enumerate(self._label2id.keys())])
get_label = lambda x: re.findall(pattern, x)[0].replace('-', '')
out = [get_label(m) for img in self.masks for m in img]
return np.unique(out, return_counts=True)
def get_target(masks, labels, tfms, size=(250, 250), num_classes=4):
"""
Recibe una lista de máscaras y una lista de etiquetas, retorna
un tensor de dimensiones (n_classes, height, widht). En cada clase,
cualquier entero distinto a 0 representa a uns instancia particular
de la clase. El entero 0 se reserva para background/no-clase.
"""
# creamos tensor para almacenar máscaras por clase en cada canal (dim 1)
# NOTA: la última es ausencia o ninguna detecctión
out = torch.zeros((1, num_classes, size[0], size[1]))
# si no hay máscaras, retornamos el tensor
if len(masks) == 0:
return out
# crear un entero que represente cada instancia de las mascaras asociadas
# a la observación. No siguen un orden necesario tipo todas las de la clase 0
# parten al principio, se encuentran según el orden de los labels. El entero
# 0 se reserva como background/no-clase
instance_idxs = torch.tensor([l+1 for l in range(len(labels))], dtype=torch.long)
# iteramos sobre cada clase para procesar las máscaras asociadas a estas,
# agregar un identificador de instancias y colapsar en una sola matriz
# Supuesto: no hay clase sobrelapadas. Si hay, se debe hacer un proceso
# adicional para ver que entero se asigna al pixel correspondiente
for l in list(set(labels)):
x = torch.cat([torch.where(tfms(Image.open(m).resize(size)) > 0.0, 1.0, 0.0) * instance_idxs[i] for i, m in enumerate(masks) if labels[i] == l])
out[0, l, :, :] = x.sum(dim=0)
return out
def get_binary_target(masks, labels, tfms, size=(250, 250)):
"""
Recibe una lista de máscaras y una lista de etiquetas, retorna
un tensor de dimensiones (1, height, width). La clase target_id
se identificar con el entero 1, el resto es representado por
el entero 0 (es decir background o noise).
Supuesto: se espera que la clase con la señal a detectar siempre sea la
representada por el entero 0 en el diccionario dataset._label2id. Por ejemplo,
si queremos detectar la clase "normal" (ver dataset._label2id), entonces:
- dataset._label2id -> {'normal': 0, 'normal_cut': 1, 'noise': 2}
"""
target_id = 0
out = torch.zeros((1, 1, size[0], size[1]))
# si no hay máscaras o la clase a detectar no tiene máscara, retornamos el tensor
if len(masks) == 0 or target_id not in (labels):
return out
target_masks_idx = np.where(np.array(labels) == target_id)[0]
target_masks=[]
for idx in target_masks_idx:
target_masks.append(torch.where(tfms(Image.open(masks[idx]).resize(size)) > 0.0, 1.0, 0.0))
out[0,0,:,:] = torch.where(torch.cat(target_masks).sum(dim=0) > 0.0, 1.0, 0.0)
return out