diff --git a/qlib/data/data.py b/qlib/data/data.py index e9e0c803da..bbfaeb9c12 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -835,8 +835,15 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time, # keep only the latest period value df_remain = df_remain.sort_values(by=["period"]).drop_duplicates(subset=["period"], keep="last") df_remain = df_remain.set_index("period") - - cache_key = (instrument, field, last_observe_date, start_offset, end_offset, quarterly) # f"{instrument}.{field}.{last_observe_date}.{start_offset}.{end_offset}.{quarterly}" + + cache_key = ( + instrument, + field, + last_observe_date, + start_offset, + end_offset, + quarterly, + ) # f"{instrument}.{field}.{last_observe_date}.{start_offset}.{end_offset}.{quarterly}" if cache_key in H["p"]: retur = H["p"][cache_key] else: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index b04b459f1d..66aa2ef07d 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -42,9 +42,7 @@ #################### Server #################### def get_redis_connection(): """get redis connection instance.""" - return redis.StrictRedis( - host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password - ) + return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password) #################### Data #################### @@ -95,9 +93,7 @@ def get_period_list(first: int, last: int, quarterly: bool) -> List[int]: return res -def get_period_list_by_offset( - last: int, start_offset: int, end_offset: int, quarterly: bool -) -> List[int]: +def get_period_list_by_offset(last: int, start_offset: int, end_offset: int, quarterly: bool) -> List[int]: """ This method will be used in PIT database. It return all the possible values between `first(offset-last)` and `end` (first and end is included) @@ -122,9 +118,7 @@ def get_period_list_by_offset( assert all(190000 <= x <= 209904 for x in (last,)), "invalid arguments" res = [] # last minus offset quarters - for year in range( - int(last // 100 + start_offset // 4 - 1), int(last // 100 + 1) + end_offset - ): + for year in range(int(last // 100 + start_offset // 4 - 1), int(last // 100 + 1) + end_offset): for q in range(1, 5): period = year * 100 + q if period <= last: @@ -140,9 +134,7 @@ def get_period_offset(first_year, period, quarterly): return offset -def read_period_data( - index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None -): +def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None): """ At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803). Only the updating info before cur_date or at cur_date will be used. @@ -193,9 +185,7 @@ def read_period_data( with open(data_path, "rb") as fd: while _next != NAN_INDEX: fd.seek(_next) - date, period, value, new_next = struct.unpack( - DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE)) - ) + date, period, value, new_next = struct.unpack(DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE))) if date > cur_date_int: break prev_next = _next @@ -431,9 +421,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): return calendar -def get_date_by_shift( - trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None -): +def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None): """get trading date with shift bias will cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -466,9 +454,7 @@ def get_date_by_shift( if clip_shift: shift_index = np.clip(shift_index, 0, len(cal) - 1) else: - raise IndexError( - f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range" - ) + raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range") return cal[shift_index] @@ -505,11 +491,7 @@ def transform_end_date(end_date=None, freq="day"): from ..data import D # pylint: disable=C0415 last_date = D.calendar(freq=freq)[-1] - if ( - end_date is None - or (str(end_date) == "-1") - or (pd.Timestamp(last_date) < pd.Timestamp(end_date)) - ): + if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)): log.warning( "\nInfo: the end_date in the configuration file is {}, " "so the default last date {} is used.".format(end_date, last_date) @@ -625,9 +607,7 @@ def exists_qlib_data(qlib_dir): # check instruments code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") - miss_code = set( - pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower) - ) - set(code_names) + miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): return False @@ -863,9 +843,7 @@ def register(self, provider): self._provider = provider def __repr__(self): - return "{name}(provider={provider})".format( - name=self.__class__.__name__, provider=self._provider - ) + return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider) def __getattr__(self, key): if self.__dict__.get("_provider", None) is None: