forked from mmasana/FACIL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory_dataset.py
124 lines (106 loc) · 4.81 KB
/
memory_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
class MemoryDataset(Dataset):
"""Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory"""
def __init__(self, data, transform, class_indices=None):
"""Initialization"""
self.labels = data['y']
self.images = data['x']
self.transform = transform
self.class_indices = class_indices
def __len__(self):
"""Denotes the total number of samples"""
return len(self.images)
def __getitem__(self, index):
"""Generates one sample of data"""
x = Image.fromarray(self.images[index])
x = self.transform(x)
y = self.labels[index]
return x, y
def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
"""Prepare data: dataset splits, task partition, class order"""
data = {}
taskcla = []
if class_order is None:
num_classes = len(np.unique(trn_data['y']))
class_order = list(range(num_classes))
else:
num_classes = len(class_order)
class_order = class_order.copy()
if shuffle_classes:
np.random.shuffle(class_order)
# compute classes per task and num_tasks
if nc_first_task is None:
cpertask = np.array([num_classes // num_tasks] * num_tasks)
for i in range(num_classes % num_tasks):
cpertask[i] += 1
else:
assert nc_first_task < num_classes, "first task wants more classes than exist"
remaining_classes = num_classes - nc_first_task
assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
for i in range(remaining_classes % (num_tasks - 1)):
cpertask[i + 1] += 1
assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
cpertask_cumsum = np.cumsum(cpertask)
init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
# initialize data structure
for tt in range(num_tasks):
data[tt] = {}
data[tt]['name'] = 'task-' + str(tt)
data[tt]['trn'] = {'x': [], 'y': []}
data[tt]['val'] = {'x': [], 'y': []}
data[tt]['tst'] = {'x': [], 'y': []}
# ALL OR TRAIN
filtering = np.isin(trn_data['y'], class_order)
if filtering.sum() != len(trn_data['y']):
trn_data['x'] = trn_data['x'][filtering]
trn_data['y'] = np.array(trn_data['y'])[filtering]
for this_image, this_label in zip(trn_data['x'], trn_data['y']):
# If shuffling is false, it won't change the class number
this_label = class_order.index(this_label)
# add it to the corresponding split
this_task = (this_label >= cpertask_cumsum).sum()
data[this_task]['trn']['x'].append(this_image)
data[this_task]['trn']['y'].append(this_label - init_class[this_task])
# ALL OR TEST
filtering = np.isin(tst_data['y'], class_order)
if filtering.sum() != len(tst_data['y']):
tst_data['x'] = tst_data['x'][filtering]
tst_data['y'] = tst_data['y'][filtering]
for this_image, this_label in zip(tst_data['x'], tst_data['y']):
# If shuffling is false, it won't change the class number
this_label = class_order.index(this_label)
# add it to the corresponding split
this_task = (this_label >= cpertask_cumsum).sum()
data[this_task]['tst']['x'].append(this_image)
data[this_task]['tst']['y'].append(this_label - init_class[this_task])
# check classes
for tt in range(num_tasks):
data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
# validation
if validation > 0.0:
for tt in data.keys():
for cc in range(data[tt]['ncla']):
cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
rnd_img.sort(reverse=True)
for ii in range(len(rnd_img)):
data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
data[tt]['trn']['x'].pop(rnd_img[ii])
data[tt]['trn']['y'].pop(rnd_img[ii])
# convert them to numpy arrays
for tt in data.keys():
for split in ['trn', 'val', 'tst']:
data[tt][split]['x'] = np.asarray(data[tt][split]['x'])
# other
n = 0
for t in data.keys():
taskcla.append((t, data[t]['ncla']))
n += data[t]['ncla']
data['ncla'] = n
return data, taskcla, class_order