Skip to content

Commit 05d67b3

Browse files
authored
Add multi pass portfolio analysis record (#1546)
* Add multi pass port ana record * Add list function * Add documentation and support <MODEL> tag * Add drop in replacement example * reformat * Change according to comments * update format * Update record_temp.py Fix type hint * Update record_temp.py
1 parent 38edac5 commit 05d67b3

File tree

2 files changed

+221
-10
lines changed

2 files changed

+221
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
port_analysis_config: &port_analysis_config
26+
strategy:
27+
class: TopkDropoutStrategy
28+
module_path: qlib.contrib.strategy
29+
kwargs:
30+
signal:
31+
- <MODEL>
32+
- <DATASET>
33+
topk: 50
34+
n_drop: 5
35+
backtest:
36+
start_time: 2017-01-01
37+
end_time: 2020-08-01
38+
account: 100000000
39+
benchmark: *benchmark
40+
exchange_kwargs:
41+
limit_threshold: 0.095
42+
deal_price: close
43+
open_cost: 0.0005
44+
close_cost: 0.0015
45+
min_cost: 5
46+
task:
47+
model:
48+
class: LinearModel
49+
module_path: qlib.contrib.model.linear
50+
kwargs:
51+
estimator: ols
52+
dataset:
53+
class: DatasetH
54+
module_path: qlib.data.dataset
55+
kwargs:
56+
handler:
57+
class: Alpha158
58+
module_path: qlib.contrib.data.handler
59+
kwargs: *data_handler_config
60+
segments:
61+
train: [2008-01-01, 2014-12-31]
62+
valid: [2015-01-01, 2016-12-31]
63+
test: [2017-01-01, 2020-08-01]
64+
record:
65+
- class: SignalRecord
66+
module_path: qlib.workflow.record_temp
67+
kwargs:
68+
model: <MODEL>
69+
dataset: <DATASET>
70+
- class: SigAnaRecord
71+
module_path: qlib.workflow.record_temp
72+
kwargs:
73+
ana_long_short: True
74+
ann_scaler: 252
75+
- class: MultiPassPortAnaRecord
76+
module_path: qlib.workflow.record_temp
77+
kwargs:
78+
config: *port_analysis_config

qlib/workflow/record_temp.py

+143-10
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import logging
55
import warnings
66
import pandas as pd
7+
import numpy as np
8+
from tqdm import trange
79
from pprint import pprint
8-
from typing import Union, List, Optional
10+
from typing import Union, List, Optional, Dict
911

1012
from qlib.utils.exceptions import LoadObjectError
1113
from ..contrib.evaluate import risk_analysis, indicator_analysis
@@ -17,6 +19,7 @@
1719
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
1820
from ..utils.time import Freq
1921
from ..utils.data import deepcopy_basic_type
22+
from ..utils.exceptions import QlibException
2023
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
2124

2225

@@ -230,9 +233,16 @@ def generate(self, *args, **kwargs):
230233
except FileNotFoundError:
231234
logger.warning("The dependent data does not exists. Generation skipped.")
232235
return
233-
return self._generate(*args, **kwargs)
236+
artifact_dict = self._generate(*args, **kwargs)
237+
if isinstance(artifact_dict, dict):
238+
self.save(**artifact_dict)
239+
return artifact_dict
234240

235-
def _generate(self, *args, **kwargs):
241+
def _generate(self, *args, **kwargs) -> Dict[str, object]:
242+
"""
243+
Run the concrete generating task, return the dictionary of the generated results.
244+
The caller method will save the results to the recorder.
245+
"""
236246
raise NotImplementedError(f"Please implement the `_generate` method")
237247

238248

@@ -336,8 +346,8 @@ def _generate(self, label: Optional[pd.DataFrame] = None, **kwargs):
336346
}
337347
)
338348
self.recorder.log_metrics(**metrics)
339-
self.save(**objects)
340349
pprint(metrics)
350+
return objects
341351

342352
def list(self):
343353
paths = ["ic.pkl", "ric.pkl"]
@@ -468,17 +478,18 @@ def _generate(self, **kwargs):
468478
if self.backtest_config["end_time"] is None:
469479
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
470480

481+
artifact_objects = {}
471482
# custom strategy and get backtest
472483
portfolio_metric_dict, indicator_dict = normal_backtest(
473484
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
474485
)
475486
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
476-
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
477-
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
487+
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
488+
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
478489

