diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 008d789402..3fb7cb9e19 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -160,6 +160,10 @@ def metric_fn(self, pred, label): if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) + elif self.metric == "mse": + mask = ~torch.isnan(label) + weight = torch.ones_like(label) + return -self.mse(pred[mask], label[mask], weight[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/data/data.py b/qlib/data/data.py index 1b1353ee4e..aba75c0b1a 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -616,7 +616,7 @@ def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, data = pd.DataFrame(obj) if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")): - # If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format + # If the underlaying provides the data not in datetime format, we'll convert it into datetime format _calendar = Cal.calendar(freq=freq) data.index = _calendar[data.index.values.astype(int)] data.index.names = ["datetime"] diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 19837b3c79..4b4a7c06b8 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -242,7 +242,7 @@ def _add_step(self, index, step): def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: """ - Shift the datatime of segment + Shift the datetime of segment If there are None (which indicates unbounded index) in the segment, this method will return None.