Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
John Lyu committed Oct 20, 2023
1 parent afff257 commit 6c214aa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 34 deletions.
11 changes: 9 additions & 2 deletions qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,15 @@ def period_feature(self, instrument, field, start_offset, end_offset, cur_time,
# keep only the latest period value
df_remain = df_remain.sort_values(by=["period"]).drop_duplicates(subset=["period"], keep="last")
df_remain = df_remain.set_index("period")

cache_key = (instrument, field, last_observe_date, start_offset, end_offset, quarterly) # f"{instrument}.{field}.{last_observe_date}.{start_offset}.{end_offset}.{quarterly}"

cache_key = (
instrument,
field,
last_observe_date,
start_offset,
end_offset,
quarterly,
) # f"{instrument}.{field}.{last_observe_date}.{start_offset}.{end_offset}.{quarterly}"
if cache_key in H["p"]:
retur = H["p"][cache_key]
else:
Expand Down
42 changes: 10 additions & 32 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
#################### Server ####################
def get_redis_connection():
"""get redis connection instance."""
return redis.StrictRedis(
host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password
)
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)


#################### Data ####################
Expand Down Expand Up @@ -95,9 +93,7 @@ def get_period_list(first: int, last: int, quarterly: bool) -> List[int]:
return res


def get_period_list_by_offset(
last: int, start_offset: int, end_offset: int, quarterly: bool
) -> List[int]:
def get_period_list_by_offset(last: int, start_offset: int, end_offset: int, quarterly: bool) -> List[int]:
"""
This method will be used in PIT database.
It return all the possible values between `first(offset-last)` and `end` (first and end is included)
Expand All @@ -122,9 +118,7 @@ def get_period_list_by_offset(
assert all(190000 <= x <= 209904 for x in (last,)), "invalid arguments"
res = []
# last minus offset quarters
for year in range(
int(last // 100 + start_offset // 4 - 1), int(last // 100 + 1) + end_offset
):
for year in range(int(last // 100 + start_offset // 4 - 1), int(last // 100 + 1) + end_offset):
for q in range(1, 5):
period = year * 100 + q
if period <= last:
Expand All @@ -140,9 +134,7 @@ def get_period_offset(first_year, period, quarterly):
return offset


def read_period_data(
index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None
):
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
"""
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
Only the updating info before cur_date or at cur_date will be used.
Expand Down Expand Up @@ -193,9 +185,7 @@ def read_period_data(
with open(data_path, "rb") as fd:
while _next != NAN_INDEX:
fd.seek(_next)
date, period, value, new_next = struct.unpack(
DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE))
)
date, period, value, new_next = struct.unpack(DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE)))
if date > cur_date_int:
break
prev_next = _next
Expand Down Expand Up @@ -431,9 +421,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar


def get_date_by_shift(
trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None
):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
"""get trading date with shift bias will cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
Expand Down Expand Up @@ -466,9 +454,7 @@ def get_date_by_shift(
if clip_shift:
shift_index = np.clip(shift_index, 0, len(cal) - 1)
else:
raise IndexError(
f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range"
)
raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range")
return cal[shift_index]


Expand Down Expand Up @@ -505,11 +491,7 @@ def transform_end_date(end_date=None, freq="day"):
from ..data import D # pylint: disable=C0415

last_date = D.calendar(freq=freq)[-1]
if (
end_date is None
or (str(end_date) == "-1")
or (pd.Timestamp(last_date) < pd.Timestamp(end_date))
):
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
log.warning(
"\nInfo: the end_date in the configuration file is {}, "
"so the default last date {} is used.".format(end_date, last_date)
Expand Down Expand Up @@ -625,9 +607,7 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
miss_code = set(
pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)
) - set(code_names)
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False

Expand Down Expand Up @@ -863,9 +843,7 @@ def register(self, provider):
self._provider = provider

def __repr__(self):
return "{name}(provider={provider})".format(
name=self.__class__.__name__, provider=self._provider
)
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)

def __getattr__(self, key):
if self.__dict__.get("_provider", None) is None:
Expand Down

0 comments on commit 6c214aa

Please sign in to comment.