From b1e0e77c97b55ae0305d3d8c3192c61db032fcc6 Mon Sep 17 00:00:00 2001 From: Chuan Xu <xuchuan0304@gmail.com> Date: Fri, 10 May 2024 01:09:39 -0400 Subject: [PATCH] Fix the bug of reading string NA as NaN in the function exists_qlib_data. (#1736) * Fix the bug of reading NA string as NaN in exists_qlib_data. * Fix the .gitignore file. * Update the fix and add some comments. * format with black --------- Co-authored-by: Chuan Xu <chuan.xu@sas.com> Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com> --- .gitignore | 2 +- qlib/utils/__init__.py | 70 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 8854c25e99..29ea1cd5e3 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,4 @@ tags *.swp ./pretrain -.idea/ +.idea/ \ No newline at end of file diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 9e63c104a1..732638b236 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -25,7 +25,12 @@ from pathlib import Path from typing import List, Union, Optional, Callable from packaging import version -from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer +from .file import ( + get_or_create_path, + save_multiple_parts_file, + unpack_archive_with_buffer, + get_tmp_file_with_buffer, +) from ..config import C from ..log import get_module_logger, set_log_with_config @@ -37,7 +42,12 @@ #################### Server #################### def get_redis_connection(): """get redis connection instance.""" - return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password) + return redis.StrictRedis( + host=C.redis_host, + port=C.redis_port, + db=C.redis_task_db, + password=C.redis_password, + ) #################### Data #################### @@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly): return offset -def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None): +def read_period_data( + index_path, + data_path, + period, + cur_date_int: int, + quarterly, + last_period_index: int = None, +): """ At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803). Only the updating info before cur_date or at cur_date will be used. @@ -273,7 +290,10 @@ def parse_field(field): # \uff09 -> ) chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09" for pattern, new in [ - (rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $ + ( + rf"\$\$([\w{chinese_punctuation_regex}]+)", + r'PFeature("\1")', + ), # $$ must be before $ (rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'), (r"(\w+\s*)\(", r"Operators.\1("), ]: # Features # Operators @@ -383,7 +403,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): return calendar -def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None): +def get_date_by_shift( + trading_date, + shift, + future=False, + clip_shift=True, + freq="day", + align: Optional[str] = None, +): """get trading date with shift bias will cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -569,7 +596,38 @@ def exists_qlib_data(qlib_dir): # check instruments code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") - miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) + # Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0 + miss_code = set( + pd.read_csv( + _instrument, + sep="\t", + header=None, + keep_default_na=False, + na_values={ + 0: [ + " ", + "#N/A", + "#N/A N/A", + "#NA", + "-1.#IND", + "-1.#QNAN", + "-NaN", + "-nan", + "1.#IND", + "1.#QNAN", + "<NA>", + "N/A", + "NaN", + "None", + "n/a", + "nan", + "null ", + ] + }, + ) + .loc[:, 0] + .apply(str.lower) + ) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): return False