Skip to content

Commit 666e1ff

Browse files
lwwang1995you-n-g
authored andcommitted
Update settings.
1 parent 70fb760 commit 666e1ff

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

examples/benchmarks/SFM/workflow_config_sfm_Alpha158.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ task:
5757
num_layers: 2
5858
dropout: 0.0
5959
n_epochs: 200
60-
lr: 1e-1
60+
lr: 5e-2
6161
early_stop: 10
6262
batch_size: 800
6363
metric: loss

qlib/contrib/model/pytorch_gats_ts.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@
3232

3333

3434
class DailyBatchSampler(Sampler):
35+
3536
def __init__(self, data_source):
37+
3638
self.data_source = data_source
3739
self.data = self.data_source.data.loc[self.data_source.get_index()]
38-
self.daily_count = self.data.groupby(level=0).size().values
39-
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)
40+
self.daily_count = self.data.groupby(level=0).size().values[1:]
41+
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:]
4042

4143
def __iter__(self):
4244
for idx, count in zip(self.daily_index, self.daily_count):
43-
yield slice(idx, idx + count)
45+
yield np.arange(idx, idx + count)
4446

4547
def __len__(self):
4648
return len(self.data_source)
@@ -202,6 +204,8 @@ def train_epoch(self, data_loader):
202204
self.GAT_model.train()
203205

204206
for data in data_loader:
207+
208+
data = data.squeeze()
205209
feature = data[:, :, 0:-1].to(self.device)
206210
label = data[:, -1, -1].to(self.device)
207211

@@ -222,6 +226,7 @@ def test_epoch(self, data_loader):
222226

223227
for data in data_loader:
224228

229+
data = data.squeeze()
225230
feature = data[:, :, 0:-1].to(self.device)
226231
# feature[torch.isnan(feature)] = 0
227232
label = data[:, -1, -1].to(self.device)
@@ -335,6 +340,7 @@ def predict(self, dataset):
335340

336341
for data in test_loader:
337342

343+
data = data.squeeze()
338344
feature = data[:, :, 0:-1].to(self.device)
339345

340346
with torch.no_grad():

0 commit comments

Comments
 (0)