diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py index c34b31348d..b594721006 100644 --- a/scripts/data_collector/pit/collector.py +++ b/scripts/data_collector/pit/collector.py @@ -88,8 +88,12 @@ def get_instrument_list(self) -> List[str]: return symbols def normalize_symbol(self, symbol: str) -> str: - symbol, exchange = symbol.split(".") - exchange = "sh" if exchange == "ss" else "sz" + if symbol.startswith('6'): + exchange = 'sh' + elif symbol.startswith('0') or symbol.startswith('3'): + exchange = 'sz' + else: + exchange = 'bj' return f"{exchange}{symbol}" @staticmethod @@ -201,8 +205,13 @@ def get_data( ) -> pd.DataFrame: if interval != self.INTERVAL_QUARTERLY: raise ValueError(f"cannot support {interval}") - symbol, exchange = symbol.split(".") - exchange = "sh" if exchange == "ss" else "sz" + if symbol.startswith('6'): + exchange = 'sh' + elif symbol.startswith('0') or symbol.startswith('3'): + exchange = 'sz' + else: + exchange = 'bj' + code = f"{exchange}.{symbol}" start_date = start_datetime.strftime("%Y-%m-%d") end_date = end_datetime.strftime("%Y-%m-%d") diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 596eae60ef..4e0b140254 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -190,17 +190,21 @@ def get_hs_stock_symbols() -> list: global _HS_SYMBOLS # pylint: disable=W0603 def _get_symbol(): - _res = set() - for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")): - resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None) - _res |= set( - map( - lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640 - etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101 - ) - ) - time.sleep(3) - return _res + url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12" + resp = requests.get(url, timeout=None) + if resp.status_code != 200: + raise ValueError("request error") + + try: + _symbols = [_v["f12"] for _v in resp.json()["data"]["diff"]] + except Exception as e: + logger.warning(f"request error: {e}") + raise + + if len(_symbols) < 3900: + raise ValueError("request error") + + return set(_symbols) if _HS_SYMBOLS is None: symbols = set()