diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py index b594721006..caabbde380 100644 --- a/scripts/data_collector/pit/collector.py +++ b/scripts/data_collector/pit/collector.py @@ -88,12 +88,12 @@ def get_instrument_list(self) -> List[str]: return symbols def normalize_symbol(self, symbol: str) -> str: - if symbol.startswith('6'): - exchange = 'sh' - elif symbol.startswith('0') or symbol.startswith('3'): - exchange = 'sz' + if symbol.startswith("6"): + exchange = "sh" + elif symbol.startswith("0") or symbol.startswith("3"): + exchange = "sz" else: - exchange = 'bj' + exchange = "bj" return f"{exchange}{symbol}" @staticmethod @@ -205,12 +205,12 @@ def get_data( ) -> pd.DataFrame: if interval != self.INTERVAL_QUARTERLY: raise ValueError(f"cannot support {interval}") - if symbol.startswith('6'): - exchange = 'sh' - elif symbol.startswith('0') or symbol.startswith('3'): - exchange = 'sz' + if symbol.startswith("6"): + exchange = "sh" + elif symbol.startswith("0") or symbol.startswith("3"): + exchange = "sz" else: - exchange = 'bj' + exchange = "bj" code = f"{exchange}.{symbol}" start_date = start_datetime.strftime("%Y-%m-%d") diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index b642b4404a..1962281918 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -204,7 +204,10 @@ def _get_symbol(): if len(_symbols) < 3900: raise ValueError("request error") - _symbols = [_symbol + '.ss' if _symbol.startswith('6') else _symbol + '.sz' if _symbol.startswith(('0', '3')) else None for _symbol in _symbols] + _symbols = [ + _symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None + for _symbol in _symbols + ] _symbols = [_symbol for _symbol in _symbols if _symbol is not None] return set(_symbols)