-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcifar5.py
73 lines (51 loc) · 2.4 KB
/
cifar5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from copy import copy
import torch as th
from torchvision import datasets
from torchvision import transforms as trafo
def create_cifar5():
train_data_orig = datasets.CIFAR10("../data/cifar10", train=True, download=True,
transform=trafo.Compose([trafo.ToTensor()]))
test_data_orig = datasets.CIFAR10("../data/cifar10", train=False, download=True,
transform=trafo.Compose([trafo.ToTensor()]))
# Normalization computed on the reduced training set
tr = trafo.Compose([trafo.ToTensor(), trafo.Normalize((0.4905, 0.4854, 0.4514), (0.2454, 0.2415, 0.2620))])
# Prepare training and out of distribution data
mask = th.zeros(len(train_data_orig))
labels = th.LongTensor(train_data_orig.targets)
mask[labels.eq(0)] = 1
mask[labels.eq(1)] = 1
mask[labels.eq(2)] = 1
mask[labels.eq(3)] = 1
mask[labels.eq(4)] = 1
train_data_reduced = copy(train_data_orig)
ood_data_reduced = copy(train_data_orig)
train_data_reduced.data = train_data_orig.data[mask.numpy() == 1]
train_data_reduced.targets = labels[mask.eq(1)]
train_data_reduced.transform = tr
ood_data_reduced.data = train_data_orig.data[mask.numpy() == 0]
ood_data_reduced.targets = labels[mask.eq(0)] - 5
ood_data_reduced.transform = tr
# Prepare reduced test data
mask = th.zeros(len(test_data_orig))
labels = th.LongTensor(test_data_orig.targets)
mask[labels.eq(0)] = 1
mask[labels.eq(1)] = 1
mask[labels.eq(2)] = 1
mask[labels.eq(3)] = 1
mask[labels.eq(4)] = 1
test_data_reduced = copy(test_data_orig)
test_data_reduced.data = test_data_orig.data[mask.numpy() == 1]
test_data_reduced.targets = labels[mask.eq(1)]
test_data_reduced.transform = tr
return train_data_reduced, test_data_reduced, ood_data_reduced
def create_cifar10():
train_data = datasets.CIFAR10('../data/cifar10', train=True, download=True, transform=trafo.Compose([
trafo.ToTensor(),
trafo.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
]))
test_data = datasets.CIFAR10('../data/cifar10', train=False, download=True, transform=trafo.Compose([
trafo.ToTensor(),
trafo.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
]))
print("WARNING: There is no OOD Loader here")
return train_data, test_data, 0