-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcifar_fs.py
132 lines (111 loc) · 5.31 KB
/
cifar_fs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import print_function
from torchtools import *
import torch.utils.data as data
import random
import os
import numpy as np
from PIL import Image as pil_image
import pickle
from itertools import islice
from torchvision import transforms
from tqdm import tqdm
import cv2
class CifarFsLoader(data.Dataset):
def __init__(self, root, partition='train'):
super(CifarFsLoader, self).__init__()
# set dataset information
self.root = root
self.partition = partition
self.data_size = [3, 32, 32]
# set normalizer
mean_pix = [x / 255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
std_pix = [x / 255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
# set transformer
if self.partition == 'train':
self.transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize])
else: # 'val' or 'test' ,
self.transform = transforms.Compose([lambda x: np.asarray(x),
transforms.ToTensor(),
normalize])
# load data
self.data = self.load_dataset()
def load_dataset(self):
# load data
dataset_path = os.path.join(self.root, 'CIFAR_FS_%s.pickle' % self.partition)
try:
with open(dataset_path, 'rb') as fo:
data = pickle.load(fo)
except:
with open(dataset_path, 'rb') as f:
u = pickle._Unpickler(f)
u.encoding = 'latin1'
data = u.load()
data_c = {}
for i in range(len(data['labels'])):
# resize
image_data = pil_image.fromarray(np.uint8(data['data'][i]))
image_data = image_data.resize((self.data_size[2], self.data_size[1]))
#image_data = np.array(image_data, dtype='float32')
#image_data = np.transpose(image_data, (2, 0, 1))
# save
if data['labels'][i] in data_c:
data_c[data['labels'][i]].append(image_data)
else:
data_c[data['labels'][i]] = []
data_c[data['labels'][i]].append(image_data)
return data_c
def get_task_batch(self,
num_tasks=5,
num_ways=20,
num_shots=1,
num_queries=1,
seed=None):
if seed is not None:
random.seed(seed)
# init task batch data
support_data, support_label, query_data, query_label = [], [], [], []
for _ in range(num_ways * num_shots):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
support_data.append(data)
support_label.append(label)
for _ in range(num_ways * num_queries):
data = np.zeros(shape=[num_tasks] + self.data_size,
dtype='float32')
label = np.zeros(shape=[num_tasks],
dtype='float32')
query_data.append(data)
query_label.append(label)
# get full class list in dataset
full_class_list = list(self.data.keys())
# for each task
for t_idx in range(num_tasks):
# define task by sampling classes (num_ways)
task_class_list = random.sample(full_class_list, num_ways)
# for each sampled class in task
for c_idx in range(num_ways):
# sample data for support and query (num_shots + num_queries)
class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)
# load sample for support set
for i_idx in range(num_shots):
# set data
support_data[i_idx + c_idx * num_shots][t_idx] = self.transform(class_data_list[i_idx])
support_label[i_idx + c_idx * num_shots][t_idx] = c_idx
# load sample for query set
for i_idx in range(num_queries):
query_data[i_idx + c_idx * num_queries][t_idx] = self.transform(class_data_list[num_shots + i_idx])
query_label[i_idx + c_idx * num_queries][t_idx] = c_idx
# convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)
return [support_data, support_label, query_data, query_label]
if __name__ == "__main__":
c = CifarFsLoader('E:\data\CIFAR-FS')