From e123add6d2663c94a03901187a58195ce85e2fd3 Mon Sep 17 00:00:00 2001 From: edtechre Date: Sun, 11 Aug 2024 19:08:23 -0700 Subject: [PATCH] Uses last bar in test split when windows=1 Guarantees that the last bar of data is used in the test split when the number of train/test windows is 1. --- src/pybroker/strategy.py | 8 +++++--- tests/test_strategy.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pybroker/strategy.py b/src/pybroker/strategy.py index 6ee1aaa..9da6c99 100644 --- a/src/pybroker/strategy.py +++ b/src/pybroker/strategy.py @@ -698,12 +698,14 @@ def walkforward_split( raise ValueError(error_msg) train_length = int(res * train_size) test_length = int(res * (1 - train_size)) - train_start = 0 - train_end = train_length + train_start = ( + len(window_dates) - lookahead - train_length - test_length - 1 + ) + train_end = train_start + train_length test_start = train_end + lookahead if test_start >= len(window_dates): raise ValueError(error_msg) - test_end = test_start + test_length + test_end = len(window_dates) - 1 train_idx = dates[ (dates[date_col] >= window_dates[train_start]) & (dates[date_col] <= window_dates[train_end]) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index d84b29b..c0d32c4 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -101,7 +101,7 @@ def _verify_windows( ) dates = sorted(dates) assert len(results) == windows - for train_idx, test_idx in results: + for i, (train_idx, test_idx) in enumerate(results): assert len(dates) - (len(train_idx) + len(test_idx) * windows) >= 0 assert not (set(train_idx) & set(test_idx)) assert len(train_idx) or len(test_idx) @@ -112,6 +112,8 @@ def _verify_windows( assert dates[train_end_index - 2] != dates[test_start_index] if train_size == 0.5: assert len(train_idx) == len(test_idx) + if len(test_idx) and i == len(results) - 1: + assert dates[dates_length - 1] == dates[sorted(test_idx)[-1]] @pytest.mark.parametrize( "dates_length, windows, lookahead, train_size",