-
Notifications
You must be signed in to change notification settings - Fork 2
/
custom_dataset_data_loader.py
49 lines (43 loc) · 1.98 KB
/
custom_dataset_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch.utils.data
from data.base_data_loader import BaseDataLoader
import os
def CreateDataset(dataroots,dataset_mode='2afc',morpho=False,load_size=64,):
dataset = None
if not morpho: # human judgements
if dataset_mode=='2afc': # human judgements
from data.dataset.twoafc_dataset import TwoAFCDataset
dataset = TwoAFCDataset()
elif dataset_mode=='jnd': # human judgements
from data.dataset.jnd_dataset import JNDDataset
dataset = JNDDataset()
else:
raise ValueError("Dataset Mode [%s] not recognized." % self.dataset_mode)
else:
if dataset_mode=='2afc': # human judgements
from data.dataset.twoafc_dataset import TwoAFCDatasetMOPRHO
dataset = TwoAFCDatasetMOPRHO()
elif dataset_mode=='jnd': # human judgements
from data.dataset.jnd_dataset import JNDDataset
dataset = JNDDataset()
else:
raise ValueError("Dataset Mode [%s] not recognized."%self.dataset_mode)
dataset.initialize(dataroots,load_size=load_size)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, datafolders, dataroot='./dataset',dataset_mode='2afc',morpho=False,load_size=64,batch_size=1,serial_batches=True, nThreads=1):
BaseDataLoader.initialize(self)
if(not isinstance(datafolders,list)):
datafolders = [datafolders,]
data_root_folders = [os.path.join(dataroot,datafolder) for datafolder in datafolders]
self.dataset = CreateDataset(data_root_folders,dataset_mode=dataset_mode,morpho=morpho,load_size=load_size)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=not serial_batches,
num_workers=int(nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return len(self.dataset)