From 89c25f69adbb04831e889fa66bf01d0e289b6e18 Mon Sep 17 00:00:00 2001 From: edtechre Date: Fri, 7 Jul 2023 04:11:16 -0700 Subject: [PATCH] Replace Collection hints with Iterable and concrete types. --- src/pybroker/indicator.py | 8 ++++---- src/pybroker/model.py | 25 ++++++++++++------------- src/pybroker/scope.py | 3 +-- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/pybroker/indicator.py b/src/pybroker/indicator.py index d729beb..b7714a3 100644 --- a/src/pybroker/indicator.py +++ b/src/pybroker/indicator.py @@ -171,7 +171,7 @@ class IndicatorsMixin: def compute_indicators( self, df: pd.DataFrame, - indicator_syms: Collection[IndicatorSymbol], + indicator_syms: Iterable[IndicatorSymbol], cache_date_fields: Optional[CacheDateFields], disable_parallel: bool, ) -> dict[IndicatorSymbol, pd.Series]: @@ -180,7 +180,7 @@ def compute_indicators( Args: df: :class:`pandas.DataFrame` used to compute the indicator values. - indicator_syms: ``Collection`` of + indicator_syms: ``Iterable`` of :class:`pybroker.common.IndicatorSymbol` pairs of indicators to compute. cache_date_fields: Date fields used to key cache data. Pass @@ -228,9 +228,9 @@ def compute_indicators( def _get_cached_indicators( self, - indicator_syms: Collection[IndicatorSymbol], + indicator_syms: Iterable[IndicatorSymbol], cache_date_fields: Optional[CacheDateFields], - ) -> tuple[dict[IndicatorSymbol, pd.Series], Collection[IndicatorSymbol]]: + ) -> tuple[dict[IndicatorSymbol, pd.Series], list[IndicatorSymbol]]: indicator_syms = sorted(indicator_syms) indicator_data: dict[IndicatorSymbol, pd.Series] = {} if cache_date_fields is None: diff --git a/src/pybroker/model.py b/src/pybroker/model.py index 013553b..8208859 100644 --- a/src/pybroker/model.py +++ b/src/pybroker/model.py @@ -22,7 +22,6 @@ from typing import ( Any, Callable, - Collection, Iterable, Mapping, NamedTuple, @@ -92,7 +91,7 @@ class ModelLoader(ModelSource): load_fn: ``Callable[[symbol: str, ...], DataFrame]`` used to load and return a pre-trained model. This is expected to return either a trained model instance, or a tuple containing a - trained model instance and a :class:`Collection` of column names to + trained model instance and a :class:`Iterable` of column names to to be used as input for the model when making predictions. indicator_names: :class:`Iterable` of names of :class:`pybroker.indicator.Indicator`\ s used as features of the @@ -111,7 +110,7 @@ class ModelLoader(ModelSource): def __init__( self, name: str, - load_fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]], + load_fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]], indicator_names: Iterable[str], input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]], predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]], @@ -122,7 +121,7 @@ def __init__( ) self._load_fn = functools.partial(load_fn, **kwargs) - def __call__(self, symbol: str) -> Union[Any, tuple[Any, Collection[str]]]: + def __call__(self, symbol: str) -> Union[Any, tuple[Any, Iterable[str]]]: """Loads pre-trained model. Args: @@ -149,7 +148,7 @@ class ModelTrainer(ModelSource): test_data: DataFrame, ...], DataFrame]`` used to train and return a model. This is expected to return either a trained model instance, or a tuple containing a trained model instance and a - :class:`Collection` of column names to to be used as input for the + :class:`Iterable` of column names to to be used as input for the model when making predictions. indicator_names: :class:`Iterable` of names of :class:`pybroker.indicator.Indicator`\ s used as features of the @@ -168,7 +167,7 @@ class ModelTrainer(ModelSource): def __init__( self, name: str, - train_fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]], + train_fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]], indicator_names: Iterable[str], input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]], predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]], @@ -181,7 +180,7 @@ def __init__( def __call__( self, symbol: str, train_data: pd.DataFrame, test_data: pd.DataFrame - ) -> Union[Any, tuple[Any, Collection[str]]]: + ) -> Union[Any, tuple[Any, Iterable[str]]]: """Trains model. Args: @@ -203,7 +202,7 @@ def __str__(self): def model( name: str, - fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]], + fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]], indicators: Optional[Iterable[Indicator]] = None, input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]] = None, @@ -221,7 +220,7 @@ def model( If for loading, then ``fn`` has signature ``Callable[[symbol: str, ...], DataFrame]``. This is expected to return either a trained model instance, or a tuple containing a - trained model instance and a :class:`Collection` of column names to + trained model instance and a :class:`Iterable` of column names to to be used as input for the model when making predictions. indicators: :class:`Iterable` of :class:`pybroker.indicator.Indicator`\ s used as features of the @@ -290,7 +289,7 @@ class ModelsMixin: def train_models( self, - model_syms: Collection[ModelSymbol], + model_syms: Iterable[ModelSymbol], train_data: pd.DataFrame, test_data: pd.DataFrame, indicator_data: Mapping[IndicatorSymbol, pd.Series], @@ -300,7 +299,7 @@ def train_models( pairs. Args: - model_syms: ``Collection`` of + model_syms: ``Iterable`` of :class:`pybroker.common.ModelSymbol` pairs of models to train. train_data: :class:`pandas.DataFrame` of training data. test_data: :class:`pandas.DataFrame` of test data. @@ -385,9 +384,9 @@ def _slice_by_symbol(self, symbol: str, df: pd.DataFrame) -> pd.DataFrame: def _get_cached_models( self, - model_syms: Collection[ModelSymbol], + model_syms: Iterable[ModelSymbol], cache_date_fields: CacheDateFields, - ) -> tuple[dict[ModelSymbol, TrainedModel], Collection[ModelSymbol]]: + ) -> tuple[dict[ModelSymbol, TrainedModel], list[ModelSymbol]]: model_syms = sorted(model_syms) models: dict[ModelSymbol, TrainedModel] = {} scope = StaticScope.instance() diff --git a/src/pybroker/scope.py b/src/pybroker/scope.py index 1c807de..d2dbf70 100644 --- a/src/pybroker/scope.py +++ b/src/pybroker/scope.py @@ -27,7 +27,6 @@ from typing import ( Any, Callable, - Collection, Final, Iterable, Literal, @@ -237,7 +236,7 @@ def __init__(self, df: pd.DataFrame): def fetch_dict( self, symbol: str, - names: Collection[str], + names: Iterable[str], end_index: Optional[int] = None, ) -> dict[str, Optional[NDArray]]: r"""Fetches a ``dict`` of column data for ``symbol``.