From 8efc8b92ef1ec4abb27cc70502a19aafc383c023 Mon Sep 17 00:00:00 2001 From: Chauncey <32626585+Chaoyingz@users.noreply.github.com> Date: Fri, 18 Mar 2022 21:51:36 +0800 Subject: [PATCH] Optimize the pit collector script (#982) * Optimize the pit collector script * Add copyright notice to collector.py * Remove unnecessary parameters for test_pit.py * Update test_pit.py * Update test_pit.py --- scripts/data_collector/pit/README.md | 15 +- scripts/data_collector/pit/collector.py | 466 ++++++++++-------------- scripts/data_collector/pit/test_pit.py | 33 +- 3 files changed, 211 insertions(+), 303 deletions(-) diff --git a/scripts/data_collector/pit/README.md b/scripts/data_collector/pit/README.md index e18dcd0c17..f7b4f9fbe6 100644 --- a/scripts/data_collector/pit/README.md +++ b/scripts/data_collector/pit/README.md @@ -16,12 +16,18 @@ pip install -r requirements.txt ```bash cd qlib/scripts/data_collector/pit/ # download from baostock.com -python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly +python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly ``` Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below -``` bash -python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*" +```bash +python collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*" +``` + + +### Normalize Data +```bash +python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized ``` @@ -30,6 +36,5 @@ python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --en ```bash cd qlib/scripts -# data_collector/pit/csv_pit is the data you download just now. -python dump_pit.py dump --csv_path data_collector/pit/csv_pit --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly +python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly ``` diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py index c1e811bbdf..385de0f180 100644 --- a/scripts/data_collector/pit/collector.py +++ b/scripts/data_collector/pit/collector.py @@ -2,71 +2,69 @@ # Licensed under the MIT License. import re -import sys -import datetime +from datetime import datetime from pathlib import Path +from typing import List, Iterable, Optional, Union import fire -import numpy as np import pandas as pd import baostock as bs from loguru import logger -CUR_DIR = Path(__file__).resolve().parent -sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.base import BaseCollector, BaseRun -from data_collector.utils import get_calendar_list, get_hs_stock_symbols +from scripts.data_collector.base import BaseCollector, BaseRun, BaseNormalize +from scripts.data_collector.utils import get_hs_stock_symbols, get_calendar_list +BASE_DIR = Path(__file__).resolve().parent.parent -class PitCollector(BaseCollector): - DEFAULT_START_DATETIME_QUARTER = pd.Timestamp("2000-01-01") +class PitCollector(BaseCollector): + DEFAULT_START_DATETIME_QUARTERLY = pd.Timestamp("2000-01-01") DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01") - DEFAULT_END_DATETIME_QUARTER = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + DEFAULT_END_DATETIME_QUARTERLY = pd.Timestamp(datetime.now() + pd.Timedelta(days=1)) + DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.now() + pd.Timedelta(days=1)) - INTERVAL_quarterly = "quarterly" - INTERVAL_annual = "annual" + INTERVAL_QUARTERLY = "quarterly" + INTERVAL_ANNUAL = "annual" def __init__( self, - save_dir: [str, Path], - start=None, - end=None, - interval="quarterly", - max_workers=1, - max_collector_count=1, - delay=0, + save_dir: Union[str, Path], + start: Optional[str] = None, + end: Optional[str] = None, + interval: str = "quarterly", + max_workers: int = 1, + max_collector_count: int = 1, + delay: int = 0, check_data_length: bool = False, - limit_nums: int = None, - symbol_flt_regx=None, + limit_nums: Optional[int] = None, + symbol_regex: Optional[str] = None, ): """ - Parameters ---------- save_dir: str - pit save dir - interval: str: - value from ['quarterly', 'annual'] + instrument save dir max_workers: int - workers, default 1 + workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 max_collector_count: int - default 1 + default 2 delay: float time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d start: str start datetime, default None end: str end datetime, default None + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. limit_nums: int using for debug, by default None + symbol_regex: str + symbol regular expression, by default None. """ - if symbol_flt_regx is None: - self.symbol_flt_regx = None - else: - self.symbol_flt_regx = re.compile(symbol_flt_regx) - super(PitCollector, self).__init__( + self.symbol_regex = symbol_regex + super().__init__( save_dir=save_dir, start=start, end=end, @@ -78,186 +76,119 @@ def __init__( limit_nums=limit_nums, ) - def normalize_symbol(self, symbol): - symbol_s = symbol.split(".") - symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" - return symbol - - def get_instrument_list(self): + def get_instrument_list(self) -> List[str]: logger.info("get cn stock symbols......") symbols = get_hs_stock_symbols() - logger.info(f"get {symbols[:10]}[{len(symbols)}] symbols.") - if self.symbol_flt_regx is not None: - s_flt = [] - for s in symbols: - m = self.symbol_flt_regx.match(s) - if m is not None: - s_flt.append(s) - logger.info(f"after filtering, it becomes {s_flt[:10]}[{len(s_flt)}] symbols") - return s_flt - + if self.symbol_regex is not None: + regex_compile = re.compile(self.symbol_regex) + symbols = [symbol for symbol in symbols if regex_compile.match(symbol)] + logger.info(f"get {len(symbols)} symbols.") return symbols - def _get_data_from_baostock(self, symbol, interval, start_datetime, end_datetime): - error_msg = f"{symbol}-{interval}-{start_datetime}-{end_datetime}" - - def _str_to_float(r): - try: - return float(r) - except Exception as e: - return np.nan - + def normalize_symbol(self, symbol: str) -> str: + symbol, exchange = symbol.split(".") + exchange = "sh" if exchange == "ss" else "sz" + return f"{exchange}{symbol}" + + @staticmethod + def get_performance_express_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + column_mapping = { + "performanceExpPubDate": "date", + "performanceExpStatDate": "period", + "performanceExpressROEWa": "value", + } + + resp = bs.query_performance_express_report(code=code, start_date=start_date, end_date=end_date) + report_list = [] + while (resp.error_code == "0") and resp.next(): + report_list.append(resp.get_row_data()) + report_df = pd.DataFrame(report_list, columns=resp.fields) try: - code, market = symbol.split(".") - market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list - symbol = f"{market}.{code}" - rs_report = bs.query_performance_express_report( - code=symbol, - start_date=str(start_datetime.date()), - end_date=str(end_datetime.date()), - ) - report_list = [] - while (rs_report.error_code == "0") & rs_report.next(): - report_list.append(rs_report.get_row_data()) - - df_report = pd.DataFrame(report_list, columns=rs_report.fields) - if { - "performanceExpPubDate", - "performanceExpStatDate", - "performanceExpressROEWa", - } <= set(rs_report.fields): - df_report = df_report[ - [ - "performanceExpPubDate", - "performanceExpStatDate", - "performanceExpressROEWa", - ] - ] - df_report.rename( - columns={ - "performanceExpPubDate": "date", - "performanceExpStatDate": "period", - "performanceExpressROEWa": "value", - }, - inplace=True, - ) - df_report["value"] = df_report["value"].apply(lambda r: _str_to_float(r) / 100.0) - df_report["field"] = "roeWa" - - profit_list = [] - for year in range(start_datetime.year - 1, end_datetime.year + 1): - for q_num in range(0, 4): - rs_profit = bs.query_profit_data(code=symbol, year=year, quarter=q_num + 1) - while (rs_profit.error_code == "0") & rs_profit.next(): - row_data = rs_profit.get_row_data() - if "pubDate" in rs_profit.fields: - pub_date = pd.Timestamp(row_data[rs_profit.fields.index("pubDate")]) - if pub_date >= start_datetime and pub_date <= end_datetime: - profit_list.append(row_data) - - df_profit = pd.DataFrame(profit_list, columns=rs_profit.fields) - if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields): - df_profit = df_profit[["pubDate", "statDate", "roeAvg"]] - df_profit.rename( - columns={ - "pubDate": "date", - "statDate": "period", - "roeAvg": "value", - }, - inplace=True, - ) - df_profit["value"] = df_profit["value"].apply(_str_to_float) - df_profit["field"] = "roeWa" - - forecast_list = [] - rs_forecast = bs.query_forecast_report( - code=symbol, - start_date=str(start_datetime.date()), - end_date=str(end_datetime.date()), - ) - - while (rs_forecast.error_code == "0") & rs_forecast.next(): - forecast_list.append(rs_forecast.get_row_data()) - - df_forecast = pd.DataFrame(forecast_list, columns=rs_forecast.fields) - if { - "profitForcastExpPubDate", - "profitForcastExpStatDate", - "profitForcastChgPctUp", - "profitForcastChgPctDwn", - } <= set(rs_forecast.fields): - df_forecast = df_forecast[ - [ - "profitForcastExpPubDate", - "profitForcastExpStatDate", - "profitForcastChgPctUp", - "profitForcastChgPctDwn", - ] - ] - df_forecast.rename( - columns={ - "profitForcastExpPubDate": "date", - "profitForcastExpStatDate": "period", - }, - inplace=True, - ) - - df_forecast["profitForcastChgPctUp"] = df_forecast["profitForcastChgPctUp"].apply(_str_to_float) - df_forecast["profitForcastChgPctDwn"] = df_forecast["profitForcastChgPctDwn"].apply(_str_to_float) - df_forecast["value"] = ( - df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"] - ) / 200 - df_forecast["field"] = "YOYNI" - df_forecast.drop( - ["profitForcastChgPctUp", "profitForcastChgPctDwn"], - axis=1, - inplace=True, - ) - - growth_list = [] - for year in range(start_datetime.year - 1, end_datetime.year + 1): - for q_num in range(0, 4): - rs_growth = bs.query_growth_data(code=symbol, year=year, quarter=q_num + 1) - while (rs_growth.error_code == "0") & rs_growth.next(): - row_data = rs_growth.get_row_data() - if "pubDate" in rs_growth.fields: - pub_date = pd.Timestamp(row_data[rs_growth.fields.index("pubDate")]) - if pub_date >= start_datetime and pub_date <= end_datetime: - growth_list.append(row_data) - - df_growth = pd.DataFrame(growth_list, columns=rs_growth.fields) - if {"pubDate", "statDate", "YOYNI"} <= set(rs_growth.fields): - df_growth = df_growth[["pubDate", "statDate", "YOYNI"]] - df_growth.rename( - columns={"pubDate": "date", "statDate": "period", "YOYNI": "value"}, - inplace=True, - ) - df_growth["value"] = df_growth["value"].apply(_str_to_float) - df_growth["field"] = "YOYNI" - df_merge = df_report.append([df_profit, df_forecast, df_growth]) - - return df_merge - except Exception as e: - logger.warning(f"{error_msg}:{e}") - - def _process_data(self, df, symbol, interval): - error_msg = f"{symbol}-{interval}" - - def _process_period(r): - _date = pd.Timestamp(r) - return _date.year if interval == self.INTERVAL_annual else _date.year * 100 + (_date.month - 1) // 3 + 1 - + report_df = report_df[list(column_mapping.keys())] + except KeyError: + return pd.DataFrame() + report_df.rename(columns=column_mapping, inplace=True) + report_df["field"] = "roeWa" + report_df["value"] = pd.to_numeric(report_df["value"], errors="ignore") + report_df["value"] = report_df["value"].apply(lambda x: x / 100.0) + return report_df + + @staticmethod + def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + column_mapping = {"pubDate": "date", "statDate": "period", "roeAvg": "value"} + fields = bs.query_profit_data(code="sh.600519", year=2020, quarter=1).fields + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)] + profit_list = [] + for year, quarter in args: + resp = bs.query_profit_data(code=code, year=year, quarter=quarter) + while (resp.error_code == "0") and resp.next(): + if "pubDate" not in resp.fields: + continue + row_data = resp.get_row_data() + pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")]) + if start_date <= pub_date <= end_date and row_data: + profit_list.append(row_data) + profit_df = pd.DataFrame(profit_list, columns=fields) + try: + profit_df = profit_df[list(column_mapping.keys())] + except KeyError: + return pd.DataFrame() + profit_df.rename(columns=column_mapping, inplace=True) + profit_df["field"] = "roeWa" + profit_df["value"] = pd.to_numeric(profit_df["value"], errors="ignore") + return profit_df + + @staticmethod + def get_forecast_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + column_mapping = { + "profitForcastExpPubDate": "date", + "profitForcastExpStatDate": "period", + "value": "value", + } + resp = bs.query_forecast_report(code=code, start_date=start_date, end_date=end_date) + forecast_list = [] + while (resp.error_code == "0") and resp.next(): + forecast_list.append(resp.get_row_data()) + forecast_df = pd.DataFrame(forecast_list, columns=resp.fields) + numeric_fields = ["profitForcastChgPctUp", "profitForcastChgPctDwn"] try: - _date = df["period"].apply( - lambda x: ( - pd.to_datetime(x) + pd.DateOffset(days=(45 if interval == self.INTERVAL_quarterly else 90)) - ).date() - ) - df["date"] = df["date"].fillna(_date.astype(str)) - df["period"] = df["period"].apply(_process_period) - return df - except Exception as e: - logger.warning(f"{error_msg}:{e}") + forecast_df[numeric_fields] = forecast_df[numeric_fields].apply(pd.to_numeric, errors="ignore") + except KeyError: + return pd.DataFrame() + forecast_df["value"] = (forecast_df["profitForcastChgPctUp"] + forecast_df["profitForcastChgPctDwn"]) / 200 + forecast_df = forecast_df[list(column_mapping.keys())] + forecast_df.rename(columns=column_mapping, inplace=True) + forecast_df["field"] = "YOYNI" + return forecast_df + + @staticmethod + def get_growth_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + column_mapping = {"pubDate": "date", "statDate": "period", "YOYNI": "value"} + fields = bs.query_growth_data(code="sh.600519", year=2020, quarter=1).fields + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)] + growth_list = [] + for year, quarter in args: + resp = bs.query_growth_data(code=code, year=year, quarter=quarter) + while (resp.error_code == "0") and resp.next(): + if "pubDate" not in resp.fields: + continue + row_data = resp.get_row_data() + pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")]) + if start_date <= pub_date <= end_date and row_data: + growth_list.append(row_data) + growth_df = pd.DataFrame(growth_list, columns=fields) + try: + growth_df = growth_df[list(column_mapping.keys())] + except KeyError: + return pd.DataFrame() + growth_df.rename(columns=column_mapping, inplace=True) + growth_df["field"] = "YOYNI" + growth_df["value"] = pd.to_numeric(growth_df["value"], errors="ignore") + return growth_df def get_data( self, @@ -265,91 +196,62 @@ def get_data( interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp, - ) -> [pd.DataFrame]: - - if interval == self.INTERVAL_quarterly: - _result = self._get_data_from_baostock(symbol, interval, start_datetime, end_datetime) - if _result is None or _result.empty: - return _result - else: - return self._process_data(_result, symbol, interval) - else: + ) -> pd.DataFrame: + if interval != self.INTERVAL_QUARTERLY: raise ValueError(f"cannot support {interval}") - return self._process_data(_result, interval) - - @property - def min_numbers_trading(self): - pass - + symbol, exchange = symbol.split(".") + exchange = "sh" if exchange == "ss" else "sz" + code = f"{exchange}.{symbol}" + start_date = start_datetime.strftime("%Y-%m-%d") + end_date = end_datetime.strftime("%Y-%m-%d") + + performance_express_report_df = self.get_performance_express_report_df(code, start_date, end_date) + profit_df = self.get_profit_df(code, start_date, end_date) + forecast_report_df = self.get_forecast_report_df(code, start_date, end_date) + growth_df = self.get_growth_df(code, start_date, end_date) + + df = pd.concat( + [performance_express_report_df, profit_df, forecast_report_df, growth_df], + axis=0, + ) + return df -class Run(BaseRun): - def __init__(self, source_dir=None, max_workers=1, interval="quarterly"): - """ - Parameters - ---------- - source_dir: str - The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - max_workers: int - Concurrent number, default is 4 - interval: str - freq, value from [quarterly, annual], default 1d - """ - super().__init__(source_dir=source_dir, max_workers=max_workers, interval=interval) +class PitNormalize(BaseNormalize): + def __init__(self, interval: str = "quarterly", *args, **kwargs): + super().__init__(*args, **kwargs) + self.interval = interval - @property - def collector_class_name(self): - return "PitCollector" + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + dt = df["period"].apply( + lambda x: ( + pd.to_datetime(x) + pd.DateOffset(days=(45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90)) + ).date() + ) + df["date"] = df["date"].fillna(dt.astype(str)) - @property - def default_base_dir(self) -> [Path, str]: - return CUR_DIR + df["period"] = pd.to_datetime(df["period"]) + df["period"] = df["period"].apply( + lambda x: x.year if self.interval == PitCollector.INTERVAL_ANNUAL else x.year * 100 + (x.month - 1) // 3 + 1 + ) + return df - def download_data( - self, - max_collector_count=1, - delay=0, - start=None, - end=None, - check_data_length=False, - limit_nums=None, - **kwargs, - ): - """download data from Internet + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: + return get_calendar_list() - Parameters - ---------- - max_collector_count: int - default 2 - delay: float - time.sleep(delay), default 0 - start: str - start datetime, default "2000-01-01" - end: str - end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` - check_data_length: bool # if this param useful? - check data length, by default False - limit_nums: int - using for debug, by default None - Examples - --------- - # get quarterly data - $ python collector.py download_data --source_dir ~/.qlib/cn_data/source/pit_quarter --start 2000-01-01 --end 2021-01-01 --interval quarterly - """ +class Run(BaseRun): + @property + def collector_class_name(self) -> str: + return f"PitCollector" - super(Run, self).download_data( - max_collector_count, - delay, - start, - end, - check_data_length, - limit_nums, - **kwargs, - ) + @property + def normalize_class_name(self) -> str: + return f"PitNormalize" - def normalize_class_name(self): - pass + @property + def default_base_dir(self) -> [Path, str]: + return BASE_DIR if __name__ == "__main__": diff --git a/scripts/data_collector/pit/test_pit.py b/scripts/data_collector/pit/test_pit.py index fa456670b0..4dedd85cf0 100644 --- a/scripts/data_collector/pit/test_pit.py +++ b/scripts/data_collector/pit/test_pit.py @@ -1,20 +1,28 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import pandas as pd + import qlib from qlib.data import D import unittest +pd.set_option("display.width", 1000) +pd.set_option("display.max_columns", None) + class TestPIT(unittest.TestCase): """ NOTE!!!!!! The assert of this test assumes that users follows the cmd below and only download 2 stock. - `python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"` + 1. `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn` + 2. `python scripts/data_collector/pit/collector.py download_data --source_dir ~/.qlib/stock_data/source/pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_regex "^(600519|000725).*"` + 3. `python scripts/data_collector/pit/collector.py normalize_data --interval quarterly --source_dir ~/.qlib/stock_data/source/pit --normalize_dir ~/.qlib/stock_data/source/pit_normalized` + 4. `python scripts/dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly` """ def setUp(self): # qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier - qlib.init() # NOTE: set kernel to 1 to make it debug easier + qlib.init() def to_str(self, obj): return "".join(str(obj).split()) @@ -27,10 +35,7 @@ def test_query(self): fields = ["P($$roewa_q)", "P($$yoyni_q)"] # Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18 # - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index - data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day") - - print(data) - + data = D.features(instruments, fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") res = """ P($$roewa_q) P($$yoyni_q) count 133.000000 133.000000 @@ -57,12 +62,11 @@ def test_query(self): def test_no_exist_data(self): fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"] - data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="20190719", freq="day") + data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") data["$close"] = 1 # in case of different dataset gives different values - print(data) expect = """ P($$roewa_q) P($$yoyni_q) $close - instrument datetime + instrument datetime sh600519 2019-01-02 0.25522 0.243892 1 2019-01-03 0.25522 0.243892 1 2019-01-04 0.25522 0.243892 1 @@ -74,7 +78,7 @@ def test_no_exist_data(self): 2019-07-17 NaN NaN 1 2019-07-18 NaN NaN 1 2019-07-19 NaN NaN 1 - + [266 rows x 3 columns] """ self.check_same(data, expect) @@ -115,12 +119,12 @@ def test_unlimit(self): fields = ["P($$roewa_q)"] instruments = ["sh600519"] _ = D.features(instruments, fields, freq="day") # this should not raise error - data = D.features(instruments, fields, end_time="20200101", freq="day") # this should not raise error + data = D.features(instruments, fields, end_time="2020-01-01", freq="day") # this should not raise error s = data.iloc[:, 0] # You can check the expected value based on the content in `docs/advanced/PIT.rst` expect = """ instrument datetime - sh600519 1999-11-10 NaN + sh600519 2005-01-04 NaN 2007-04-30 0.090219 2007-08-17 0.139330 2007-10-23 0.245863 @@ -156,7 +160,7 @@ def test_unlimit(self): 2014-10-30 0.234085 2015-04-21 0.078494 2015-08-28 0.137504 - 2015-10-26 0.201709 + 2015-10-23 0.201709 2016-03-24 0.264205 2016-04-21 0.073664 2016-08-29 0.136576 @@ -176,7 +180,6 @@ def test_unlimit(self): 2019-10-16 0.255819 Name: P($$roewa_q), dtype: float32 """ - self.check_same(s[~s.duplicated().values], expect) def test_expr2(self): @@ -186,8 +189,6 @@ def test_expr2(self): fields += ["P(Sum($$yoyni_q, 4))"] fields += ["$close", "P($$roewa_q) * $close"] data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day") - print(data) - print(data.describe()) if __name__ == "__main__":