Skip to content

Commit

Permalink
Update test_all_objects.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Aug 3, 2023
1 parent e83bcf0 commit f8d8d0a
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions skbase/testing/test_all_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,21 +659,17 @@ def test_no_between_test_case_side_effects(self, object_instance, a):
assert not hasattr(object_instance, "test__attr")
object_instance.test__attr = 42

@pytest.mark.skipif(
not _check_soft_dependencies("sklearn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_get_params(self, object_instance):
"""Check that get_params works correctly, against sklearn interface."""
from sklearn.utils.estimator_checks import (
check_get_params_invariance as _check_get_params_invariance,
)

params = object_instance.get_params()
assert isinstance(params, dict)
_check_get_params_invariance(
object_instance.__class__.__name__, object_instance
)

e = object_instance.clone()

shallow_params = e.get_params(deep=False)
deep_params = e.get_params(deep=True)

assert all(item in deep_params.items() for item in shallow_params.items())

def test_set_params(self, object_instance):
"""Check that set_params works correctly."""
Expand Down

0 comments on commit f8d8d0a

Please sign in to comment.