Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/YeewahChan/qlib
Browse files Browse the repository at this point in the history
  • Loading branch information
YeewahChan committed Jun 20, 2024
2 parents 60f4ea8 + 1bc41a8 commit 4c7a1a1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions qlib/contrib/model/pytorch_alstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def metric_fn(self, pred, label):

if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
elif self.metric == "mse":
mask = ~torch.isnan(label)
weight = torch.ones_like(label)
return -self.mse(pred[mask], label[mask], weight[mask])

raise ValueError("unknown metric `%s`" % self.metric)

Expand Down
2 changes: 1 addition & 1 deletion qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None,

data = pd.DataFrame(obj)
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
# If the underlaying provides the data not in datetime format, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
Expand Down
2 changes: 1 addition & 1 deletion qlib/workflow/task/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _add_step(self, index, step):

def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datatime of segment
Shift the datetime of segment
If there are None (which indicates unbounded index) in the segment, this method will return None.
Expand Down

0 comments on commit 4c7a1a1

Please sign in to comment.