Skip to content

Commit

Permalink
Fix TSDataSampler Slicing Bug with simplyer implmentation#1716
Browse files Browse the repository at this point in the history
 with Simplified Implementation
  • Loading branch information
YeewahChan committed Jun 16, 2024
1 parent 31af793 commit f9810d6
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716

# the data type will be changed
# The index of usable data is between start_idx and end_idx
Expand Down Expand Up @@ -627,12 +627,7 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)

if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up.
if indices[0] == -1 :
data = self.data_arr[0 : indices[-1] + 1]
# Prepend nan values to the data to match the step_len
data = np.concatenate([np.full((self.step_len - len(data), *data.shape[1:]), np.nan), data])
else:
data = self.data_arr[indices[0] : indices[-1] + 1]
data = self.data_arr[indices[0] : indices[-1] + 1]
else:
data = self.data_arr[indices]
if isinstance(idx, mtit):
Expand Down

0 comments on commit f9810d6

Please sign in to comment.