forked from carlini/nn_breaking_detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup_cifar.py
executable file
·98 lines (77 loc) · 3.04 KB
/
setup_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
## setup_cifar.py -- code to set up the CIFAR dataset
##
## Copyright (C) 2017, Nicholas Carlini <[email protected]>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
import tensorflow as tf
import numpy as np
import os
import pickle
import gzip
import pickle
import urllib.request
from resnet import ResnetBuilder
from keras.layers import Dropout
def load_batch(fpath, label_key="labels"):
f = open(fpath, "rb")
d = pickle.load(f, encoding="bytes")
for k, v in d.items():
del(d[k])
d[k.decode("utf8")] = v
f.close()
data = d["data"]
labels = d[label_key]
data = data.reshape(data.shape[0], 3, 32, 32)
final = np.zeros((data.shape[0], 32, 32, 3),dtype=np.float32)
final[:,:,:,0] = data[:,0,:,:]
final[:,:,:,1] = data[:,1,:,:]
final[:,:,:,2] = data[:,2,:,:]
final /= 255
final -= .5
labels2 = np.zeros((len(labels), 10))
labels2[np.arange(len(labels2)), labels] = 1
return final, labels
def load_batch(fpath):
f = open(fpath,"rb").read()
size = 32*32*3+1
labels = []
images = []
for i in range(10000):
arr = np.fromstring(f[i*size:(i+1)*size],dtype=np.uint8)
lab = np.identity(10)[arr[0]]
img = arr[1:].reshape((3,32,32)).transpose((1,2,0))
labels.append(lab)
images.append((img/255)-.5)
return np.array(images),np.array(labels)
class CIFAR:
def __init__(self):
train_data = []
train_labels = []
if not os.path.exists("../Datasets/CIFAR-10"):
os.mkdir("../Datasets/CIFAR-10")
urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", "../Datasets/CIFAR-10/cifar-data.tar.gz")
os.popen("tar -xzf ../Datasets/CIFAR-10/cifar-data.tar.gz -C ../Datasets/CIFAR-10").read()
for i in range(5):
r,s = load_batch("../Datasets/CIFAR-10/cifar-10-batches-bin/data_batch_"+str(i+1)+".bin")
train_data.extend(r)
train_labels.extend(s)
train_data = np.array(train_data,dtype=np.float32)
train_labels = np.array(train_labels)
self.test_data, self.test_labels = load_batch("../Datasets/CIFAR-10/cifar-10-batches-bin/test_batch.bin")
VALIDATION_SIZE = 5000
self.validation_data = train_data[:VALIDATION_SIZE, :, :, :]
self.validation_labels = train_labels[:VALIDATION_SIZE]
self.train_data = train_data[VALIDATION_SIZE:, :, :, :]
self.train_labels = train_labels[VALIDATION_SIZE:]
class CIFARModel:
def __init__(self, restore=None, session=None, Dropout=Dropout, num_labels=10):
self.num_channels = 3
self.image_size = 32
self.num_labels = num_labels
model = ResnetBuilder.build_resnet_32((3, 32, 32), num_labels, activation=False, Dropout=Dropout)
if restore != None:
model.load_weights(restore)
self.model = model
def predict(self, data):
return self.model(data)