479490
for _freq, indicators_normal in indicator_dict.items():
480-
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
481-
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
491+
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
492+
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
482493

483494
for _analysis_freq in self.risk_analysis_freq:
484495
if _analysis_freq not in portfolio_metric_dict:
@@ -500,7 +511,7 @@ def _generate(self, **kwargs):
500511
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
501512
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
502513
# save results
503-
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
514+
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
504515
logger.info(
505516
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
506517
)
@@ -525,12 +536,13 @@ def _generate(self, **kwargs):
525536
analysis_dict = analysis_df["value"].to_dict()
526537
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
527538
# save results
528-
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
539+
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
529540
logger.info(
530541
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
531542
)
532543
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
533544
pprint(analysis_df)
545+
return artifact_objects
534546

535547
def list(self):
536548
list_path = []
@@ -553,3 +565,124 @@ def list(self):
553565
else:
554566
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
555567
return list_path
568+
569+
570+
class MultiPassPortAnaRecord(PortAnaRecord):
571+
"""
572+
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
573+
574+
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
575+
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
576+
577+
Parameters
578+
----------
579+
recorder : Recorder
580+
The recorder used to save the backtest results.
581+
pass_num : int
582+
The number of backtest passes.
583+
shuffle_init_score : bool
584+
Whether to shuffle the prediction score of the first backtest date.
585+
"""
586+
587+
depend_cls = SignalRecord
588+
589+
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
590+
"""
591+
Parameters
592+
----------
593+
recorder : Recorder
594+
The recorder used to save the backtest results.
595+
pass_num : int
596+
The number of backtest passes.
597+
shuffle_init_score : bool
598+
Whether to shuffle the prediction score of the first backtest date.
599+
"""
600+
self.pass_num = pass_num
601+
self.shuffle_init_score = shuffle_init_score
602+
603+
super().__init__(recorder, **kwargs)
604+
605+
# Save original strategy so that pred df can be replaced in next generate
606+
self.original_strategy = deepcopy_basic_type(self.strategy_config)
607+
if not isinstance(self.original_strategy, dict):
608+
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
609+
if "signal" not in self.original_strategy.get("kwargs", {}):
610+
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
611+
612+
def random_init(self):
613+
pred_df = self.load("pred.pkl")
614+
615+
all_pred_dates = pred_df.index.get_level_values("datetime")
616+
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
617+
if bt_start_date is None:
618+
first_bt_pred_date = all_pred_dates.min()
619+
else:
620+
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
621+
622+
# Shuffle the first backtest date's pred score
623+
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
624+
np.random.shuffle(first_date_score.values)
625+
626+
# Use shuffled signal as the strategy signal
627+
self.strategy_config = deepcopy_basic_type(self.original_strategy)
628+
self.strategy_config["kwargs"]["signal"] = pred_df
629+
630+
def _generate(self, **kwargs):
631+
risk_analysis_df_map = {}
632+
633+
# Collect each frequency's analysis df as df list
634+
for i in trange(self.pass_num):
635+
if self.shuffle_init_score:
636+
self.random_init()
637+
638+
# Not check for cache file list
639+
single_run_artifacts = super()._generate(**kwargs)
640+
641+
for _analysis_freq in self.risk_analysis_freq:
642+
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
643+
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
644+
645+
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
646+
analysis_df["run_id"] = i
647+
risk_analysis_df_list.append(analysis_df)
648+
649+
result_artifacts = {}
650+
# Concat df list
651+
for _analysis_freq in self.risk_analysis_freq:
652+
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
653+
654+
# Calculate return and information ratio's mean, std and mean/std
655+
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
656+
lambda x: pd.Series(
657+
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
658+
)
659+
)
660+
661+
# Only look at "annualized_return" and "information_ratio"
662+
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
663+
(slice(None), ["annualized_return", "information_ratio"]), :
664+
]
665+
pprint(multi_pass_port_analysis_df)
666+
667+
# Save new df
668+
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
669+
670+
# Log metrics
671+
metrics = flatten_dict(
672+
{
673+
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
674+
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
675+
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
676+
}
677+
)
678+
self.recorder.log_metrics(**metrics)
679+
return result_artifacts
680+
681+
def list(self):
682+
list_path = []
683+
for _analysis_freq in self.risk_analysis_freq:
684+
if _analysis_freq in self.all_freq:
685+
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
686+
else:
687+
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
688+
return list_path

0 commit comments

Comments
 (0)