From 502e5fe4660f0fd2d659dda80e26e2dd6d437abc Mon Sep 17 00:00:00 2001 From: edtechre Date: Fri, 7 Jul 2023 12:30:51 -0700 Subject: [PATCH] Add start_date, end_date to model loader arguments. --- src/pybroker/model.py | 20 +++++++++++++++----- tests/test_model.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/pybroker/model.py b/src/pybroker/model.py index 8208859..6e6846e 100644 --- a/src/pybroker/model.py +++ b/src/pybroker/model.py @@ -14,10 +14,12 @@ IndicatorSymbol, ModelSymbol, TrainedModel, + to_datetime, ) from pybroker.indicator import Indicator from pybroker.scope import StaticScope from dataclasses import asdict +from datetime import datetime from numpy.typing import NDArray from typing import ( Any, @@ -88,7 +90,8 @@ class ModelLoader(ModelSource): Args: name: Name of model. - load_fn: ``Callable[[symbol: str, ...], DataFrame]`` used to load and + load_fn: ``Callable[[symbol: str, train_start_date: datetime, + train_end_date: datetime, ...], DataFrame]`` used to load and return a pre-trained model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and a :class:`Iterable` of column names to @@ -121,16 +124,20 @@ def __init__( ) self._load_fn = functools.partial(load_fn, **kwargs) - def __call__(self, symbol: str) -> Union[Any, tuple[Any, Iterable[str]]]: + def __call__( + self, symbol: str, train_start_date: datetime, train_end_date: datetime + ) -> Union[Any, tuple[Any, Iterable[str]]]: """Loads pre-trained model. Args: symbol: Ticker symbol for loading the pre-trained model. + train_start_date: Start date of training window. + train_end_date: End date of training window. Returns: Pre-trained model. """ - return self._load_fn(symbol) + return self._load_fn(symbol, train_start_date, train_end_date) def __repr__(self): return self.__str__() @@ -218,7 +225,8 @@ def model( for training, then ``fn`` has signature ``Callable[[symbol: str, train_data: DataFrame, test_data: DataFrame, ...], DataFrame]``. If for loading, then ``fn`` has signature - ``Callable[[symbol: str, ...], DataFrame]``. This is expected to + ``Callable[[symbol: str, train_start_date: datetime, + train_end_date: datetime, ...], DataFrame]``. This is expected to return either a trained model instance, or a tuple containing a trained model instance and a :class:`Iterable` of column names to to be used as input for the model when making predictions. @@ -353,7 +361,9 @@ def train_models( model_result = source(sym, sym_train_data, sym_test_data) scope.logger.info_train_model_completed(model_sym) elif isinstance(source, ModelLoader): - model_result = source(sym) + start_date = to_datetime(train_dates[0]) + end_date = to_datetime(train_dates[-1]) + model_result = source(sym, start_date, end_date) scope.logger.info_loaded_model(model_sym) else: raise TypeError(f"Invalid ModelSource type: {type(source)}") diff --git a/tests/test_model.py b/tests/test_model.py index 317b54c..7c4ff04 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -40,16 +40,36 @@ def cache_date_fields(train_data): ) +@pytest.fixture() +def start_date(train_data): + return to_datetime(sorted(train_data["date"].unique())[0]) + + @pytest.fixture() def end_date(train_data): return to_datetime(sorted(train_data["date"].unique())[-1]) @pytest.fixture() -def model_syms(train_data, model_source): +def model_loader(): + return model( + "loader", + lambda symbol, train_start_date, train_end_date: FakeModel( + symbol=symbol, preds=[] + ), + [], + pretrained=True, + ) + + +@pytest.fixture() +def model_syms(train_data, model_source, model_loader): return [ ModelSymbol(model_source.name, sym) for sym in train_data["symbol"].unique() + ] + [ + ModelSymbol(model_loader.name, sym) + for sym in train_data["symbol"].unique() ] @@ -116,11 +136,13 @@ def test_model_prepare_input_fn_when_indicators_not_found_then_error( ): source.prepare_input_data(ind_df) - def test_model_loader_call_with_kwargs(self): + def test_model_loader_call_with_kwargs(self, start_date, end_date): load_fn = Mock() kwargs = {"a": 1, "b": 2} - ModelLoader("loader", load_fn, [], None, None, kwargs)("SPY") - load_fn.assert_called_once_with("SPY", **kwargs) + ModelLoader("loader", load_fn, [], None, None, kwargs)( + "SPY", start_date, end_date + ) + load_fn.assert_called_once_with("SPY", start_date, end_date, **kwargs) def test_model_trainer_call_with_kwargs(self, train_data, test_data): train_fn = Mock()