-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtianchiAL.py
97 lines (88 loc) · 5.11 KB
/
tianchiAL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from pathlib import Path
from PIL import Image
import os
from torchvision import transforms
tianchiAL_img_path = '/media/zcf/Elements/dataset/tianchi2018al/guangdong_round2_train_20181011_new' #---1.diff from mobile.py
# 1.mobile -->PVELAD
class tianchiAL(Dataset): #---2.diff from mobile.py Mobile-->PVELAD
def __init__(self, label_path, img_root, transform=None):
""" Constructs an Imagenet Dataset from a CSV file. The file should list the path to the
images and the corresponding label. For example:
val/n02100583/ILSVRC2012_val_00013430.JPEG, 0
Args:
csv_file(Path): Path to the csv file with image paths and labels.
imagenet_path(Path): Home directory of the Imagenet dataset.
transform(torchvision.transforms): Transforms to apply to the images.
"""
if isinstance(label_path, (str, Path)):
dataset = pd.read_csv(label_path)
elif isinstance(label_path, list):
dataset = [pd.read_csv(csv_f) for csv_f in label_path]
dataset = pd.concat(dataset,ignore_index=True)
# 原则上应该两列,即[图像名称,label],将one hot label转换回去
if dataset.shape[1] > 2:
oneHotLabel_list = np.array(dataset)[:,1:]
oneHotLabel_list = np.concatenate(((1-oneHotLabel_list.sum(-1))[:,np.newaxis],oneHotLabel_list),axis=1) # for normal
# if not(['normal']+dataset.columns.tolist()[1:] == label_index):
# labels = ['normal']+dataset.columns.tolist()[1:]
# label_id = [labels.index(ii) for ii in label_index[1:]]
# oneHotLabel_list = oneHotLabel_list[:,label_id]
print('标签顺序为:',['normal']+dataset.columns.tolist()[1:])
targets = np.where(oneHotLabel_list == 1)[-1]
targets = np.concatenate((np.array(dataset)[:,0][:,np.newaxis],targets[:,np.newaxis]),axis=1)
dataset = pd.DataFrame(targets)
self.img_root = img_root
self.transform = transform
self.target_transform = None
if self.transform is None:
self.transform = transforms.Compose([
transforms.Resize((1920, 2560)), #h,w----diff from mobile.py
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.unique_classes = np.sort(dataset[1].unique())
self.data, self.targets = np.array(dataset)[:,0].tolist(), np.array(dataset)[:,1].tolist()
# self.label_count = len(self.dataset[1].unique())
# self.unique_classes = np.sort(self.dataset[1].unique())
self.uq_idxs = np.array(range(len(self)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img_path, target = os.path.join(self.img_root,self.data[index]), self.targets[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, self.uq_idxs[index]
def get_tianchiAL_datasets(train_transform, test_transform, train_classes=range(4),
open_set_classes=range(4, 10), balance_open_set_eval=False, split_train_val=True, seed=0,dataroot=None):
np.random.seed(seed)
global tianchiAL_img_path
if dataroot:
tianchiAL_img_path = dataroot
train_file = os.path.join(tianchiAL_img_path,'labels/singleLabel_base_train.csv') #0,1,2,3,4,5 #---3.diff from mobile.py singleLabel_
val_file = os.path.join(tianchiAL_img_path,'labels/singleLabel_base_val.csv') #0,1,2,3,4,5 #---3.diff from mobile.py singleLabel_
test_file_known = os.path.join(tianchiAL_img_path,'labels/singleLabel_base_test.csv') #0,1,2,3,4,5 #---3.diff from mobile.py singleLabel_
test_file_unknown = os.path.join(tianchiAL_img_path,'labels/singleLabel_novel.csv') #6,7,8 #---3.diff from mobile.py singleLabel_
# Init train dataset
train_dataset_whole = tianchiAL(label_path=train_file,img_root=tianchiAL_img_path,transform=train_transform) #---2.diff from mobile.py Mobile-->PVELAD
val_dataset_whole = tianchiAL(label_path=val_file,img_root=tianchiAL_img_path,transform=test_transform) #---2.diff from mobile.py Mobile-->PVELAD
test_dataset_known = tianchiAL(label_path=test_file_known,img_root=tianchiAL_img_path,transform=test_transform) #---2.diff from mobile.py Mobile-->PVELAD
test_dataset_unknown = tianchiAL(label_path=test_file_unknown,img_root=tianchiAL_img_path,transform=test_transform) #---2.diff from mobile.py Mobile-->PVELAD
all_datasets = {
'train': train_dataset_whole,
'val': val_dataset_whole,
'test_known': test_dataset_known,
'test_unknown': test_dataset_unknown,
}
print('tianchiAL')
return all_datasets
if __name__ == '__main__':
# x = get_cifar_10_100_datasets(None, None, balance_open_set_eval=True)
x = get_tianchiAL_datasets(None, None)
print([len(v) for k, v in x.items()])
debug = 0