-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_database.py
73 lines (60 loc) · 2.41 KB
/
prepare_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pymongo import MongoClient
from datetime import datetime
import os
import numpy as np
from PIL import Image
import pickle
def insert_attack_types(db, benchmark):
attack_types = ['lambda', 'rectangle', 'random rectangle', 'noise', 'spread-out']
categories = ['backdoor', 'backdoor', 'backdoor', 'backdoor', 'backdoor']
for i in range(len(attack_types)):
attack_type = {
"category": categories[i],
"type": attack_types[i],
}
db.attack.insert_one(attack_type)
def insert_classnames(db, benchmark):
benchmark_obj = db.benchmarks.find_one({"benchmark": benchmark})
benchmark_id = str(benchmark_obj['_id'])
list_of_label_names = list(range(10)) if benchmark == 'MNIST' else np.load(os.path.join(benchmark, 'label_names.npy'))
for i, label_name in enumerate(list_of_label_names):
class_name = {
"benchmark": benchmark_id,
"class_id": i,
"class_label": label_name,
}
db.class_names.insert_one(class_name)
def insert_benchmark_images(db, benchmark):
benchmark_obj = db.benchmarks.find_one({"benchmark": benchmark})
benchmark_id = str(benchmark_obj['_id'])
folder = os.path.join(benchmark, 'images')
list_of_images = os.listdir(folder)
list_of_labels = np.load(os.path.join(benchmark, 'y_test.npy'))
for i, img in enumerate(list_of_images):
im = Image.open(os.path.join(folder, img))
label = list_of_labels[i]
image = {
"image": pickle.dumps(im),
"class": int(label),
"benchmark": benchmark_id,
"created_at": datetime.utcnow()
}
db.images.insert_one(image)
def insert_benchmarks(db, benchmarks):
for benchmark in benchmarks:
response = db.benchmarks.insert_one({"benchmark": benchmark})
if __name__ == '__main__':
client = MongoClient('localhost', 27017)
db = client.models
# Add benchmarks to the database
benchmarks = ['CIFAR10', 'GTSRB', 'MNIST', 'Fashion_MNIST']
insert_benchmarks(db, benchmarks)
# Add validation images for each benchmark to the database
for benchmark in benchmarks:
insert_benchmark_images(db, benchmark)
# Add class names to the database
for benchmark in benchmarks:
insert_classnames(db, benchmark)
# Add attack types to the database
for benchmark in benchmarks:
insert_attack_types(db, benchmark)