Skip to content

Commit

Permalink
Replace Collection hints with Iterable and concrete types.
Browse files Browse the repository at this point in the history
  • Loading branch information
edtechre committed Jul 7, 2023
1 parent 681bce7 commit 89c25f6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
8 changes: 4 additions & 4 deletions src/pybroker/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class IndicatorsMixin:
def compute_indicators(
self,
df: pd.DataFrame,
indicator_syms: Collection[IndicatorSymbol],
indicator_syms: Iterable[IndicatorSymbol],
cache_date_fields: Optional[CacheDateFields],
disable_parallel: bool,
) -> dict[IndicatorSymbol, pd.Series]:
Expand All @@ -180,7 +180,7 @@ def compute_indicators(
Args:
df: :class:`pandas.DataFrame` used to compute the indicator values.
indicator_syms: ``Collection`` of
indicator_syms: ``Iterable`` of
:class:`pybroker.common.IndicatorSymbol` pairs of indicators
to compute.
cache_date_fields: Date fields used to key cache data. Pass
Expand Down Expand Up @@ -228,9 +228,9 @@ def compute_indicators(

def _get_cached_indicators(
self,
indicator_syms: Collection[IndicatorSymbol],
indicator_syms: Iterable[IndicatorSymbol],
cache_date_fields: Optional[CacheDateFields],
) -> tuple[dict[IndicatorSymbol, pd.Series], Collection[IndicatorSymbol]]:
) -> tuple[dict[IndicatorSymbol, pd.Series], list[IndicatorSymbol]]:
indicator_syms = sorted(indicator_syms)
indicator_data: dict[IndicatorSymbol, pd.Series] = {}
if cache_date_fields is None:
Expand Down
25 changes: 12 additions & 13 deletions src/pybroker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import (
Any,
Callable,
Collection,
Iterable,
Mapping,
NamedTuple,
Expand Down Expand Up @@ -92,7 +91,7 @@ class ModelLoader(ModelSource):
load_fn: ``Callable[[symbol: str, ...], DataFrame]`` used to load and
return a pre-trained model. This is expected to
return either a trained model instance, or a tuple containing a
trained model instance and a :class:`Collection` of column names to
trained model instance and a :class:`Iterable` of column names to
to be used as input for the model when making predictions.
indicator_names: :class:`Iterable` of names of
:class:`pybroker.indicator.Indicator`\ s used as features of the
Expand All @@ -111,7 +110,7 @@ class ModelLoader(ModelSource):
def __init__(
self,
name: str,
load_fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]],
load_fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]],
indicator_names: Iterable[str],
input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]],
predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]],
Expand All @@ -122,7 +121,7 @@ def __init__(
)
self._load_fn = functools.partial(load_fn, **kwargs)

def __call__(self, symbol: str) -> Union[Any, tuple[Any, Collection[str]]]:
def __call__(self, symbol: str) -> Union[Any, tuple[Any, Iterable[str]]]:
"""Loads pre-trained model.
Args:
Expand All @@ -149,7 +148,7 @@ class ModelTrainer(ModelSource):
test_data: DataFrame, ...], DataFrame]`` used to train and return a
model. This is expected to return either a trained model instance,
or a tuple containing a trained model instance and a
:class:`Collection` of column names to to be used as input for the
:class:`Iterable` of column names to to be used as input for the
model when making predictions.
indicator_names: :class:`Iterable` of names of
:class:`pybroker.indicator.Indicator`\ s used as features of the
Expand All @@ -168,7 +167,7 @@ class ModelTrainer(ModelSource):
def __init__(
self,
name: str,
train_fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]],
train_fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]],
indicator_names: Iterable[str],
input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]],
predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]],
Expand All @@ -181,7 +180,7 @@ def __init__(

def __call__(
self, symbol: str, train_data: pd.DataFrame, test_data: pd.DataFrame
) -> Union[Any, tuple[Any, Collection[str]]]:
) -> Union[Any, tuple[Any, Iterable[str]]]:
"""Trains model.
Args:
Expand All @@ -203,7 +202,7 @@ def __str__(self):

def model(
name: str,
fn: Callable[..., Union[Any, tuple[Any, Collection[str]]]],
fn: Callable[..., Union[Any, tuple[Any, Iterable[str]]]],
indicators: Optional[Iterable[Indicator]] = None,
input_data_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
predict_fn: Optional[Callable[[Any, pd.DataFrame], NDArray]] = None,
Expand All @@ -221,7 +220,7 @@ def model(
If for loading, then ``fn`` has signature
``Callable[[symbol: str, ...], DataFrame]``. This is expected to
return either a trained model instance, or a tuple containing a
trained model instance and a :class:`Collection` of column names to
trained model instance and a :class:`Iterable` of column names to
to be used as input for the model when making predictions.
indicators: :class:`Iterable` of
:class:`pybroker.indicator.Indicator`\ s used as features of the
Expand Down Expand Up @@ -290,7 +289,7 @@ class ModelsMixin:

def train_models(
self,
model_syms: Collection[ModelSymbol],
model_syms: Iterable[ModelSymbol],
train_data: pd.DataFrame,
test_data: pd.DataFrame,
indicator_data: Mapping[IndicatorSymbol, pd.Series],
Expand All @@ -300,7 +299,7 @@ def train_models(
pairs.
Args:
model_syms: ``Collection`` of
model_syms: ``Iterable`` of
:class:`pybroker.common.ModelSymbol` pairs of models to train.
train_data: :class:`pandas.DataFrame` of training data.
test_data: :class:`pandas.DataFrame` of test data.
Expand Down Expand Up @@ -385,9 +384,9 @@ def _slice_by_symbol(self, symbol: str, df: pd.DataFrame) -> pd.DataFrame:

def _get_cached_models(
self,
model_syms: Collection[ModelSymbol],
model_syms: Iterable[ModelSymbol],
cache_date_fields: CacheDateFields,
) -> tuple[dict[ModelSymbol, TrainedModel], Collection[ModelSymbol]]:
) -> tuple[dict[ModelSymbol, TrainedModel], list[ModelSymbol]]:
model_syms = sorted(model_syms)
models: dict[ModelSymbol, TrainedModel] = {}
scope = StaticScope.instance()
Expand Down
3 changes: 1 addition & 2 deletions src/pybroker/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from typing import (
Any,
Callable,
Collection,
Final,
Iterable,
Literal,
Expand Down Expand Up @@ -237,7 +236,7 @@ def __init__(self, df: pd.DataFrame):
def fetch_dict(
self,
symbol: str,
names: Collection[str],
names: Iterable[str],
end_index: Optional[int] = None,
) -> dict[str, Optional[NDArray]]:
r"""Fetches a ``dict`` of column data for ``symbol``.
Expand Down

0 comments on commit 89c25f6

Please sign in to comment.