From b1e0e77c97b55ae0305d3d8c3192c61db032fcc6 Mon Sep 17 00:00:00 2001
From: Chuan Xu <xuchuan0304@gmail.com>
Date: Fri, 10 May 2024 01:09:39 -0400
Subject: [PATCH] Fix the bug of reading string NA as NaN in the function
 exists_qlib_data. (#1736)

* Fix the bug of reading NA string as NaN in exists_qlib_data.

* Fix the .gitignore file.

* Update the fix and add some comments.

* format with black

---------

Co-authored-by: Chuan Xu <chuan.xu@sas.com>
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
---
 .gitignore             |  2 +-
 qlib/utils/__init__.py | 70 ++++++++++++++++++++++++++++++++++++++----
 2 files changed, 65 insertions(+), 7 deletions(-)

diff --git a/.gitignore b/.gitignore
index 8854c25e99..29ea1cd5e3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -48,4 +48,4 @@ tags
 *.swp
 
 ./pretrain
-.idea/
+.idea/
\ No newline at end of file
diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py
index 9e63c104a1..732638b236 100644
--- a/qlib/utils/__init__.py
+++ b/qlib/utils/__init__.py
@@ -25,7 +25,12 @@
 from pathlib import Path
 from typing import List, Union, Optional, Callable
 from packaging import version
-from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
+from .file import (
+    get_or_create_path,
+    save_multiple_parts_file,
+    unpack_archive_with_buffer,
+    get_tmp_file_with_buffer,
+)
 from ..config import C
 from ..log import get_module_logger, set_log_with_config
 
@@ -37,7 +42,12 @@
 #################### Server ####################
 def get_redis_connection():
     """get redis connection instance."""
-    return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
+    return redis.StrictRedis(
+        host=C.redis_host,
+        port=C.redis_port,
+        db=C.redis_task_db,
+        password=C.redis_password,
+    )
 
 
 #################### Data ####################
@@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly):
     return offset
 
 
-def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
+def read_period_data(
+    index_path,
+    data_path,
+    period,
+    cur_date_int: int,
+    quarterly,
+    last_period_index: int = None,
+):
     """
     At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
     Only the updating info before cur_date or at cur_date will be used.
@@ -273,7 +290,10 @@ def parse_field(field):
     # \uff09 -> )
     chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
     for pattern, new in [
-        (rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'),  # $$ must be before $
+        (
+            rf"\$\$([\w{chinese_punctuation_regex}]+)",
+            r'PFeature("\1")',
+        ),  # $$ must be before $
         (rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
         (r"(\w+\s*)\(", r"Operators.\1("),
     ]:  # Features  # Operators
@@ -383,7 +403,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
     return calendar
 
 
-def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
+def get_date_by_shift(
+    trading_date,
+    shift,
+    future=False,
+    clip_shift=True,
+    freq="day",
+    align: Optional[str] = None,
+):
     """get trading date with shift bias will cur_date
         e.g. : shift == 1,  return next trading date
                shift == -1, return previous trading date
@@ -569,7 +596,38 @@ def exists_qlib_data(qlib_dir):
     # check instruments
     code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
     _instrument = instruments_dir.joinpath("all.txt")
-    miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
+    # Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
+    miss_code = set(
+        pd.read_csv(
+            _instrument,
+            sep="\t",
+            header=None,
+            keep_default_na=False,
+            na_values={
+                0: [
+                    " ",
+                    "#N/A",
+                    "#N/A N/A",
+                    "#NA",
+                    "-1.#IND",
+                    "-1.#QNAN",
+                    "-NaN",
+                    "-nan",
+                    "1.#IND",
+                    "1.#QNAN",
+                    "<NA>",
+                    "N/A",
+                    "NaN",
+                    "None",
+                    "n/a",
+                    "nan",
+                    "null ",
+                ]
+            },
+        )
+        .loc[:, 0]
+        .apply(str.lower)
+    ) - set(code_names)
     if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
         return False