From ffabeea641659b19ee34f81b2f99b562459e7c7b Mon Sep 17 00:00:00 2001 From: edtechre Date: Thu, 14 Dec 2023 17:41:43 -0800 Subject: [PATCH] Add optional portfolio argument to backtest/walkforward. Allows overriding Portfolio that is used for backtests. --- src/pybroker/strategy.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/pybroker/strategy.py b/src/pybroker/strategy.py index b6076b1..b39b067 100644 --- a/src/pybroker/strategy.py +++ b/src/pybroker/strategy.py @@ -1022,6 +1022,7 @@ def backtest( calc_bootstrap: bool = False, disable_parallel: bool = False, warmup: Optional[int] = None, + portfolio: Optional[Portfolio] = None, ) -> TestResult: """Backtests the trading strategy by running executions that were added with :meth:`.add_execution`. @@ -1070,6 +1071,8 @@ def backtest( Defaults to ``False``. warmup: Number of bars that need to pass before running the executions. + portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for + backtests. Returns: :class:`.TestResult` containing portfolio balances, order @@ -1088,6 +1091,7 @@ def backtest( calc_bootstrap=calc_bootstrap, disable_parallel=disable_parallel, warmup=warmup, + portfolio=portfolio, ) def walkforward( @@ -1104,6 +1108,7 @@ def walkforward( calc_bootstrap: bool = False, disable_parallel: bool = False, warmup: Optional[int] = None, + portfolio: Optional[Portfolio] = None, ) -> TestResult: """Backtests the trading strategy using `Walkforward Analysis `_. @@ -1158,6 +1163,8 @@ def walkforward( Defaults to ``False``. warmup: Number of bars that need to pass before running the executions. + portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for + backtests. Returns: :class:`.TestResult` containing portfolio balances, order @@ -1217,14 +1224,15 @@ def walkforward( and self._after_exec_fn is None and all(map(lambda e: e.fn is None, self._executions)) ) - portfolio = Portfolio( - self._config.initial_cash, - self._config.fee_mode, - self._config.fee_amount, - self._fractional_shares_enabled(), - self._config.max_long_positions, - self._config.max_short_positions, - ) + if portfolio is None: + portfolio = Portfolio( + self._config.initial_cash, + self._config.fee_mode, + self._config.fee_amount, + self._fractional_shares_enabled(), + self._config.max_long_positions, + self._config.max_short_positions, + ) signals = self._run_walkforward( portfolio=portfolio, df=df,