Skip to content

Commit

Permalink
Add backtest example to online simulation (microsoft#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g authored Mar 18, 2022
1 parent 8f93065 commit f7b2b63
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
9 changes: 3 additions & 6 deletions docs/component/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,9 @@ Running backtest
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
)
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"], freq=analysis_freq
)
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
)
# default frequency will be daily (i.e. "day")
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
analysis_df = pd.concat(analysis) # type: pd.DataFrame
pprint(analysis_df)
Expand Down
43 changes: 39 additions & 4 deletions examples/online_srv/online_management_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

"""
This example is about how can simulate the OnlineManager based on rolling tasks.
This example is about how can simulate the OnlineManager based on rolling tasks.
"""

from pprint import pprint
Expand All @@ -15,6 +15,10 @@
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
import pandas as pd
from qlib.contrib.evaluate import backtest_daily
from qlib.contrib.evaluate import risk_analysis
from qlib.contrib.strategy import TopkDropoutStrategy


class OnlineSimulationExample:
Expand All @@ -30,6 +34,7 @@ def __init__(
start_time="2018-09-10",
end_time="2018-10-31",
tasks=None,
trainer="TrainerR",
):
"""
Init OnlineManagerExample.
Expand Down Expand Up @@ -60,7 +65,13 @@ def __init__(
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
if trainer == "TrainerRM":
self.trainer = TrainerRM(self.exp_name, self.task_pool)
elif trainer == "TrainerR":
self.trainer = TrainerR(self.exp_name)
else:
# TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR
raise NotImplementedError(f"This type of input is not supported")
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
Expand All @@ -70,7 +81,8 @@ def __init__(

# Reset all things to the first status, be careful to save important data
def reset(self):
TaskManager(self.task_pool).remove()
if isinstance(self.trainer, TrainerRM):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
Expand All @@ -84,7 +96,30 @@ def main(self):
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
signals = self.rolling_online_manager.get_signals()
print(signals)
# Backtesting
# - the code is based on this example https://qlib.readthedocs.io/en/latest/component/strategy.html
CSI300_BENCH = "SH000903"
STRATEGY_CONFIG = {
"topk": 30,
"n_drop": 3,
"signal": signals.to_frame("score"),
}
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = backtest_daily(
start_time=signals.index.get_level_values("datetime").min(),
end_time=signals.index.get_level_values("datetime").max(),
strategy=strategy_obj,
)
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)

analysis_df = pd.concat(analysis) # type: pd.DataFrame
pprint(analysis_df)

def worker(self):
# train tasks by other progress or machines for multiprocessing
Expand Down
1 change: 1 addition & 0 deletions qlib/contrib/model/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def fit(
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)
# NOTE: if you encounter error here. Please upgrade your lightgbm
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(
Expand Down

0 comments on commit f7b2b63

Please sign in to comment.