Skip to content

Commit

Permalink
[ENH] _HeterogenousMetaObject to accept list of tuples of any length (
Browse files Browse the repository at this point in the history
#206)

Mirror PR of sktime/sktime#4793 from `sktime`.

This improves the `_HeterogenousMetaObject` by widening its
functionality.

`_HeterogenousMetaObject` now allows tuples of any length in the
`_steps_attr`, as long as the zeroth elements are str names, and the
first elements are estimators
  • Loading branch information
fkiraly authored Aug 14, 2023
1 parent 9e570d0 commit 9210fdb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
11 changes: 7 additions & 4 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def _set_params(self, attr: str, **params):
# 2. Step replacement
items = getattr(self, attr)
names = []
if items:
names, _ = zip(*items)
if items and isinstance(items, (list, tuple)):
names = list(zip(*items))[0]
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
Expand All @@ -256,9 +256,12 @@ def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
"""Replace an object in attribute that contains named objects."""
# assumes `name` is a valid object name
new_objects = list(getattr(self, attr))
for i, (object_name, _) in enumerate(new_objects):
for i, obj_tpl in enumerate(new_objects):
object_name = obj_tpl[0]
if object_name == name:
new_objects[i] = (name, new_val)
new_tpl = list(obj_tpl)
new_tpl[1] = new_val
new_objects[i] = tuple(new_tpl)
break
setattr(self, attr, new_objects)

Expand Down
57 changes: 48 additions & 9 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tests for BaseMetaObject and BaseMetaEstimator mixins.
"""Tests for BaseMetaObject and BaseMetaEstimator mixins."""

tests in this module:
"""

__author__ = ["RNKuhns"]
__author__ = ["RNKuhns", "fkiraly"]
import inspect

import pytest
Expand All @@ -23,37 +18,51 @@


class MetaObjectTester(BaseMetaObject):
"""Class to test meta object functionality."""
"""Class to test meta-object functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class MetaEstimatorTester(BaseMetaEstimator):
"""Class to test meta estimator functionality."""
"""Class to test meta-estimator functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class ComponentDummy(BaseObject):
"""Class to use as components in meta-estimator."""

def __init__(self, a=7, b="something"):
self.a = a
self.b = b
super().__init__()


@pytest.fixture
def fixture_metaestimator_instance():
"""BaseMetaEstimator instance fixture."""
return BaseMetaEstimator()


@pytest.fixture
def fixture_meta_object():
"""MetaObjectTester instance fixture."""
return MetaObjectTester()


@pytest.fixture
def fixture_meta_estimator():
"""MetaEstimatorTester instance fixture."""
return MetaEstimatorTester()


Expand Down Expand Up @@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(

fixture_metaestimator_instance._is_fitted = True
assert fixture_metaestimator_instance.check_is_fitted() is None


@pytest.mark.parametrize("long_steps", (True, False))
def test_metaestimator_composite(long_steps):
"""Test composite meta-estimator functionality."""
if long_steps:
steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
else:
steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]

meta_est = MetaEstimatorTester(steps=steps)

meta_est_params = meta_est.get_params()
assert isinstance(meta_est_params, dict)
expected_keys = [
"a",
"b",
"c",
"steps",
"foo",
"bar",
"foo__a",
"foo__b",
"bar__a",
"bar__b",
]
assert set(meta_est_params.keys()) == set(expected_keys)

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"

0 comments on commit 9210fdb

Please sign in to comment.