diff --git a/qlib/data/data.py b/qlib/data/data.py index c15ff60885..86ddf893c5 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -47,7 +47,10 @@ class ProviderBackendMixin: def get_default_backend(self): backend = {} - provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] + if hasattr(self, "provider_name"): + provider_name = getattr(self, "provider_name") + else: + provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # set default storage class backend.setdefault("class", f"File{provider_name}Storage") # set default storage module @@ -335,6 +338,10 @@ def feature(self, instrument, field, start_time, end_time, freq): class PITProvider(abc.ABC): + @property + def provider_name(self): + return "PIT" + @abc.abstractmethod def period_feature( self, @@ -741,10 +748,15 @@ def feature(self, instrument, field, start_index, end_index, freq): return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] -class LocalPITProvider(PITProvider): +class LocalPITProvider(PITProvider, ProviderBackendMixin): # TODO: Add PIT backend file storage # NOTE: This class is not multi-threading-safe!!!! + def __init__(self, remote=False, backend={}): + super().__init__() + self.remote = remote + self.backend = backend + def period_feature(self, instrument, field, start_offset, end_offset, cur_time, period=None, start_time=None): """get raw data from PIT we have 3 modes to query data from PIT, all method need current datetime @@ -764,17 +776,11 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time, assert end_offset <= 0 # PIT don't support querying future data - DATA_RECORDS = [ - ("date", C.pit_record_type["date"]), - ("period", C.pit_record_type["period"]), - ("value", C.pit_record_type["value"]), - ("_next", C.pit_record_type["index"]), - ] - VALUE_DTYPE = C.pit_record_type["value"] - field = str(field).lower()[2:] instrument = code_to_fname(instrument) + backend_obj = self.backend_obj(instrument=instrument, field=field) + # {For acceleration # start_index, end_index, cur_index = kwargs["info"] # if cur_index == start_index: @@ -803,8 +809,8 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time, ## so we cannot findout the offset by given date ## stop using index in this version # start_point = get_pitdata_offset(index_path, period, ) - data = np.fromfile(data_path, dtype=DATA_RECORDS) - df = pd.DataFrame(data, columns=[i[0] for i in DATA_RECORDS]) + data = backend_obj.np_data() + df = pd.DataFrame(data) df.sort_values(by=["date", "period"], inplace=True) df["date"] = pd.to_datetime(df["date"].astype(str)) H["f"][key] = df @@ -823,7 +829,7 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time, df_sim = df[s_sign].drop_duplicates(subset=["date"], keep="last") s_part = df_sim.set_index("date")[start_time:]["value"] if s_part.empty: - return pd.Series(dtype=VALUE_DTYPE) + return pd.Series(index=backend_obj.columns, dtype="float64") if start_time != s_part.index[0] and start_time >= df["date"].iloc[0]: # add previous value to result to avoid nan in the first period pre_value = pd.Series(df[df["date"] < start_time]["value"].iloc[-1], index=[start_time]) @@ -832,7 +838,7 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time, else: df_remain = df[(df["date"] <= cur_time)] if df_remain.empty: - return pd.Series(dtype=VALUE_DTYPE) + return pd.Series(index=backend_obj.columns, dtype="float64") last_observe_date = df_remain["date"].iloc[-1] # keep only the latest period value df_remain = df_remain.sort_values(by=["period"]).drop_duplicates(subset=["period"], keep="last") diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 8a100a2d19..2d36fe3bef 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -7,13 +7,21 @@ import numpy as np import pandas as pd +from qlib.data.storage.storage import PITStorage from qlib.utils.time import Freq from qlib.utils.resam import resam_calendar from qlib.config import C from qlib.data.cache import H from qlib.log import get_module_logger -from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT +from qlib.data.storage import ( + CalendarStorage, + InstrumentStorage, + FeatureStorage, + CalVT, + InstKT, + InstVT, +) logger = get_module_logger("file_storage") @@ -48,7 +56,10 @@ def support_freq(self) -> List[str]: if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri: freq_l = filter( lambda _freq: not _freq.endswith("_future"), - map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt")), + map( + lambda x: x.stem, + self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt"), + ), ) else: freq_l = self.provider_uri.keys() @@ -140,7 +151,10 @@ def data(self) -> List[CalVT]: _calendar = self._read_calendar() if Freq(self._freq_file) != Freq(self.freq): _calendar = resam_calendar( - np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq, self.region + np.array(list(map(pd.Timestamp, _calendar))), + self._freq_file, + self.freq, + self.region, ) return _calendar @@ -287,6 +301,7 @@ def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" + self._start_index = None def clear(self): with self.uri.open("wb") as _: @@ -303,6 +318,7 @@ def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None: "if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" ) return + self._start_index = None if not self.uri.exists(): # write index = 0 if index is None else index @@ -320,7 +336,9 @@ def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None: _old_data = np.fromfile(fp, dtype=" None: def start_index(self) -> Union[int, None]: if not self.uri.exists(): return None - with self.uri.open("rb") as fp: - index = int(np.frombuffer(fp.read(4), dtype=" Union[int, None]: @@ -377,3 +396,179 @@ def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Serie def __len__(self) -> int: self.check() return self.uri.stat().st_size // 4 - 1 + + +class FilePITStorage(FileStorageMixin, PITStorage): + """PIT data is a special case of Feature data, it looks like + + date period value _next + 0 20070428 200701 0.090219 4294967295 + 1 20070817 200702 0.139330 4294967295 + 2 20071023 200703 0.245863 4294967295 + 3 20080301 200704 0.347900 80 + 4 20080313 200704 0.395989 4294967295 + + It is sorted by [date, period]. + + next field currently is not used. just for forward compatible. + """ + + # NOTE: + # PIT data should have two files, one is the index file, the other is the data file. + + # pesudo code: + # date_index = calendar.index(date) + # data_start_index, data_end_index = index_file[date_index] + # data = data_file[data_start_index:data_end_index] + + # the index file is like feature's data file, but given a start index in index file, it will return the first and the last observe index of the data file. + # the data file has tree columns, the first column is observe date, the second column is financial period, the third column is the value. + + # so given start and end date, we can get the start_index and end_index from calendar. + # use it to read two line from index file, then we can get the start and end index of the data file. + + # but consider this implementation, we will create a index file which will have 50 times lines than the data file. Is it a good idea? + # if we just create a index file the same line with data file, we have to read the whole index file for any time slice search, so why not read whole data file? + + def __init__(self, instrument: str, field: str, freq: str = "day", provider_uri: dict = None, **kwargs): + super(FilePITStorage, self).__init__(instrument, field, freq, **kwargs) + + if not field.endswith("_q") and not field.endswith("_a"): + raise ValueError("period field must ends with '_q' or '_a'") + self.quarterly = field.endswith("_q") + + self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) + self.file_name = f"{instrument.lower()}/{field.lower()}.data" + self.raw_dtype = [ + ("date", C.pit_record_type["date"]), + ("period", C.pit_record_type["period"]), + ("value", C.pit_record_type["value"]), + ("_next", C.pit_record_type["index"]), # not used in current implementation + ] + self.dtypes = np.dtype(self.raw_dtype) + self.itemsize = self.dtypes.itemsize + self.dtype_string = "".join([i[1] for i in self.raw_dtype]) + self.columns = [i[0] for i in self.raw_dtype] + + @property + def uri(self) -> Path: + if self.freq not in self.support_freq: + raise ValueError(f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}") + return self.dpm.get_data_uri(self.freq).joinpath(f"{self.storage_name}", self.file_name) + + def clear(self): + with self.uri.open("wb") as _: + pass + + @property + def data(self) -> pd.DataFrame: + return self[:] + + def update(self, data_array: np.ndarray) -> None: + """update data to storage, replace current data from start_date to end_date with given data_array + + Args: + data_array: Structured arrays contains date, period, value and next. same with self.raw_dtype + """ + if not self.uri.exists(): + # write + index = 0 + else: + # sort it + data_array = np.sort(data_array, order=["date", "period"]) + # get index + update_start_date = data_array[0][0] + update_end_date = data_array[-1][0] + current_data = self.np_data() + index = (current_data["date"] >= update_start_date).argmax() + end_index = (current_data["date"] > update_end_date).argmax() + new_data = np.concatenate([data_array, current_data[end_index:]]) + self.write(new_data, index) + + def write(self, data_array: np.ndarray, index: int = None) -> None: + """write data to storage at specific index + + Args: + data_array: Structured arrays contains date, period, value and next + index: _description_. Defaults to None. + """ + + if len(data_array) == 0: + logger.info( + "len(data_array) == 0, write" + "if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + + # sort data_array with first 2 columns + data_array = np.sort(data_array, order=["date", "period"]) + + if not self.uri.exists(): + # write + index = 0 if index is None else index + with self.uri.open("wb") as fp: + data_array.tofile(self.uri) + else: + with self.uri.open("rb+") as fp: + fp.seek(index * self.itemsize) + data_array.tofile(fp) + + @property + def start_index(self) -> Union[int, None]: + return 0 + + @property + def end_index(self) -> Union[int, None]: + if not self.uri.exists(): + return None + # The next data appending index point will be `end_index + 1` + return self.start_index + len(self) - 1 + + def np_data(self, i: Union[int, slice] = None) -> np.ndarray: + if not self.uri.exists(): + if isinstance(i, int): + return None, None + elif isinstance(i, slice): + return pd.Series(dtype=np.float32) + else: + raise TypeError(f"type(i) = {type(i)}") + + if i is None: + i = slice(None, None) + storage_start_index = self.start_index + storage_end_index = self.end_index + with self.uri.open("rb") as fp: + if isinstance(i, int): + if storage_start_index > i: + raise IndexError(f"{i}: start index is {storage_start_index}") + fp.seek(i * self.itemsize) + return np.array([struct.unpack(self.dtype_string, fp.read(self.itemsize))], dtype=self.dtypes) + elif isinstance(i, slice): + start_index = storage_start_index if i.start is None else i.start + end_index = storage_end_index if i.stop is None else i.stop - 1 + si = max(start_index, storage_start_index) + if si > end_index: + return pd.Series(dtype=np.float32) + fp.seek(start_index * self.itemsize) + # read n bytes + count = end_index - si + 1 + data = np.frombuffer(fp.read(self.itemsize * count), dtype=self.dtypes) + return data + else: + raise TypeError(f"type(i) = {type(i)}") + + def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.DataFrame]: + if isinstance(i, int): + return pd.Series(self.np_data(i), index=self.columns, name=i) + elif isinstance(i, slice): + data = self.np_data(i) + si = self.start_index if i.start is None else i.start + if si < 0: + si = len(self) + si + return pd.DataFrame(data, index=pd.RangeIndex(si, si + len(data)), columns=self.columns) + else: + raise TypeError(f"type(i) = {type(i)}") + + def __len__(self) -> int: + self.check() + return self.uri.stat().st_size // self.itemsize diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py index 2eb7da1de6..0d0ee0e7eb 100644 --- a/qlib/data/storage/storage.py +++ b/qlib/data/storage/storage.py @@ -492,3 +492,129 @@ def __len__(self) -> int: """ raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") + + +class PITStorage(FeatureStorage): + @property + def storage_name(self) -> str: + return "financial" # for compatibility + + @property + def data(self) -> pd.DataFrame: + """get all data + + dataframe index is date, columns are report_period and value + + Notes + ------ + if data(storage) does not exist, return empty pd.DataFrame: `return pd.DataFrame(dtype=np.float32)` + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + + def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): + """Write data_array to FeatureStorage starting from index. + + Notes + ------ + If index is None, append data_array to feature. + + If len(data_array) == 0; return + + If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan + + Examples + --------- + .. code-block:: + + feature: + 3 4 + 4 5 + 5 6 + + + >>> self.write([6, 7], index=6) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + + >>> self.write([8], index=9) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + >>> self.write([1, np.nan], index=3) + + feature: + 3 1 + 4 np.nan + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `write` method") + + def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int): + """overwrite all data in FeatureStorage with data + + Parameters + ---------- + data: Union[List, np.ndarray, Tuple] + data + index: int + data start index + """ + self.clear() + self.write(data, index) + + @overload + def __getitem__(self, s: slice) -> pd.Series: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + """ + + @overload + def __getitem__(self, i: int) -> Tuple[int, float]: + """x.__getitem__(y) <==> x[y]""" + + def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]: + """x.__getitem__(y) <==> x[y] + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + """ + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") diff --git a/tests/test_pit.py b/tests/test_pit.py index 8320e1d361..26655b85ab 100644 --- a/tests/test_pit.py +++ b/tests/test_pit.py @@ -3,6 +3,8 @@ import sys + +import numpy as np import qlib import shutil import unittest @@ -12,6 +14,7 @@ from pathlib import Path from qlib.data import D +from qlib.data.storage.file_storage import FilePITStorage from qlib.tests.data import GetData sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) @@ -32,37 +35,37 @@ class TestPIT(unittest.TestCase): - @classmethod - def tearDownClass(cls) -> None: - shutil.rmtree(str(DATA_DIR.resolve())) - - @classmethod - def setUpClass(cls) -> None: - cn_data_dir = str(QLIB_DIR.joinpath("cn_data").resolve()) - pit_dir = str(SOURCE_DIR.joinpath("pit").resolve()) - pit_normalized_dir = str(SOURCE_DIR.joinpath("pit_normalized").resolve()) - GetData().qlib_data( - name="qlib_data_simple", target_dir=cn_data_dir, region="cn", delete_old=False, exists_skip=True - ) - GetData().qlib_data(name="qlib_data", target_dir=pit_dir, region="pit", delete_old=False, exists_skip=True) - - # NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data. - # bs.login() - # Run( - # source_dir=pit_dir, - # interval="quarterly", - # ).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*") - # bs.logout() - - Run( - source_dir=pit_dir, - normalize_dir=pit_normalized_dir, - interval="quarterly", - ).normalize_data() - DumpPitData( - csv_path=pit_normalized_dir, - qlib_dir=cn_data_dir, - ).dump(interval="quarterly") + # @classmethod + # def tearDownClass(cls) -> None: + # shutil.rmtree(str(DATA_DIR.resolve())) + + # @classmethod + # def setUpClass(cls) -> None: + # cn_data_dir = str(QLIB_DIR.joinpath("cn_data").resolve()) + # pit_dir = str(SOURCE_DIR.joinpath("pit").resolve()) + # pit_normalized_dir = str(SOURCE_DIR.joinpath("pit_normalized").resolve()) + # GetData().qlib_data( + # name="qlib_data_simple", target_dir=cn_data_dir, region="cn", delete_old=False, exists_skip=True + # ) + # GetData().qlib_data(name="qlib_data", target_dir=pit_dir, region="pit", delete_old=False, exists_skip=True) + + # # NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data. + # # bs.login() + # # Run( + # # source_dir=pit_dir, + # # interval="quarterly", + # # ).download_data(start="2000-01-01", end="2020-01-01", symbol_regex="^(600519|000725).*") + # # bs.logout() + + # Run( + # source_dir=pit_dir, + # normalize_dir=pit_normalized_dir, + # interval="quarterly", + # ).normalize_data() + # DumpPitData( + # csv_path=pit_normalized_dir, + # qlib_dir=cn_data_dir, + # ).dump(interval="quarterly") def setUp(self): # qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier @@ -70,11 +73,84 @@ def setUp(self): qlib.init(provider_uri=provider_uri) def to_str(self, obj): - return "".join(str(obj).split()) + return "\n".join(str(obj).split()) def check_same(self, a, b): self.assertEqual(self.to_str(a), self.to_str(b)) + def test_storage_read(self): + s = FilePITStorage("sh600519", "roewa_q") + np_data = s.np_data(1) + self.assertEqual(np_data.shape, (1,)) + data = s.data + self.check_same( + data.head(), + """ + date period value _next + 0 20070428 200701 0.090219 4294967295 + 1 20070817 200702 0.139330 4294967295 + 2 20071023 200703 0.245863 4294967295 + 3 20080301 200704 0.347900 80 + 4 20080313 200704 0.395989 4294967295 + """, + ) + + def test_storage_write(self): + base = FilePITStorage("sh600519", "roewa_q") + s = FilePITStorage("sh600519", "roewa2_q") + + shutil.copy(base.uri, s.uri) + s.write( + np.array([(20070917, 200703, 0.239330, 0)], dtype=s.raw_dtype), + 1, + ) + data = s.data + self.check_same( + data.head(), + """ + date period value _next + 0 20070428 200701 0.090219 4294967295 + 1 20070917 200703 0.239330 0 + 2 20071023 200703 0.245863 4294967295 + 3 20080301 200704 0.347900 80 + 4 20080313 200704 0.395989 4294967295 + """, + ) + + def test_storage_slice(self): + s = FilePITStorage("sh600519", "roewa_q") + data = s[1:4] + self.check_same( + data, + """ + date period value _next + 1 20070817 200702 0.139330 4294967295 + 2 20071023 200703 0.245863 4294967295 + 3 20080301 200704 0.347900 80 + """, + ) + + def test_storage_update(self): + base = FilePITStorage("sh600519", "roewa_q") + s = FilePITStorage("sh600519", "roewa3_q") + + shutil.copy(base.uri, s.uri) + s.update( + np.array([(20070917, 200703, 0.111111, 0), (20100314, 200703, 0.111111, 0)], dtype=s.raw_dtype), + ) + data = s.data + self.check_same( + data.head(), + """ + date period value _next + 0 20070428 200701 0.090219 4294967295 + 1 20070817 200702 0.139330 4294967295 + 2 20070917 200703 0.111111 0 + 3 20100314 200703 0.111111 0 + 4 20100402 200904 0.335461 4294967295 + """, + ) + def test_query(self): instruments = ["sh600519"] fields = ["P($$roewa_q)", "P($$yoyni_q)"] @@ -107,7 +183,13 @@ def test_query(self): def test_no_exist_data(self): fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"] - data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") + data = D.features( + ["sh600519", "sh601988"], + fields, + start_time="2019-01-01", + end_time="2019-07-19", + freq="day", + ) data["$close"] = 1 # in case of different dataset gives different values expect = """ P($$roewa_q) P($$yoyni_q) $close