Skip to content

Commit

Permalink
Add start_date, end_date to model loader arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
edtechre committed Jul 7, 2023
1 parent 89c25f6 commit 502e5fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
20 changes: 15 additions & 5 deletions src/pybroker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
IndicatorSymbol,
ModelSymbol,
TrainedModel,
to_datetime,
)
from pybroker.indicator import Indicator
from pybroker.scope import StaticScope
from dataclasses import asdict
from datetime import datetime
from numpy.typing import NDArray
from typing import (
Any,
Expand Down Expand Up @@ -88,7 +90,8 @@ class ModelLoader(ModelSource):
Args:
name: Name of model.
load_fn: ``Callable[[symbol: str, ...], DataFrame]`` used to load and
load_fn: ``Callable[[symbol: str, train_start_date: datetime,
train_end_date: datetime, ...], DataFrame]`` used to load and
return a pre-trained model. This is expected to
return either a trained model instance, or a tuple containing a
trained model instance and a :class:`Iterable` of column names to
Expand Down Expand Up @@ -121,16 +124,20 @@ def __init__(
)
self._load_fn = functools.partial(load_fn, **kwargs)

def __call__(self, symbol: str) -> Union[Any, tuple[Any, Iterable[str]]]:
def __call__(
self, symbol: str, train_start_date: datetime, train_end_date: datetime
) -> Union[Any, tuple[Any, Iterable[str]]]:
"""Loads pre-trained model.
Args:
symbol: Ticker symbol for loading the pre-trained model.
train_start_date: Start date of training window.
train_end_date: End date of training window.
Returns:
Pre-trained model.
"""
return self._load_fn(symbol)
return self._load_fn(symbol, train_start_date, train_end_date)

def __repr__(self):
return self.__str__()
Expand Down Expand Up @@ -218,7 +225,8 @@ def model(
for training, then ``fn`` has signature ``Callable[[symbol: str,
train_data: DataFrame, test_data: DataFrame, ...], DataFrame]``.
If for loading, then ``fn`` has signature
``Callable[[symbol: str, ...], DataFrame]``. This is expected to
``Callable[[symbol: str, train_start_date: datetime,
train_end_date: datetime, ...], DataFrame]``. This is expected to
return either a trained model instance, or a tuple containing a
trained model instance and a :class:`Iterable` of column names to
to be used as input for the model when making predictions.
Expand Down Expand Up @@ -353,7 +361,9 @@ def train_models(
model_result = source(sym, sym_train_data, sym_test_data)
scope.logger.info_train_model_completed(model_sym)
elif isinstance(source, ModelLoader):
model_result = source(sym)
start_date = to_datetime(train_dates[0])
end_date = to_datetime(train_dates[-1])
model_result = source(sym, start_date, end_date)
scope.logger.info_loaded_model(model_sym)
else:
raise TypeError(f"Invalid ModelSource type: {type(source)}")
Expand Down
30 changes: 26 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,36 @@ def cache_date_fields(train_data):
)


@pytest.fixture()
def start_date(train_data):
return to_datetime(sorted(train_data["date"].unique())[0])


@pytest.fixture()
def end_date(train_data):
return to_datetime(sorted(train_data["date"].unique())[-1])


@pytest.fixture()
def model_syms(train_data, model_source):
def model_loader():
return model(
"loader",
lambda symbol, train_start_date, train_end_date: FakeModel(
symbol=symbol, preds=[]
),
[],
pretrained=True,
)


@pytest.fixture()
def model_syms(train_data, model_source, model_loader):
return [
ModelSymbol(model_source.name, sym)
for sym in train_data["symbol"].unique()
] + [
ModelSymbol(model_loader.name, sym)
for sym in train_data["symbol"].unique()
]


Expand Down Expand Up @@ -116,11 +136,13 @@ def test_model_prepare_input_fn_when_indicators_not_found_then_error(
):
source.prepare_input_data(ind_df)

def test_model_loader_call_with_kwargs(self):
def test_model_loader_call_with_kwargs(self, start_date, end_date):
load_fn = Mock()
kwargs = {"a": 1, "b": 2}
ModelLoader("loader", load_fn, [], None, None, kwargs)("SPY")
load_fn.assert_called_once_with("SPY", **kwargs)
ModelLoader("loader", load_fn, [], None, None, kwargs)(
"SPY", start_date, end_date
)
load_fn.assert_called_once_with("SPY", start_date, end_date, **kwargs)

def test_model_trainer_call_with_kwargs(self, train_data, test_data):
train_fn = Mock()
Expand Down

0 comments on commit 502e5fe

Please sign in to comment.