Skip to content

Commit

Permalink
[BUG] fix for get_fitted_params in _HeterogenousMetaEstimator (#191)
Browse files Browse the repository at this point in the history
Upstream version of `sktime` bugfix
sktime/sktime#4633

This fixes a bug in `get_fitted_params` of `_HeterogenousMetaEstimator`, which is at the root of sktime/sktime#4574.

`get_fitted_params` of `_HeterogenousMetaEstimator` was accidentally calling the private interface point of components rather than the public interface point that it should have.

For reviewers: the first call should be to `_get_fitted_params` of the base class, which is correct. Subsequent calls to components should be to the public interface, `get_fitted_params`.
  • Loading branch information
fkiraly authored Aug 14, 2023
1 parent 9fbe066 commit 9e570d0
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,16 @@ def _get_params(
"""
# Set variables that let us use same code for retrieving params or fitted params
if fitted:
method = "_get_fitted_params"
method_shallow = "_get_fitted_params"
method_public = "get_fitted_params"
deepkw = {}
else:
method = "get_params"
method_shallow = "get_params"
method_public = "get_params"
deepkw = {"deep": deep}

# Get the direct params/fitted params
out = getattr(super(), method)(**deepkw)
out = getattr(super(), method_shallow)(**deepkw)

if deep and hasattr(self, attr):
named_objects = getattr(self, attr)
Expand All @@ -207,8 +209,15 @@ def _get_params(
]
out.update(named_objects_)
for name, obj in named_objects_:
if hasattr(obj, method):
for key, value in getattr(obj, method)(**deepkw).items():
# checks estimator has the method we want to call
cond1 = hasattr(obj, method_public)
# checks estimator is fitted if calling get_fitted_params
is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted
# if we call get_params and not get_fitted_params, this is True
cond2 = not fitted or is_fitted
# check both conditions together
if cond1 and cond2:
for key, value in getattr(obj, method_public)(**deepkw).items():
out["%s__%s" % (name, key)] = value
return out

Expand Down

0 comments on commit 9e570d0

Please sign in to comment.