-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
104 lines (84 loc) · 3.23 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
MxNet compatible dataloader
"""
from mxnet.gluon.data import DataLoader, Sampler
import math
import numpy as np
from mxnet import nd
from sklearn.model_selection import StratifiedKFold
import dgl
class SubsetRandomSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter([self.indices[i] for i in np.random.permutation(len(self.indices))])
def __len__(self):
return len(self.indices)
# default collate function
def collate(samples):
# The input `samples` is a list of pairs (graph, label).
graphs, labels = map(list, zip(*samples))
for g in graphs:
# deal with node feats
for key in g.node_attr_schemes().keys():
g.ndata[key] = nd.array(g.ndata[key])
# no edge feats
batched_graph = dgl.batch(graphs)
labels = [nd.reshape(label, (1,)) for label in labels]
labels = nd.concat(*labels, dim=0)
return batched_graph, labels
class GraphDataLoader():
def __init__(self,
dataset,
batch_size,
collate_fn=collate,
seed=0,
shuffle=True,
split_name='fold10',
fold_idx=0,
split_ratio=0.7):
self.shuffle = shuffle
self.seed = seed
labels = [l for _, l in dataset]
if split_name == 'fold10':
train_idx, valid_idx = self._split_fold10(
labels, fold_idx, seed, shuffle)
elif split_name == 'rand':
train_idx, valid_idx = self._split_rand(
labels, split_ratio, seed, shuffle)
else:
raise NotImplementedError()
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
self.train_loader = DataLoader(
dataset, sampler=train_sampler,
batch_size=batch_size, batchify_fn=collate_fn)
self.valid_loader = DataLoader(
dataset, sampler=valid_sampler,
batch_size=batch_size, batchify_fn=collate_fn)
def train_valid_loader(self):
return self.train_loader, self.valid_loader
def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):
''' 10 flod '''
assert 0 <= fold_idx and fold_idx < 10, print(
"fold_idx must be from 0 to 9.")
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), [label.asnumpy() for label in labels]): # split(x, y)
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx
def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):
num_entries = len(labels)
indices = list(range(num_entries))
np.random.seed(seed)
np.random.shuffle(indices)
split = int(math.floor(split_ratio * num_entries))
train_idx, valid_idx = indices[:split], indices[split:]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx