diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 113f9802d7..6c4525ce36 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -108,6 +108,12 @@ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]): self.index_map = self.idx_list = np.arange(idx_list) self._is_sorted = True else: + # Check if all elements in idx_list are of the same type + if not all(isinstance(x, type(idx_list[0])) for x in idx_list): + raise TypeError("All elements in idx_list must be of the same type") + # Check if all elements in idx_list are of the same datetime64 precision + if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list): + raise TypeError("All elements in idx_list must be of the same datetime64 precision") self.idx_list = np.array(idx_list) # NOTE: only the first appearance is indexed self.index_map = dict(zip(self.idx_list, range(len(self)))) @@ -131,7 +137,12 @@ def _convert_type(self, item): if self.idx_list.dtype.type is np.datetime64: if isinstance(item, pd.Timestamp): # This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp - return item.to_numpy() + return item.to_numpy().astype(self.idx_list.dtype) + elif isinstance(item, np.datetime64): + # This happens often when creating index based on np.datetime64 and query with another precision + return item.astype(self.idx_list.dtype) + # NOTE: It is hard to consider every case at first. + # We just try to cover part of cases to make it more user-friendly return item def index(self, item) -> int: diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 2db644f8a6..b3045a5c7f 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -94,6 +94,24 @@ def test_corner_cases(self): print(sd) self.assertTrue(sd.iloc[0] == 2) + # test different precisions of time data + timeindex = [ + np.datetime64("2024-06-22T00:00:00.000000000"), + np.datetime64("2024-06-21T00:00:00.000000000"), + np.datetime64("2024-06-20T00:00:00.000000000"), + ] + sd = idd.SingleData([1, 2, 3], index=timeindex) + self.assertTrue( + sd.index.index(np.datetime64("2024-06-21T00:00:00.000000000")) + == sd.index.index(np.datetime64("2024-06-21T00:00:00")) + ) + self.assertTrue(sd.index.index(pd.Timestamp("2024-06-21 00:00")) == 1) + + # Bad case: the input is not aligned + timeindex[1] = (np.datetime64("2024-06-21T00:00:00.00"),) + with self.assertRaises(TypeError): + sd = idd.SingleData([1, 2, 3], index=timeindex) + def test_ops(self): sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])