diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 6d985d2b20..0c797a09fe 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -403,7 +403,7 @@ def __init__( np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0, ) - self.nan_idx = -1 # The last line is all NaN + self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716 # the data type will be changed # The index of usable data is between start_idx and end_idx @@ -627,12 +627,7 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. - if indices[0] == -1 : - data = self.data_arr[0 : indices[-1] + 1] - # Prepend nan values to the data to match the step_len - data = np.concatenate([np.full((self.step_len - len(data), *data.shape[1:]), np.nan), data]) - else: - data = self.data_arr[indices[0] : indices[-1] + 1] + data = self.data_arr[indices[0] : indices[-1] + 1] else: data = self.data_arr[indices] if isinstance(idx, mtit):