-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloaders.py
71 lines (59 loc) · 2.23 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torchvision
import torchvision.transforms as transforms
from base import BaseDataLoader
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])
])
def get_dataloader(
data_name: str = 'cifar100',
data_augmentation=True,
batch_size: int = 128,
num_workers: int = 2
) -> (torch.utils.data.DataLoader, torch.utils.data.DataLoader):
"""
Get dataloader for the required dataset
Args:
data_name (str, optional): Name of the dataset. Defaults to 'cifar100'.
data_augmentation (bool, optional): Whether to apply data augmentation. Defaults to True.
batch_size (int, optional): Number of samples in a batch. Defaults to 128.
num_workers (int, optional): Number of workers to use for data loading. Defaults to 2.
"""
train_dataset, test_dataset = None, None
if data_name == 'cifar100':
train_dataset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test
)
# elif data_name == 'imagenet':
# train_dataset
train_loader = BaseDataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
)
test_loader = BaseDataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
return train_loader, test_loader
if __name__ == '__main__':
train_loader, test_loader = get_dataloader()
print(f'Train loader: {len(train_loader)} batches')
print(f'Test loader: {len(test_loader)} batches')
# print(f'Validation loader: {len(val_loader)} batches')
print(f'Number of classes: {train_loader.dataset} classes')