-
Notifications
You must be signed in to change notification settings - Fork 63
/
augmentations.py
82 lines (61 loc) · 2.9 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import numpy as np
def embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset=False):
device = x_cont.device
x_categ = x_categ + model.categories_offset.type_as(x_categ)
x_categ_enc = model.embeds(x_categ)
n1,n2 = x_cont.shape
_, n3 = x_categ.shape
if model.cont_embeddings == 'MLP':
x_cont_enc = torch.empty(n1,n2, model.dim)
for i in range(model.num_continuous):
x_cont_enc[:,i,:] = model.simple_MLP[i](x_cont[:,i])
else:
raise Exception('This case should not work!')
x_cont_enc = x_cont_enc.to(device)
cat_mask_temp = cat_mask + model.cat_mask_offset.type_as(cat_mask)
con_mask_temp = con_mask + model.con_mask_offset.type_as(con_mask)
cat_mask_temp = model.mask_embeds_cat(cat_mask_temp)
con_mask_temp = model.mask_embeds_cont(con_mask_temp)
x_categ_enc[cat_mask == 0] = cat_mask_temp[cat_mask == 0]
x_cont_enc[con_mask == 0] = con_mask_temp[con_mask == 0]
if vision_dset:
pos = np.tile(np.arange(x_categ.shape[-1]),(x_categ.shape[0],1))
pos = torch.from_numpy(pos).to(device)
pos_enc =model.pos_encodings(pos)
x_categ_enc+=pos_enc
return x_categ, x_categ_enc, x_cont_enc
def mixup_data(x1, x2 , lam=1.0, y= None, use_cuda=True):
'''Returns mixed inputs, pairs of targets'''
batch_size = x1.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x1 = lam * x1 + (1 - lam) * x1[index, :]
mixed_x2 = lam * x2 + (1 - lam) * x2[index, :]
if y is not None:
y_a, y_b = y, y[index]
return mixed_x1, mixed_x2, y_a, y_b
return mixed_x1, mixed_x2
def add_noise(x_categ,x_cont, noise_params = {'noise_type' : ['cutmix'],'lambda' : 0.1}):
lam = noise_params['lambda']
device = x_categ.device
batch_size = x_categ.size()[0]
if 'cutmix' in noise_params['noise_type']:
index = torch.randperm(batch_size)
cat_corr = torch.from_numpy(np.random.choice(2,(x_categ.shape),p=[lam,1-lam])).to(device)
con_corr = torch.from_numpy(np.random.choice(2,(x_cont.shape),p=[lam,1-lam])).to(device)
x1, x2 = x_categ[index,:], x_cont[index,:]
x_categ_corr, x_cont_corr = x_categ.clone().detach() ,x_cont.clone().detach()
x_categ_corr[cat_corr==0] = x1[cat_corr==0]
x_cont_corr[con_corr==0] = x2[con_corr==0]
return x_categ_corr, x_cont_corr
elif noise_params['noise_type'] == 'missing':
x_categ_mask = np.random.choice(2,(x_categ.shape),p=[lam,1-lam])
x_cont_mask = np.random.choice(2,(x_cont.shape),p=[lam,1-lam])
x_categ_mask = torch.from_numpy(x_categ_mask).to(device)
x_cont_mask = torch.from_numpy(x_cont_mask).to(device)
return torch.mul(x_categ,x_categ_mask), torch.mul(x_cont,x_cont_mask)
else:
print("yet to write this")