diff --git a/tests/data_mid_layer_tests/test_dataset.py b/tests/data_mid_layer_tests/test_dataset.py index 055e603f3a..9eb2083aa7 100755 --- a/tests/data_mid_layer_tests/test_dataset.py +++ b/tests/data_mid_layer_tests/test_dataset.py @@ -98,19 +98,19 @@ def testTSDataset(self): print(data.shape) print(idx[i]) + class TestTSDataSampler(unittest.TestCase): def test_TSDataSampler(self): """ Test TSDataSampler for issue #1716 """ - datetime_list = [ - '2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31' - ] - instruments = ['000001', '000002', '000003', '000004', '000005'] - index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments], - names=['datetime', 'instrument']) + datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + instruments = ["000001", "000002", "000003", "000004", "000005"] + index = pd.MultiIndex.from_product( + [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + ) data = np.random.randn(len(datetime_list) * len(instruments)) - test_df = pd.DataFrame(data=data, index=index, columns=['factor']) + test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2) print() print("--------------dataset[0]--------------") @@ -127,14 +127,13 @@ def test_TSDataSampler2(self): """ Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front """ - datetime_list = [ - '2000-01-31', '2000-02-29', '2000-03-31', '2000-04-30', '2000-05-31' - ] - instruments = ['000001', '000002', '000003', '000004', '000005'] - index = pd.MultiIndex.from_product([pd.to_datetime(datetime_list), instruments], - names=['datetime', 'instrument']) + datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + instruments = ["000001", "000002", "000003", "000004", "000005"] + index = pd.MultiIndex.from_product( + [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + ) data = np.random.randn(len(datetime_list) * len(instruments)) - test_df = pd.DataFrame(data=data, index=index, columns=['factor']) + test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3) print() print("--------------dataset[0]--------------") @@ -146,6 +145,8 @@ def test_TSDataSampler2(self): self.assertFalse(np.isnan(dataset[1][i])) self.assertEqual(dataset[0][1], dataset[1][0]) self.assertEqual(dataset[0][2], dataset[1][1]) + + if __name__ == "__main__": unittest.main(verbosity=10)