Skip to content

Commit

Permalink
optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
SunsetWolf committed Oct 10, 2023
1 parent 065479e commit bbf47df
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
2 changes: 2 additions & 0 deletions scripts/data_collector/baostock_5min/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
"""
bs.login()
qlib.init(provider_uri=qlib_data_1d_dir)
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
self.qlib_data_1d_dir = qlib_data_1d_dir
super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)

Expand Down Expand Up @@ -257,6 +258,7 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
_date_field_name=self._date_field_name,
_symbol_field_name=self._symbol_field_name,
frequence="5min",
_1d_data_all=self.all_1d_data,
)
return df

Expand Down
22 changes: 15 additions & 7 deletions scripts/data_collector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import re
import copy
import importlib
import time
import bisect
Expand All @@ -21,7 +22,6 @@
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from bs4 import BeautifulSoup
from qlib.data import D

HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"

Expand Down Expand Up @@ -606,17 +606,22 @@ def get_instruments(
getattr(obj, method)()


def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str):
# qlib.init(provider_uri=qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):
df = copy.deepcopy(_1d_data_all)
df.reset_index(inplace=True)
df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df


def get_1d_data(
qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, symbol: str, start: str, end: str
qlib_data_1d_dir: str,
_date_field_name: str,
_symbol_field_name: str,
symbol: str,
start: str,
end: str,
_1d_data_all: pd.DataFrame,
) -> pd.DataFrame:
"""get 1d data
Expand All @@ -626,7 +631,7 @@ def get_1d_data(
data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"]
"""
_all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name)
_all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, _1d_data_all)
return _all_1d_data[
(_all_1d_data[_symbol_field_name] == symbol.upper())
& (_all_1d_data[_date_field_name] >= pd.Timestamp(start))
Expand All @@ -636,6 +641,7 @@ def get_1d_data(

def calc_adjusted_price(
df: pd.DataFrame,
_1d_data_all: pd.DataFrame,
qlib_data_1d_dir: str,
_date_field_name: str,
_symbol_field_name: str,
Expand All @@ -654,7 +660,9 @@ def calc_adjusted_price(
# get 1d data from qlib
_start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d")
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
data_1d: pd.DataFrame = get_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end)
data_1d: pd.DataFrame = get_1d_data(
qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all
)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
Expand Down
3 changes: 2 additions & 1 deletion scripts/data_collector/yahoo/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def __init__(
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
self.qlib_data_1d_dir = qlib_data_1d_dir
qlib.init(provider_uri=self.qlib_data_1d_dir)
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")

def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
return list(D.calendar(freq="day"))
Expand All @@ -604,6 +605,7 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
frequence="1min",
consistent_1d=self.CONSISTENT_1d,
calc_paused=self.CALC_PAUSED_NUM,
_1d_data_all=self.all_1d_data,
)
return df

Expand Down Expand Up @@ -959,7 +961,6 @@ def update_data_to_bin(
Examples
-------
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
# get 1m data
"""

if self.interval.lower() != "1d":
Expand Down

0 comments on commit bbf47df

Please sign in to comment.