From 4ab021886522d5c778b692d5c9003bae8ace8ff4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 16 Nov 2024 20:25:58 +0100 Subject: [PATCH] [ENH] refactor `_check_estimator_types` to use record class interface (#7395) This refactors `_check_estimator_types` to use the new record class interface, removing the need to construct a full lookup of base classes, reducing coupling in the process. --- sktime/registry/_base_classes.py | 24 ++++++++++++++++++++++++ sktime/registry/_lookup.py | 24 +++++++++++++++++------- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/sktime/registry/_base_classes.py b/sktime/registry/_base_classes.py index d9b49249bc2..c7819d8bc1b 100644 --- a/sktime/registry/_base_classes.py +++ b/sktime/registry/_base_classes.py @@ -399,6 +399,30 @@ def _construct_base_class_register(mixin=False): return register +def get_base_class_for_str(scitype_str): + """Return base class for a given scitype string. + + Parameters + ---------- + scitype_str : str, or list of str + scitype shorthand, as in scitype_name field of scitype classes + + Returns + ------- + base_cls : class or list of class + base class corresponding to the scitype string, + or list of base classes if input was a list + """ + if isinstance(scitype_str, list): + return [get_base_class_for_str(s) for s in scitype_str] + + base_classes = _get_base_classes() + base_classes += _get_base_classes(mixin=True) + base_class_lookup = {cl.get_class_tags()["scitype_name"]: cl for cl in base_classes} + base_cls = base_class_lookup[scitype_str].get_base_class() + return base_cls + + def get_base_class_register(mixin=False, include_baseobjs=True): """Return register of object scitypes and base classes in sktime. diff --git a/sktime/registry/_lookup.py b/sktime/registry/_lookup.py index 93894f7e333..b72b7b1c198 100644 --- a/sktime/registry/_lookup.py +++ b/sktime/registry/_lookup.py @@ -21,7 +21,10 @@ import pandas as pd from skbase.lookup import all_objects -from sktime.registry._base_classes import get_base_class_lookup, get_obj_scitype_list +from sktime.registry._base_classes import ( + get_base_class_for_str, + get_obj_scitype_list, +) from sktime.registry._tags import ESTIMATOR_TAG_REGISTER @@ -362,7 +365,17 @@ def is_tag_for_type(tag, estimator_types): def _check_estimator_types(estimator_types): - """Return list of classes corresponding to type strings.""" + """Return list of classes corresponding to type strings. + + Parameters + ---------- + estimator_types: str, or list of str + + Returns + ------- + estimator_types: list of classes + base classes corresponding to scitype strings in estimator_types + """ estimator_types = deepcopy(estimator_types) if not isinstance(estimator_types, list): @@ -370,14 +383,12 @@ def _check_estimator_types(estimator_types): def _get_err_msg(estimator_type): return ( - f"Parameter `estimator_type` must be None, a string or a list of " + f"Parameter `estimator_type` must be a string or a list of " f"strings. Valid string values are: " f"{get_obj_scitype_list()}, but found: " f"{repr(estimator_type)}" ) - BASE_CLASS_LOOKUP = get_base_class_lookup() - for i, estimator_type in enumerate(estimator_types): if not isinstance(estimator_type, (type, str)): raise ValueError( @@ -386,8 +397,7 @@ def _get_err_msg(estimator_type): if isinstance(estimator_type, str): if estimator_type not in get_obj_scitype_list(): raise ValueError(_get_err_msg(estimator_type)) - estimator_type = BASE_CLASS_LOOKUP[estimator_type] - estimator_types[i] = estimator_type + estimator_types[i] = get_base_class_for_str(estimator_type) elif isinstance(estimator_type, type): pass else: