From 9210fdb9dd9e4f441dda42aa2f61db6d5bd8da28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 14 Aug 2023 21:12:54 +0100 Subject: [PATCH] [ENH] `_HeterogenousMetaObject` to accept list of tuples of any length (#206) Mirror PR of https://github.com/sktime/sktime/pull/4793 from `sktime`. This improves the `_HeterogenousMetaObject` by widening its functionality. `_HeterogenousMetaObject` now allows tuples of any length in the `_steps_attr`, as long as the zeroth elements are str names, and the first elements are estimators --- skbase/base/_meta.py | 11 +++++--- skbase/tests/test_meta.py | 57 ++++++++++++++++++++++++++++++++------- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/skbase/base/_meta.py b/skbase/base/_meta.py index 57681a1d..a5042484 100644 --- a/skbase/base/_meta.py +++ b/skbase/base/_meta.py @@ -243,8 +243,8 @@ def _set_params(self, attr: str, **params): # 2. Step replacement items = getattr(self, attr) names = [] - if items: - names, _ = zip(*items) + if items and isinstance(items, (list, tuple)): + names = list(zip(*items))[0] for name in list(params.keys()): if "__" not in name and name in names: self._replace_object(attr, name, params.pop(name)) @@ -256,9 +256,12 @@ def _replace_object(self, attr: str, name: str, new_val: Any) -> None: """Replace an object in attribute that contains named objects.""" # assumes `name` is a valid object name new_objects = list(getattr(self, attr)) - for i, (object_name, _) in enumerate(new_objects): + for i, obj_tpl in enumerate(new_objects): + object_name = obj_tpl[0] if object_name == name: - new_objects[i] = (name, new_val) + new_tpl = list(obj_tpl) + new_tpl[1] = new_val + new_objects[i] = tuple(new_tpl) break setattr(self, attr, new_objects) diff --git a/skbase/tests/test_meta.py b/skbase/tests/test_meta.py index 75414992..f2c6c8fa 100644 --- a/skbase/tests/test_meta.py +++ b/skbase/tests/test_meta.py @@ -1,13 +1,8 @@ # -*- coding: utf-8 -*- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) -"""Tests for BaseMetaObject and BaseMetaEstimator mixins. +"""Tests for BaseMetaObject and BaseMetaEstimator mixins.""" -tests in this module: - - -""" - -__author__ = ["RNKuhns"] +__author__ = ["RNKuhns", "fkiraly"] import inspect import pytest @@ -23,37 +18,51 @@ class MetaObjectTester(BaseMetaObject): - """Class to test meta object functionality.""" + """Class to test meta-object functionality.""" def __init__(self, a=7, b="something", c=None, steps=None): self.a = a self.b = b self.c = c self.steps = steps + super().__init__() class MetaEstimatorTester(BaseMetaEstimator): - """Class to test meta estimator functionality.""" + """Class to test meta-estimator functionality.""" def __init__(self, a=7, b="something", c=None, steps=None): self.a = a self.b = b self.c = c self.steps = steps + super().__init__() + + +class ComponentDummy(BaseObject): + """Class to use as components in meta-estimator.""" + + def __init__(self, a=7, b="something"): + self.a = a + self.b = b + super().__init__() @pytest.fixture def fixture_metaestimator_instance(): + """BaseMetaEstimator instance fixture.""" return BaseMetaEstimator() @pytest.fixture def fixture_meta_object(): + """MetaObjectTester instance fixture.""" return MetaObjectTester() @pytest.fixture def fixture_meta_estimator(): + """MetaEstimatorTester instance fixture.""" return MetaEstimatorTester() @@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted( fixture_metaestimator_instance._is_fitted = True assert fixture_metaestimator_instance.check_is_fitted() is None + + +@pytest.mark.parametrize("long_steps", (True, False)) +def test_metaestimator_composite(long_steps): + """Test composite meta-estimator functionality.""" + if long_steps: + steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))] + else: + steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)] + + meta_est = MetaEstimatorTester(steps=steps) + + meta_est_params = meta_est.get_params() + assert isinstance(meta_est_params, dict) + expected_keys = [ + "a", + "b", + "c", + "steps", + "foo", + "bar", + "foo__a", + "foo__b", + "bar__a", + "bar__b", + ] + assert set(meta_est_params.keys()) == set(expected_keys) + + meta_est.set_params(bar__b="something else") + assert meta_est.get_params()["bar__b"] == "something else"