diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index affb115a10..ebafd6a521 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -324,7 +324,6 @@ def predict(self, dataset, segment="test"): class LSTM(nn.Module): - """LSTM Model Args: @@ -414,7 +413,6 @@ def forward(self, x): class Transformer(nn.Module): - """Transformer Model Args: @@ -475,7 +473,6 @@ def forward(self, x): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/examples/orderbook_data/README.md b/examples/orderbook_data/README.md index 059ee27056..890e11f41e 100644 --- a/examples/orderbook_data/README.md +++ b/examples/orderbook_data/README.md @@ -27,13 +27,11 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency 2. Please follow following steps to download example data ```bash cd examples/orderbook_data/ -wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2 -tar xf highfreq_orderboook_example_data.tar.bz2 +python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip ``` 3. Please import the example data to your mongo db ```bash -cd examples/orderbook_data/ python create_dataset.py initialize_library # Initialization Libraries python create_dataset.py import_data # Initialization Libraries ``` @@ -42,7 +40,6 @@ python create_dataset.py import_data # Initialization Libraries After importing these data, you run `example.py` to create some high-frequency features. ```bash -cd examples/orderbook_data/ pytest -s --disable-warnings example.py # If you want run all examples pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example ``` diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index d784aed57e..9daba91153 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -162,13 +162,15 @@ def create_account_instance( init_cash=init_cash, position_dict=position_dict, pos_type=pos_type, - benchmark_config={} - if benchmark is None - else { - "benchmark": benchmark, - "start_time": start_time, - "end_time": end_time, - }, + benchmark_config=( + {} + if benchmark is None + else { + "benchmark": benchmark, + "start_time": start_time, + "end_time": end_time, + } + ), ) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 8e7440ba9e..e7c6041efd 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -622,9 +622,11 @@ def cal_trade_indicators( print( "[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format( freq, - trade_start_time - if isinstance(trade_start_time, str) - else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"), + ( + trade_start_time + if isinstance(trade_start_time, str) + else trade_start_time.strftime("%Y-%m-%d %H:%M:%S") + ), fulfill_rate, price_advantage, positive_rate, diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 95ec9b91e9..86d366d205 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -3,6 +3,7 @@ The interface should be redesigned carefully in the future. """ + import pandas as pd from typing import Tuple from qlib import get_module_logger diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index 964febf11c..bc9a6aa977 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -511,7 +511,6 @@ def predict(self, dataset, segment="test"): class RNN(nn.Module): - """RNN Model Args: @@ -601,7 +600,6 @@ def forward(self, x): class Transformer(nn.Module): - """Transformer Model Args: @@ -649,7 +647,6 @@ def forward(self, x): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index 9ba960eebd..bad19ddfdc 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -373,7 +373,6 @@ def generate_trade_decision(self, execute_result=None): class EnhancedIndexingStrategy(WeightStrategyBase): - """Enhanced Indexing Strategy Enhanced indexing combines the arts of active management and passive management, diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index ede1f8e3ad..1ebb16f18b 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -30,7 +30,6 @@ def __call__(self, ensemble_dict: dict, *args, **kwargs): class SingleKeyEnsemble(Ensemble): - """ Extract the object if there is only one key and value in the dict. Make the result more readable. {Only key: Only value} -> Only value @@ -64,7 +63,6 @@ def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) - class RollingEnsemble(Ensemble): - """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index b2594f707d..c3c0e48ef8 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -247,9 +247,7 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np v1 = y.T.dot(z) / t - cov_mkt[:, None] * S roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt v3 = z.T.dot(z) / t - var_mkt * S - roff3 = ( - np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 - ) + roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index f2988d843f..d545e4bc9a 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -90,7 +90,6 @@ def get_collector(self) -> Collector: class RollingStrategy(OnlineStrategy): - """ This example strategy always uses the latest rolling model sas online models. """ diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 92abc8beec..a65b1f58ee 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -146,9 +146,7 @@ def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: return ( self._include_fields if self._include_fields - else set(df_columns) - set(self._exclude_fields) - if self._exclude_fields - else df_columns + else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns ) @staticmethod diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py index 34d304ed78..1ca9cfc942 100644 --- a/scripts/dump_pit.py +++ b/scripts/dump_pit.py @@ -132,9 +132,11 @@ def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]: return ( set(self._include_fields) if self._include_fields - else set(df[self.field_column_name]) - set(self._exclude_fields) - if self._exclude_fields - else set(df[self.field_column_name]) + else ( + set(df[self.field_column_name]) - set(self._exclude_fields) + if self._exclude_fields + else set(df[self.field_column_name]) + ) ) def get_filenames(self, symbol, field, interval): diff --git a/setup.py b/setup.py index 508fd8c3a4..adafefd614 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,8 @@ def get_version(rel_path: str) -> str: # To ensure stable operation of the experiment manager, we have limited the version of mlflow, # and we need to verify whether version 2.0 of mlflow can serve qlib properly. "mlflow>=1.12.1, <=1.30.0", + # mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail. + "packaging<22", "tqdm", "loguru", "lightgbm>=3.3.0", diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 129abc0fbb..cf17b3d18a 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -9,7 +9,9 @@ class WorkflowTest(TestAutoData): - TMP_PATH = Path("./.mlruns_tmp/") + # Creating the directory manually doesn't work with mlflow, + # so we add a subfolder named .trash when we create the directory. + TMP_PATH = Path("./.mlruns_tmp/.trash") def tearDown(self) -> None: if self.TMP_PATH.exists(): @@ -17,6 +19,8 @@ def tearDown(self) -> None: def test_get_local_dir(self): """ """ + self.TMP_PATH.mkdir(parents=True, exist_ok=True) + with R.start(uri=str(self.TMP_PATH)): pass