Skip to content

Commit

Permalink
[ENH] refactor _check_estimator_types to use record class interface (
Browse files Browse the repository at this point in the history
…sktime#7395)

This refactors `_check_estimator_types` to use the new record class
interface, removing the need to construct a full lookup of base classes,
reducing coupling in the process.
  • Loading branch information
fkiraly authored Nov 16, 2024
1 parent dfcdcc5 commit 4ab0218
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
24 changes: 24 additions & 0 deletions sktime/registry/_base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,30 @@ def _construct_base_class_register(mixin=False):
return register


def get_base_class_for_str(scitype_str):
"""Return base class for a given scitype string.
Parameters
----------
scitype_str : str, or list of str
scitype shorthand, as in scitype_name field of scitype classes
Returns
-------
base_cls : class or list of class
base class corresponding to the scitype string,
or list of base classes if input was a list
"""
if isinstance(scitype_str, list):
return [get_base_class_for_str(s) for s in scitype_str]

base_classes = _get_base_classes()
base_classes += _get_base_classes(mixin=True)
base_class_lookup = {cl.get_class_tags()["scitype_name"]: cl for cl in base_classes}
base_cls = base_class_lookup[scitype_str].get_base_class()
return base_cls


def get_base_class_register(mixin=False, include_baseobjs=True):
"""Return register of object scitypes and base classes in sktime.
Expand Down
24 changes: 17 additions & 7 deletions sktime/registry/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import pandas as pd
from skbase.lookup import all_objects

from sktime.registry._base_classes import get_base_class_lookup, get_obj_scitype_list
from sktime.registry._base_classes import (
get_base_class_for_str,
get_obj_scitype_list,
)
from sktime.registry._tags import ESTIMATOR_TAG_REGISTER


Expand Down Expand Up @@ -362,22 +365,30 @@ def is_tag_for_type(tag, estimator_types):


def _check_estimator_types(estimator_types):
"""Return list of classes corresponding to type strings."""
"""Return list of classes corresponding to type strings.
Parameters
----------
estimator_types: str, or list of str
Returns
-------
estimator_types: list of classes
base classes corresponding to scitype strings in estimator_types
"""
estimator_types = deepcopy(estimator_types)

if not isinstance(estimator_types, list):
estimator_types = [estimator_types] # make iterable

def _get_err_msg(estimator_type):
return (
f"Parameter `estimator_type` must be None, a string or a list of "
f"Parameter `estimator_type` must be a string or a list of "
f"strings. Valid string values are: "
f"{get_obj_scitype_list()}, but found: "
f"{repr(estimator_type)}"
)

BASE_CLASS_LOOKUP = get_base_class_lookup()

for i, estimator_type in enumerate(estimator_types):
if not isinstance(estimator_type, (type, str)):
raise ValueError(
Expand All @@ -386,8 +397,7 @@ def _get_err_msg(estimator_type):
if isinstance(estimator_type, str):
if estimator_type not in get_obj_scitype_list():
raise ValueError(_get_err_msg(estimator_type))
estimator_type = BASE_CLASS_LOOKUP[estimator_type]
estimator_types[i] = estimator_type
estimator_types[i] = get_base_class_for_str(estimator_type)
elif isinstance(estimator_type, type):
pass
else:
Expand Down

0 comments on commit 4ab0218

Please sign in to comment.