diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index aacd58389a..0b6c552a37 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -403,7 +403,7 @@ def __init__( np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0, ) - self.nan_idx = -1 # The last line is all NaN + self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716 # the data type will be changed # The index of usable data is between start_idx and end_idx diff --git a/tests/data_mid_layer_tests/test_dataset.py b/tests/data_mid_layer_tests/test_dataset.py index dc2ec812f1..9eb2083aa7 100755 --- a/tests/data_mid_layer_tests/test_dataset.py +++ b/tests/data_mid_layer_tests/test_dataset.py @@ -5,8 +5,9 @@ import pytest import sys from qlib.tests import TestAutoData -from qlib.data.dataset import TSDatasetH +from qlib.data.dataset import TSDatasetH, TSDataSampler import numpy as np +import pandas as pd import time from qlib.data.dataset.handler import DataHandlerLP @@ -98,6 +99,54 @@ def testTSDataset(self): print(idx[i]) +class TestTSDataSampler(unittest.TestCase): + def test_TSDataSampler(self): + """ + Test TSDataSampler for issue #1716 + """ + datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + instruments = ["000001", "000002", "000003", "000004", "000005"] + index = pd.MultiIndex.from_product( + [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + ) + data = np.random.randn(len(datetime_list) * len(instruments)) + test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) + dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2) + print() + print("--------------dataset[0]--------------") + print(dataset[0]) + print("--------------dataset[1]--------------") + print(dataset[1]) + assert len(dataset[0]) == 2 + self.assertTrue(np.isnan(dataset[0][0])) + self.assertEqual(dataset[0][1], dataset[1][0]) + self.assertEqual(dataset[1][1], dataset[2][0]) + self.assertEqual(dataset[2][1], dataset[3][0]) + + def test_TSDataSampler2(self): + """ + Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front + """ + datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + instruments = ["000001", "000002", "000003", "000004", "000005"] + index = pd.MultiIndex.from_product( + [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + ) + data = np.random.randn(len(datetime_list) * len(instruments)) + test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) + dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3) + print() + print("--------------dataset[0]--------------") + print(dataset[0]) + print("--------------dataset[1]--------------") + print(dataset[1]) + for i in range(3): + self.assertFalse(np.isnan(dataset[0][i])) + self.assertFalse(np.isnan(dataset[1][i])) + self.assertEqual(dataset[0][1], dataset[1][0]) + self.assertEqual(dataset[0][2], dataset[1][1]) + + if __name__ == "__main__": unittest.main(verbosity=10)