-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathaircraft.py
118 lines (98 loc) · 4.67 KB
/
aircraft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import os
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torchvision.datasets.utils import extract_archive
class Aircraft(VisionDataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
class_type (string, optional): choose from ('variant', 'family', 'manufacturer').
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
def __init__(self, root, train=True, class_type='variant', transform=None,
target_transform=None, download=False):
super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)
split = 'trainval' if train else 'test'
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = self.find_classes()
samples = self.make_dataset(image_ids, targets)
self.loader = default_loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.img_folder)) and \
os.path.exists(self.classes_file)
def download(self):
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s...' % self.url)
tar_name = self.url.rpartition('/')[-1]
download_url(self.url, root=self.root, filename=tar_name)
tar_path = os.path.join(self.root, tar_name)
print('Extracting %s...' % tar_path)
extract_archive(tar_path)
print('Done!')
def find_classes(self):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
with open(self.classes_file, 'r') as f:
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return image_ids, targets, classes, class_to_idx
def make_dataset(self, image_ids, targets):
assert (len(image_ids) == len(targets))
images = []
for i in range(len(image_ids)):
item = (os.path.join(self.root, self.img_folder,
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
if __name__ == '__main__':
train_dataset = Aircraft('./aircraft', train=True, download=False)
test_dataset = Aircraft('./aircraft', train=False, download=False)