32
32
33
33
34
34
class DailyBatchSampler (Sampler ):
35
+
35
36
def __init__ (self , data_source ):
37
+
36
38
self .data_source = data_source
37
39
self .data = self .data_source .data .loc [self .data_source .get_index ()]
38
- self .daily_count = self .data .groupby (level = 0 ).size ().values
39
- self .daily_index = np .roll (np .cumsum (self .daily_count ), 1 )
40
+ self .daily_count = self .data .groupby (level = 0 ).size ().values [ 1 :]
41
+ self .daily_index = np .roll (np .cumsum (self .daily_count ), 1 )[ 1 :]
40
42
41
43
def __iter__ (self ):
42
44
for idx , count in zip (self .daily_index , self .daily_count ):
43
- yield slice (idx , idx + count )
45
+ yield np . arange (idx , idx + count )
44
46
45
47
def __len__ (self ):
46
48
return len (self .data_source )
@@ -202,6 +204,8 @@ def train_epoch(self, data_loader):
202
204
self .GAT_model .train ()
203
205
204
206
for data in data_loader :
207
+
208
+ data = data .squeeze ()
205
209
feature = data [:, :, 0 :- 1 ].to (self .device )
206
210
label = data [:, - 1 , - 1 ].to (self .device )
207
211
@@ -222,6 +226,7 @@ def test_epoch(self, data_loader):
222
226
223
227
for data in data_loader :
224
228
229
+ data = data .squeeze ()
225
230
feature = data [:, :, 0 :- 1 ].to (self .device )
226
231
# feature[torch.isnan(feature)] = 0
227
232
label = data [:, - 1 , - 1 ].to (self .device )
@@ -335,6 +340,7 @@ def predict(self, dataset):
335
340
336
341
for data in test_loader :
337
342
343
+ data = data .squeeze ()
338
344
feature = data [:, :, 0 :- 1 ].to (self .device )
339
345
340
346
with torch .no_grad ():
0 commit comments