-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
229 lines (187 loc) · 8.25 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.functional import conv1d
from utils import discretize_targets, build_histogram, gaussian_fn
class ReplayDataset(Dataset):
def __init__(self, features=None, rewards=None, device=None):
# Keep original in order to do LDS
self.features = features
self.rewards = rewards
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
# Actually used for training purposes, may be different from original with LDS.
self.training_features = features
self.training_rewards = rewards
# Weights tensor for LDS
self.weights_per_obs = None
def __len__(self):
return len(self.rewards)
def __getitem__(self, idx):
return self.features[idx], self.rewards[idx]
def set_(self, features_2d, rewards_2d):
self.features = features_2d
self.rewards = rewards_2d
def add(self, features, reward):
if features is None or reward is None:
return
self.features = torch.cat((self.features, features))
self.rewards = torch.cat((self.rewards, reward))
def update_weights(self, kern_size=5, kern_sigma=2, reweight="sqrt_inv"):
"""Compute weights for label distribution smoothing via a gaussian kernel
Args:
kern_size (int, optional): Gaussian kernel size. Defaults to 5.
kern_sigma (int, optional): Gaussian kernel sigma. Defaults to 2.
reweight (str, optional): Type of reweighting done, must be either "sqrt_inv" or True. Defaults to "sqrt_inv".
Returns:
torch.Tensor (1D): sampling weights for label distribution smoothing
"""
assert reweight in ["sqrt_inv", True]
# Implements label distribution smoothing (Delving into Deep Imbalanced Regression, https://arxiv.org/abs/2102.09554)
bin_size = 0.1
factor = 10
# Discretize the risks (labels used later)
flat_labels = self.rewards.flatten()
discrete_risks = discretize_targets(flat_labels, factor)
hist, n_bins, list_bin_edges = build_histogram(flat_labels, factor, bin_size)
weights = hist.hist.to(self.device)
if reweight == "sqrt_inv":
weights = torch.sqrt(weights)
# Apply label distribution smoothing with gaussian filter
# Get the gaussian filter
kernel = gaussian_fn(kern_size, kern_sigma)[None, None]
weights = conv1d(weights[None, None], kernel, padding=(kern_size // 2))
weights = 1 / weights
# Get weights for dataset
weight_bins = {list_bin_edges[i]: weights[0][0][i] for i in range(n_bins)}
# This isn't slow, looping is fast, dictionaries are hashmaps
weights_per_obs = torch.tensor([weight_bins[risk] for risk in discrete_risks])
self.weights_per_obs = weights_per_obs / weights_per_obs.sum()
def update_dataset(self):
"""
Attempt at making a fully tabular dataset that can be batched via slicing instead of individual indexing
"""
# If we do not do LDS
if self.weights_per_obs is None:
# Keep original dataset as the training dataset
self.training_features = self.features
self.training_rewards = self.rewards
# If we do LDS
else:
# Sample according to LDs weights. Make that our new training dataset.
sampled_idx = self.weights_per_obs.multinomial(
num_samples=len(self), replacement=True
)
self.training_features = self.features[sampled_idx]
self.training_rewards = self.rewards[sampled_idx]
class ValidationReplayDataset(ReplayDataset):
def __init__(self, features=None, rewards=None, valtype=None):
self.features = None
self.rewards = None
self.valtype = valtype
self.bins_features = {}
self.bins_rewards = {}
if features is None or rewards is None:
pass
else:
self.set_(features, rewards)
def build_mapping(self):
for i in range(len(self.rewards)):
reward = self.rewards[i].unsqueeze(0)
vec = self.features[i].unsqueeze(0)
bin_ = int(reward * 10) / 10
self.bins_features[bin_] = vec
self.bins_rewards[bin_] = reward
def __len__(self):
return len(self.rewards)
def set_(self, features_2d, rewards_2d):
# Set to device right away because we don't use these datapoints in a dataloader in anyway. Doing this avoids repeatedly sending to device
self.features = features_2d
self.rewards = rewards_2d
if self.valtype == "bins":
self.build_mapping()
def update(self, features, reward):
to_training = (features, reward)
if self.valtype == "extrema":
# If valtype is extrema, then the set is composed of only 2 observations, the min and max of rewards
# Assumes sorted tensors based on self.rewards
if reward < min(self.rewards):
to_training = (
self.features[0].clone().unsqueeze(0),
self.rewards[0].clone().unsqueeze(0),
)
self.features[0] = features[0]
self.rewards[0] = reward[0]
elif reward.item() < min(self.rewards).item():
to_training = (
self.features[1].clone().unsqueeze(0),
self.rewards[1].clone().unsqueeze(0),
)
self.features[1] = features[0]
self.rewards[1] = reward[0]
elif self.valtype == "bins":
# Update the bins with the new observation
new_obs_bin = int(reward * 10) / 10
# Send the observation that was already in the bin to the training set
to_training = (
self.bins_features.get(new_obs_bin, None),
self.bins_rewards.get(new_obs_bin, None),
)
# Update the bin
self.bins_features[new_obs_bin] = features
self.bins_rewards[new_obs_bin] = reward
self.rewards = torch.cat(list(self.bins_rewards.values()))
self.features = torch.cat(list(self.bins_features.values()))
return to_training
class FastTensorDataLoader:
"""
A DataLoader-like object for a set of tensors that can be much faster than
TensorDataset + DataLoader because dataloader grabs individual indices of
the dataset and calls cat (slow).
Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
"""
def __init__(self, *tensors, batch_size=32, shuffle=False):
"""
Initialize a FastTensorDataLoader.
:param *tensors: tensors to store. Must have the same length @ dim 0.
:param batch_size: batch size to load.
:param shuffle: if True, shuffle the data *in-place* whenever an
iterator is created out of this object.
:returns: A FastTensorDataLoader.
"""
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
self.tensors = tensors
self.dataset_len = self.tensors[0].shape[0]
self.batch_size = batch_size
self.shuffle = shuffle
# Calculate # batches
n_batches, remainder = divmod(self.dataset_len, self.batch_size)
if remainder > 0:
n_batches += 1
self.n_batches = n_batches
def __iter__(self):
"""Shuffles the dataset if needed and resets the index
Returns:
self
"""
if self.shuffle:
r = torch.randperm(self.dataset_len)
self.tensors = [t[r] for t in self.tensors]
self.i = 0
return self
def __next__(self):
"""Picks and returns a batch
Raises:
StopIteration: if end of dataset reached
Returns:
torch.Tensor: batch of data
"""
if self.i >= self.dataset_len:
raise StopIteration
batch = tuple(t[self.i : self.i + self.batch_size] for t in self.tensors)
self.i += self.batch_size
return batch
def __len__(self):
return self.n_batches