From 1db716d23e1e6e64b654517eddf33e274c434101 Mon Sep 17 00:00:00 2001 From: YeewahChan Date: Mon, 3 Jun 2024 11:48:13 +0800 Subject: [PATCH] Fix TSDataSampler Slicing Bug #1716 --- qlib/data/dataset/__init__.py | 7 ++- tests/data_mid_layer_tests/test_dataset.py | 54 +++++++++++++++++++++- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index aacd58389a..2c3c12c0a7 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -627,7 +627,12 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. - data = self.data_arr[indices[0] : indices[-1] + 1] + if indices[0] < 0 : + data = self.data_arr[0 : indices[-1] + 1] + # Prepend nan values to the data to match the step_len + data = np.concatenate([np.full((self.step_len - len(data), *data.shape[1:]), np.nan), data]) + else: + data = self.data_arr[indices[0] : indices[-1] + 1] else: data = self.data_arr[indices] if isinstance(idx, mtit): diff --git a/tests/data_mid_layer_tests/test_dataset.py b/tests/data_mid_layer_tests/test_dataset.py index dc2ec812f1..40085e6a17 100755 --- a/tests/data_mid_layer_tests/test_dataset.py +++ b/tests/data_mid_layer_tests/test_dataset.py @@ -5,8 +5,9 @@ import pytest import sys from qlib.tests import TestAutoData -from qlib.data.dataset import TSDatasetH +from qlib.data.dataset import TSDatasetH, TSDataSampler import numpy as np +import pandas as pd import time from qlib.data.dataset.handler import DataHandlerLP @@ -97,7 +98,56 @@ def testTSDataset(self): print(data.shape) print(idx[i]) - +class TestTSDataSampler(unittest.TestCase): + def test_TSDataSampler(self): + """ + Test TSDataSampler for issue #1716 + """ + datetime_list = [ + '2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31' + ] + instruments = ['000001', '000002', '000003', '000004', '000005'] + index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments], + names=['datetime', 'instrument']) + data = np.random.randn(len(datetime_list) * len(instruments)) + test_df = pd.DataFrame(data=data, index=index, columns=['factor']) + dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2) + print() + print("--------------dataset[0]--------------") + print(dataset[0]) + print("--------------dataset[1]--------------") + print(dataset[1]) + assert len(dataset[0]) == 2 + self.assertTrue(np.isnan(dataset[0][0])) + self.assertEqual(dataset[0][1], dataset[1][0]) + self.assertEqual(dataset[1][1], dataset[2][0]) + self.assertEqual(dataset[2][1], dataset[3][0]) + + def test_TSDataSampler2(self): + """ + Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front + """ + datetime_list = [ + '2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31' + ] + instruments = ['000001', '000002', '000003', '000004', '000005'] + index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments], + names=['datetime', 'instrument']) + data = np.random.randn(len(datetime_list) * len(instruments)) + test_df = pd.DataFrame(data=data, index=index, columns=['factor']) + dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3) + print() + print("--------------dataset[0]--------------") + print(dataset[0]) + print("--------------dataset[1]--------------") + print(dataset[1]) + for i in range(3): + self.assertFalse(np.isnan(dataset[0][i])) + self.assertFalse(np.isnan(dataset[1][i])) + #断言dataset[0][1]等于dataset[1][0] + self.assertEqual(dataset[0][1], dataset[1][0]) + #断言dataset[0][2]等于dataset[1][1] + self.assertEqual(dataset[0][2], dataset[1][1]) if __name__ == "__main__": unittest.main(verbosity=10)