From aba70196ff5c73baa0eb3b0246e4707be8a743e7 Mon Sep 17 00:00:00 2001 From: edtechre Date: Sat, 16 Mar 2024 18:46:33 -0700 Subject: [PATCH] Fix tests by removing lazy_fixture plugin. --- requirements.txt | 1 - setup.cfg | 1 - tests/test_data.py | 18 +++++++++-------- tests/test_indicator.py | 8 ++++---- tests/test_strategy.py | 45 +++++++++++++++++++++++++++++------------ tests/util.py | 4 ++++ 6 files changed, 50 insertions(+), 27 deletions(-) create mode 100644 tests/util.py diff --git a/requirements.txt b/requirements.txt index 4548818..899219c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,6 @@ progressbar2>=4.1.1 pytest>=7.2.0 pytest-cov>=4.0.0 pytest-instafail>=0.4.2 -pytest-lazy-fixture>=0.6.3 pytest-randomly>=3.12.0 pytest-xdist>=3.0.2 scikit-learn>=1.2.1 diff --git a/setup.cfg b/setup.cfg index dfda8a5..4af3d04 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,7 +56,6 @@ deps = pytest pytest-cov pytest-instafail - pytest-lazy-fixture pytest-randomly pytest-xdist commands = diff --git a/tests/test_data.py b/tests/test_data.py index 5110f70..af50762 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -13,6 +13,7 @@ import re import yfinance from .fixtures import * # noqa: F401 +from .util import * # noqa: F401 from datetime import datetime from pybroker.cache import DataSourceCacheKey from pybroker.common import to_seconds @@ -138,10 +139,9 @@ def test_set_cached(self, alpaca_df, symbols, mock_cache): assert sym_df.equals(alpaca_df[alpaca_df["symbol"] == sym]) @pytest.mark.usefixtures("scope") - @pytest.mark.parametrize( - "query_symbols", [[], pytest.lazy_fixture("symbols")] - ) - def test_get_cached_when_empty(self, mock_cache, query_symbols): + @pytest.mark.parametrize("query_symbols", [[], "symbols"]) + def test_get_cached_when_empty(self, mock_cache, query_symbols, request): + query_symbols = get_fixture(request, query_symbols) cache_mixin = DataSourceCacheMixin() df, uncached_syms = cache_mixin.get_cached( query_symbols, TIMEFRAME, START_DATE, END_DATE, ADJUST @@ -503,15 +503,17 @@ class TestYFinance: "param_symbols, expected_df, expected_rows", [ ( - pytest.lazy_fixture("symbols"), - pytest.lazy_fixture("yfinance_df"), + "symbols", + "yfinance_df", 2020, ), - (["SPY"], pytest.lazy_fixture("yfinance_single_df"), 505), + (["SPY"], "yfinance_single_df", 505), ], ) @pytest.mark.usefixtures("setup_ds_cache") - def test_query(self, param_symbols, expected_df, expected_rows): + def test_query(self, param_symbols, expected_df, expected_rows, request): + param_symbols = get_fixture(request, param_symbols) + expected_df = get_fixture(request, expected_df) yf = YFinance() with mock.patch.object(yfinance, "download", return_value=expected_df): df = yf.query(param_symbols, START_DATE, END_DATE) diff --git a/tests/test_indicator.py b/tests/test_indicator.py index fef5755..c087290 100644 --- a/tests/test_indicator.py +++ b/tests/test_indicator.py @@ -12,6 +12,7 @@ import re from pybroker.cache import CacheDateFields from .fixtures import * # noqa: F401 +from .util import * # noqa: F401 from pybroker.common import BarData, DataCol, IndicatorSymbol, to_datetime from pybroker.indicator import ( Indicator, @@ -214,10 +215,9 @@ def test_call_when_indicators_empty_then_error(self, data_source_df): with pytest.raises(ValueError, match="No indicators were added."): ind_set(data_source_df) - @pytest.mark.parametrize( - "df", [pd.DataFrame(), pytest.lazy_fixture("data_source_df")] - ) - def test_call(self, df, hhv_ind, llv_ind, disable_parallel): + @pytest.mark.parametrize("df", [pd.DataFrame(), "data_source_df"]) + def test_call(self, df, hhv_ind, llv_ind, disable_parallel, request): + df = get_fixture(request, df) ind_set = IndicatorSet() ind_set.add([hhv_ind, llv_ind]) result = ind_set(df, disable_parallel) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 738f733..1f0d8d4 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -11,7 +11,8 @@ import pandas as pd import pytest import re -from .fixtures import * +from .fixtures import * # noqa: F401 +from .util import * # noqa: F401 from collections import defaultdict, deque from datetime import datetime from decimal import Decimal @@ -1081,18 +1082,33 @@ def exec_fn_2(ctx): @pytest.fixture() -def executions_with_models(executions_only, model_source): +def exec_model_source(scope, data_source_df, indicators): + return model( + MODEL_NAME, + lambda sym, *_: FakeModel( + sym, + np.full( + data_source_df[data_source_df["symbol"] == sym].shape[0], 100 + ), + ), + indicators, + pretrained=False, + ) + + +@pytest.fixture() +def executions_with_models(executions_only, exec_model_source): def exec_fn(ctx): - assert isinstance(ctx.model(model_source.name), FakeModel) + assert isinstance(ctx.model(exec_model_source.name), FakeModel) - executions_only[0]["models"] = model_source + executions_only[0]["models"] = exec_model_source executions_only[0]["fn"] = exec_fn return executions_only @pytest.fixture() def executions_with_models_and_indicators( - executions_only, model_source, hhv_ind, llv_ind + executions_only, exec_model_source, hhv_ind, llv_ind ): def exec_fn_1(ctx): assert len(ctx.indicator(llv_ind.name)) @@ -1102,10 +1118,10 @@ def exec_fn_1(ctx): def exec_fn_2(ctx): assert len(ctx.indicator(hhv_ind.name)) - assert isinstance(ctx.model(model_source.name), FakeModel) + assert isinstance(ctx.model(exec_model_source.name), FakeModel) executions_only[1]["indicators"] = hhv_ind - executions_only[1]["models"] = model_source + executions_only[1]["models"] = exec_model_source executions_only[1]["fn"] = exec_fn_2 return executions_only @@ -1158,16 +1174,16 @@ def _fetch_data( class TestStrategy: @pytest.mark.parametrize( "data_source", - [FakeDataSource(), pytest.lazy_fixture("data_source_df")], + [FakeDataSource(), "data_source_df"], ) @pytest.mark.parametrize( "executions", [ - pytest.lazy_fixture("executions_train_only"), - pytest.lazy_fixture("executions_only"), - pytest.lazy_fixture("executions_with_indicators"), - pytest.lazy_fixture("executions_with_models"), - pytest.lazy_fixture("executions_with_models_and_indicators"), + "executions_train_only", + "executions_only", + "executions_with_indicators", + "executions_with_models", + "executions_with_models_and_indicators", ], ) def test_walkforward( @@ -1179,7 +1195,10 @@ def test_walkforward( between_time, calc_bootstrap, disable_parallel, + request, ): + data_source = get_fixture(request, data_source) + executions = get_fixture(request, executions) config = StrategyConfig( bootstrap_samples=100, bootstrap_sample_size=10 ) diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..d2fb160 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,4 @@ +def get_fixture(request, param): + if isinstance(param, str): + return request.getfixturevalue(param) + return param