Skip to content

Commit

Permalink
Fix TSDataSampler Slicing Bug #1716
Browse files Browse the repository at this point in the history
  • Loading branch information
YeewahChan committed Jun 3, 2024
1 parent 907c888 commit 1db716d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
7 changes: 6 additions & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,12 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)

if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up.
data = self.data_arr[indices[0] : indices[-1] + 1]
if indices[0] < 0 :
data = self.data_arr[0 : indices[-1] + 1]
# Prepend nan values to the data to match the step_len
data = np.concatenate([np.full((self.step_len - len(data), *data.shape[1:]), np.nan), data])
else:
data = self.data_arr[indices[0] : indices[-1] + 1]
else:
data = self.data_arr[indices]
if isinstance(idx, mtit):
Expand Down
54 changes: 52 additions & 2 deletions tests/data_mid_layer_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
from qlib.data.dataset import TSDatasetH, TSDataSampler
import numpy as np
import pandas as pd
import time
from qlib.data.dataset.handler import DataHandlerLP

Expand Down Expand Up @@ -97,7 +98,56 @@ def testTSDataset(self):
print(data.shape)
print(idx[i])


class TestTSDataSampler(unittest.TestCase):
def test_TSDataSampler(self):
"""
Test TSDataSampler for issue #1716
"""
datetime_list = [
'2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31'
]
instruments = ['000001', '000002', '000003', '000004', '000005']
index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments],
names=['datetime', 'instrument'])
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=['factor'])
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
assert len(dataset[0]) == 2
self.assertTrue(np.isnan(dataset[0][0]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[1][1], dataset[2][0])
self.assertEqual(dataset[2][1], dataset[3][0])

def test_TSDataSampler2(self):
"""
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
"""
datetime_list = [
'2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31'
]
instruments = ['000001', '000002', '000003', '000004', '000005']
index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments],
names=['datetime', 'instrument'])
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=['factor'])
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
for i in range(3):
self.assertFalse(np.isnan(dataset[0][i]))
self.assertFalse(np.isnan(dataset[1][i]))
#断言dataset[0][1]等于dataset[1][0]
self.assertEqual(dataset[0][1], dataset[1][0])
#断言dataset[0][2]等于dataset[1][1]
self.assertEqual(dataset[0][2], dataset[1][1])
if __name__ == "__main__":
unittest.main(verbosity=10)

Expand Down

0 comments on commit 1db716d

Please sign in to comment.