Skip to content

Commit

Permalink
De-dupes data in walkforward result.
Browse files Browse the repository at this point in the history
Fixes bug where test results were being duplicated for each walkforward window.
  • Loading branch information
edtechre committed Mar 16, 2023
1 parent ce8e1e7 commit 6cc3e42
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 44 deletions.
49 changes: 19 additions & 30 deletions src/pybroker/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
Mapping,
NamedTuple,
Optional,
Sequence,
Union,
)

Expand Down Expand Up @@ -1172,7 +1171,16 @@ def walkforward(
disable_parallel=disable_parallel,
)
train_only = all(map(lambda e: e.fn is None, self._executions))
test_results = self._run_walkforward(
portfolio = Portfolio(
self._config.initial_cash,
self._config.fee_mode,
self._config.fee_amount,
self._fractional_shares_enabled(),
self._config.max_long_positions,
self._config.max_short_positions,
)
self._run_walkforward(
portfolio=portfolio,
df=df,
indicator_data=indicator_data,
tf_seconds=tf_seconds,
Expand All @@ -1188,7 +1196,7 @@ def walkforward(
self._logger.walkforward_completed()
return None
return self._to_test_result(
start_dt, end_dt, test_results, calc_bootstrap
start_dt, end_dt, portfolio, calc_bootstrap
)
finally:
scope.unfreeze_data_cols()
Expand All @@ -1215,6 +1223,7 @@ def _fractional_shares_enabled(self):

def _run_walkforward(
self,
portfolio: Portfolio,
df: pd.DataFrame,
indicator_data: dict[IndicatorSymbol, pd.Series],
tf_seconds: int,
Expand All @@ -1225,22 +1234,13 @@ def _run_walkforward(
train_size: float,
shuffle: bool,
train_only: bool,
) -> deque[BacktestResult]:
):
sessions: dict[ExecSymbol, dict] = {
ExecSymbol(execution.id, sym): {}
for execution in self._executions
if execution.fn is not None
for sym in execution.symbols
}
portfolio = Portfolio(
self._config.initial_cash,
self._config.fee_mode,
self._config.fee_amount,
self._fractional_shares_enabled(),
self._config.max_long_positions,
self._config.max_short_positions,
)
backtest_results: deque[BacktestResult] = deque()
for train_idx, test_idx in self.walkforward_split(
df=df,
windows=windows,
Expand Down Expand Up @@ -1275,7 +1275,7 @@ def _run_walkforward(
),
)
if not train_only and not test_data.empty:
backtest_result = self.backtest_executions(
self.backtest_executions(
executions=self._executions,
sessions=sessions,
models=models,
Expand All @@ -1289,8 +1289,6 @@ def _run_walkforward(
pos_size_handler=self._pos_size_handler,
enable_fractional_shares=self._fractional_shares_enabled(),
)
backtest_results.append(backtest_result)
return backtest_results

def _filter_dates(
self,
Expand Down Expand Up @@ -1365,20 +1363,11 @@ def _to_test_result(
self,
start_date: datetime,
end_date: datetime,
backtest_results: Sequence[BacktestResult],
portfolio: Portfolio,
calc_bootstrap: bool,
) -> TestResult:
portfolio_bars: deque[PortfolioBar] = deque()
pos_bars: deque[PositionBar] = deque()
orders: deque[Order] = deque()
trades: deque[Trade] = deque()
for result in backtest_results:
portfolio_bars.extend(result.portfolio_bars)
pos_bars.extend(result.position_bars)
orders.extend(result.orders)
trades.extend(result.trades)
pos_df = pd.DataFrame.from_records(
pos_bars, columns=PositionBar._fields
portfolio.position_bars, columns=PositionBar._fields
)
for col in (
"close",
Expand All @@ -1390,7 +1379,7 @@ def _to_test_result(
pos_df[col] = quantize(pos_df, col)
pos_df.set_index(["symbol", "date"], inplace=True)
portfolio_df = pd.DataFrame.from_records(
portfolio_bars, columns=PortfolioBar._fields, index="date"
portfolio.bars, columns=PortfolioBar._fields, index="date"
)
for col in (
"cash",
Expand All @@ -1402,12 +1391,12 @@ def _to_test_result(
):
portfolio_df[col] = quantize(portfolio_df, col)
orders_df = pd.DataFrame.from_records(
orders, columns=Order._fields, index="id"
portfolio.orders, columns=Order._fields, index="id"
)
for col in ("limit_price", "fill_price", "fees"):
orders_df[col] = quantize(orders_df, col)
trades_df = pd.DataFrame.from_records(
trades, columns=Trade._fields, index="id"
portfolio.trades, columns=Trade._fields, index="id"
)
trades_df["bars"] = trades_df["bars"].astype(int)
for col in (
Expand Down
22 changes: 8 additions & 14 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from pybroker.strategy import (
BacktestMixin,
BacktestResult,
Execution,
ExecSymbol,
Strategy,
Expand Down Expand Up @@ -1548,7 +1547,8 @@ def test_to_test_result_when_fractional_shares(
expected_long_shares,
expected_short_shares,
):
portfolio_bars = deque(
portfolio = Portfolio(100_000)
portfolio.bars = deque(
(
PortfolioBar(
date=np.datetime64(START_DATE),
Expand All @@ -1561,7 +1561,7 @@ def test_to_test_result_when_fractional_shares(
),
)
)
position_bars = deque(
portfolio.position_bars = deque(
(
PositionBar(
symbol="SPY",
Expand All @@ -1576,7 +1576,7 @@ def test_to_test_result_when_fractional_shares(
),
)
)
orders = deque(
portfolio.orders = deque(
(
Order(
id=1,
Expand All @@ -1590,7 +1590,7 @@ def test_to_test_result_when_fractional_shares(
),
)
)
trades = deque(
portfolio.trades = deque(
(
Trade(
id=1,
Expand All @@ -1609,9 +1609,6 @@ def test_to_test_result_when_fractional_shares(
),
)
)
backtest_result = BacktestResult(
START_DATE, END_DATE, portfolio_bars, position_bars, orders, trades
)
config = StrategyConfig(
enable_fractional_shares=enable_fractional_shares
)
Expand All @@ -1622,7 +1619,7 @@ def test_to_test_result_when_fractional_shares(
config,
)
result = strategy._to_test_result(
START_DATE, END_DATE, (backtest_result,), calc_bootstrap=False
START_DATE, END_DATE, portfolio, calc_bootstrap=False
)
assert np.issubdtype(
result.positions["long_shares"].dtype, expected_shares_type
Expand All @@ -1646,15 +1643,12 @@ def test_to_test_result_when_fractional_shares(
assert result.trades["shares"].values[0] == expected_long_shares

def test_to_test_result_when_empty(self, data_source_df):
portfolio = Portfolio(100_000)
strategy = Strategy(data_source_df, START_DATE, END_DATE)
result = strategy._to_test_result(
START_DATE,
END_DATE,
(
BacktestResult(
START_DATE, END_DATE, deque(), deque(), deque(), deque()
),
),
portfolio,
calc_bootstrap=False,
)
assert result.portfolio.empty
Expand Down

0 comments on commit 6cc3e42

Please sign in to comment.