-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
29 lines (26 loc) · 998 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# import torchmeta
# from torchmeta.utils.data import BatchMetaDataLoader
# TODO: incorporate torchmeta or meta-dataset
def dataloader(hparams):
if hparams.dataset == "omniglot":
from support.omniglot_loaders import OmniglotNShot
return OmniglotNShot(
'/tmp/omniglot-data',
batchsz=hparams.meta_batch_size,
n_way=hparams.n_way,
k_shot=hparams.k_support,
k_query=hparams.k_query,
imgsz=28,
device=hparams.device,
)
elif hparams.dataset == "quickdraw":
from support.quickdraw_loaders import QuickdrawNShot
return QuickdrawNShot(
'./support/data/QuickDrawData.pkl',
batchsz=hparams.meta_batch_size,
n_way=hparams.n_way,
k_shot=hparams.k_support,
k_query=hparams.k_query,
imgsz=28,
device=hparams.device,
)