-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
76 lines (65 loc) · 2.69 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
import torchvision
import torchvision.transforms as transforms
ROOT = './data'
def get_dataloader(dataset: str, img_size, args, train=True):
if "cifar" in dataset:
normalization = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalization
])
val_transforms = transforms.Compose([
transforms.ToTensor(),
normalization
])
if dataset == 'cifar10':
train_set = torchvision.datasets.CIFAR10(ROOT, train=True, transform=train_transforms)
val_set = torchvision.datasets.CIFAR10(ROOT, train=False, transform=val_transforms)
elif dataset == 'cifar100':
train_set = torchvision.datasets.CIFAR100(ROOT, train=True, transform=train_transforms)
val_set = torchvision.datasets.CIFAR100(ROOT, train=False, transform=val_transforms)
elif dataset == 'imagenet':
traindir = os.path.join(ROOT, 'train')
valdir = os.path.join(ROOT, 'val')
normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalization,
]))
val_set = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
normalization,
]))
else:
raise NotImplementedError
train_sampler = None
if hasattr(args, 'ddp'):
if args.ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
train_loader = torch.utils.data.DataLoader(train_set, args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(val_set, args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
if train:
return train_loader, val_loader, train_sampler
else:
return val_loader
def get_num_classes(dataset: str):
if dataset == 'cifar10':
return 10
elif dataset == 'cifar100':
return 100
elif dataset == 'imagenet':
return 1000
else:
raise NotImplementedError