Skip to content

Commit

Permalink
Fix tests by removing lazy_fixture plugin.
Browse files Browse the repository at this point in the history
  • Loading branch information
edtechre committed Mar 17, 2024
1 parent 42e306f commit aba7019
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 27 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ progressbar2>=4.1.1
pytest>=7.2.0
pytest-cov>=4.0.0
pytest-instafail>=0.4.2
pytest-lazy-fixture>=0.6.3
pytest-randomly>=3.12.0
pytest-xdist>=3.0.2
scikit-learn>=1.2.1
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ deps =
pytest
pytest-cov
pytest-instafail
pytest-lazy-fixture
pytest-randomly
pytest-xdist
commands =
Expand Down
18 changes: 10 additions & 8 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import yfinance
from .fixtures import * # noqa: F401
from .util import * # noqa: F401
from datetime import datetime
from pybroker.cache import DataSourceCacheKey
from pybroker.common import to_seconds
Expand Down Expand Up @@ -138,10 +139,9 @@ def test_set_cached(self, alpaca_df, symbols, mock_cache):
assert sym_df.equals(alpaca_df[alpaca_df["symbol"] == sym])

@pytest.mark.usefixtures("scope")
@pytest.mark.parametrize(
"query_symbols", [[], pytest.lazy_fixture("symbols")]
)
def test_get_cached_when_empty(self, mock_cache, query_symbols):
@pytest.mark.parametrize("query_symbols", [[], "symbols"])
def test_get_cached_when_empty(self, mock_cache, query_symbols, request):
query_symbols = get_fixture(request, query_symbols)
cache_mixin = DataSourceCacheMixin()
df, uncached_syms = cache_mixin.get_cached(
query_symbols, TIMEFRAME, START_DATE, END_DATE, ADJUST
Expand Down Expand Up @@ -503,15 +503,17 @@ class TestYFinance:
"param_symbols, expected_df, expected_rows",
[
(
pytest.lazy_fixture("symbols"),
pytest.lazy_fixture("yfinance_df"),
"symbols",
"yfinance_df",
2020,
),
(["SPY"], pytest.lazy_fixture("yfinance_single_df"), 505),
(["SPY"], "yfinance_single_df", 505),
],
)
@pytest.mark.usefixtures("setup_ds_cache")
def test_query(self, param_symbols, expected_df, expected_rows):
def test_query(self, param_symbols, expected_df, expected_rows, request):
param_symbols = get_fixture(request, param_symbols)
expected_df = get_fixture(request, expected_df)
yf = YFinance()
with mock.patch.object(yfinance, "download", return_value=expected_df):
df = yf.query(param_symbols, START_DATE, END_DATE)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
from pybroker.cache import CacheDateFields
from .fixtures import * # noqa: F401
from .util import * # noqa: F401
from pybroker.common import BarData, DataCol, IndicatorSymbol, to_datetime
from pybroker.indicator import (
Indicator,
Expand Down Expand Up @@ -214,10 +215,9 @@ def test_call_when_indicators_empty_then_error(self, data_source_df):
with pytest.raises(ValueError, match="No indicators were added."):
ind_set(data_source_df)

@pytest.mark.parametrize(
"df", [pd.DataFrame(), pytest.lazy_fixture("data_source_df")]
)
def test_call(self, df, hhv_ind, llv_ind, disable_parallel):
@pytest.mark.parametrize("df", [pd.DataFrame(), "data_source_df"])
def test_call(self, df, hhv_ind, llv_ind, disable_parallel, request):
df = get_fixture(request, df)
ind_set = IndicatorSet()
ind_set.add([hhv_ind, llv_ind])
result = ind_set(df, disable_parallel)
Expand Down
45 changes: 32 additions & 13 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import pandas as pd
import pytest
import re
from .fixtures import *
from .fixtures import * # noqa: F401
from .util import * # noqa: F401
from collections import defaultdict, deque
from datetime import datetime
from decimal import Decimal
Expand Down Expand Up @@ -1081,18 +1082,33 @@ def exec_fn_2(ctx):


@pytest.fixture()
def executions_with_models(executions_only, model_source):
def exec_model_source(scope, data_source_df, indicators):
return model(
MODEL_NAME,
lambda sym, *_: FakeModel(
sym,
np.full(
data_source_df[data_source_df["symbol"] == sym].shape[0], 100
),
),
indicators,
pretrained=False,
)


@pytest.fixture()
def executions_with_models(executions_only, exec_model_source):
def exec_fn(ctx):
assert isinstance(ctx.model(model_source.name), FakeModel)
assert isinstance(ctx.model(exec_model_source.name), FakeModel)

executions_only[0]["models"] = model_source
executions_only[0]["models"] = exec_model_source
executions_only[0]["fn"] = exec_fn
return executions_only


@pytest.fixture()
def executions_with_models_and_indicators(
executions_only, model_source, hhv_ind, llv_ind
executions_only, exec_model_source, hhv_ind, llv_ind
):
def exec_fn_1(ctx):
assert len(ctx.indicator(llv_ind.name))
Expand All @@ -1102,10 +1118,10 @@ def exec_fn_1(ctx):

def exec_fn_2(ctx):
assert len(ctx.indicator(hhv_ind.name))
assert isinstance(ctx.model(model_source.name), FakeModel)
assert isinstance(ctx.model(exec_model_source.name), FakeModel)

executions_only[1]["indicators"] = hhv_ind
executions_only[1]["models"] = model_source
executions_only[1]["models"] = exec_model_source
executions_only[1]["fn"] = exec_fn_2
return executions_only

Expand Down Expand Up @@ -1158,16 +1174,16 @@ def _fetch_data(
class TestStrategy:
@pytest.mark.parametrize(
"data_source",
[FakeDataSource(), pytest.lazy_fixture("data_source_df")],
[FakeDataSource(), "data_source_df"],
)
@pytest.mark.parametrize(
"executions",
[
pytest.lazy_fixture("executions_train_only"),
pytest.lazy_fixture("executions_only"),
pytest.lazy_fixture("executions_with_indicators"),
pytest.lazy_fixture("executions_with_models"),
pytest.lazy_fixture("executions_with_models_and_indicators"),
"executions_train_only",
"executions_only",
"executions_with_indicators",
"executions_with_models",
"executions_with_models_and_indicators",
],
)
def test_walkforward(
Expand All @@ -1179,7 +1195,10 @@ def test_walkforward(
between_time,
calc_bootstrap,
disable_parallel,
request,
):
data_source = get_fixture(request, data_source)
executions = get_fixture(request, executions)
config = StrategyConfig(
bootstrap_samples=100, bootstrap_sample_size=10
)
Expand Down
4 changes: 4 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def get_fixture(request, param):
if isinstance(param, str):
return request.getfixturevalue(param)
return param

0 comments on commit aba7019

Please sign in to comment.