diff --git a/skbase/testing/utils/deep_equals.py b/skbase/testing/utils/deep_equals.py index 22370fde..c225fb38 100644 --- a/skbase/testing/utils/deep_equals.py +++ b/skbase/testing/utils/deep_equals.py @@ -15,6 +15,11 @@ __all__: List[str] = ["deep_equals"] +# flag variables for available soft dependencies +pandas_available = _check_soft_dependencies("pandas", severity="none") +numpy_available = _check_soft_dependencies("numpy", severity="none") + + def deep_equals(x, y, return_msg=False): """Test two objects for equality in value. @@ -70,10 +75,6 @@ def ret(is_equal, msg): # we now know all types are the same # so now we compare values - # flag variables for available soft dependencies - pandas_available = _check_soft_dependencies("pandas", severity="none") - numpy_available = _check_soft_dependencies("numpy", severity="none") - if numpy_available: import numpy as np @@ -93,6 +94,8 @@ def ret(is_equal, msg): return ret(*_tuple_equals(x, y, return_msg=True)) elif isinstance(x, dict): return ret(*_dict_equals(x, y, return_msg=True)) + elif _is_npnan(x): + return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}") elif isclass(x): return ret(x == y, f".class, x={x.__name__} != y={y.__name__}") elif type(x).__name__ == "ForecastingHorizon": @@ -125,6 +128,17 @@ def _is_npndarray(x): return clstr == "ndarray" +def _is_npnan(x): + + if numpy_available: + import numpy as np + + return isinstance(x, float) and np.isnan(x) + + else: + return False + + def _coerce_list(x): """Coerce x to list.""" if not isinstance(x, (list, tuple)): diff --git a/skbase/testing/utils/tests/test_deep_equals.py b/skbase/testing/utils/tests/test_deep_equals.py index 684c5b6f..a832e05a 100644 --- a/skbase/testing/utils/tests/test_deep_equals.py +++ b/skbase/testing/utils/tests/test_deep_equals.py @@ -16,6 +16,9 @@ [([([([()])])])], np.array([2, 3, 4]), np.array([2, 4, 5]), + 3.5, + 4.2, + np.nan, ] diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 7bb22766..b84ce692 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -122,6 +122,7 @@ "_pandas_equals", "_dict_equals", "_is_pandas", + "_is_npnan", "_tuple_equals", "_fh_equals", "deep_equals", @@ -152,6 +153,7 @@ class Parent(BaseObject): _tags = {"A": "1", "B": 2, "C": 1234, "3": "D"} def __init__(self, a="something", b=7, c=None): + """Initialize the class.""" self.a = a self.b = b self.c = c