Skip to content

Commit

Permalink
Merge pull request #111 from sktime/deep_equals_nan
Browse files Browse the repository at this point in the history
[BUG] fix faulty `BaseObject.__eq__` and `deep_equals` if an attribute or nested structure contains `np.nan`
  • Loading branch information
RNKuhns authored Feb 1, 2023
2 parents 86e0c4e + f38782e commit acaefb3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
22 changes: 18 additions & 4 deletions skbase/testing/utils/deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
__all__: List[str] = ["deep_equals"]


# flag variables for available soft dependencies
pandas_available = _check_soft_dependencies("pandas", severity="none")
numpy_available = _check_soft_dependencies("numpy", severity="none")


def deep_equals(x, y, return_msg=False):
"""Test two objects for equality in value.
Expand Down Expand Up @@ -70,10 +75,6 @@ def ret(is_equal, msg):
# we now know all types are the same
# so now we compare values

# flag variables for available soft dependencies
pandas_available = _check_soft_dependencies("pandas", severity="none")
numpy_available = _check_soft_dependencies("numpy", severity="none")

if numpy_available:
import numpy as np

Expand All @@ -93,6 +94,8 @@ def ret(is_equal, msg):
return ret(*_tuple_equals(x, y, return_msg=True))
elif isinstance(x, dict):
return ret(*_dict_equals(x, y, return_msg=True))
elif _is_npnan(x):
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
elif isclass(x):
return ret(x == y, f".class, x={x.__name__} != y={y.__name__}")
elif type(x).__name__ == "ForecastingHorizon":
Expand Down Expand Up @@ -125,6 +128,17 @@ def _is_npndarray(x):
return clstr == "ndarray"


def _is_npnan(x):

if numpy_available:
import numpy as np

return isinstance(x, float) and np.isnan(x)

else:
return False


def _coerce_list(x):
"""Coerce x to list."""
if not isinstance(x, (list, tuple)):
Expand Down
3 changes: 3 additions & 0 deletions skbase/testing/utils/tests/test_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
[([([([()])])])],
np.array([2, 3, 4]),
np.array([2, 4, 5]),
3.5,
4.2,
np.nan,
]


Expand Down
2 changes: 2 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"_pandas_equals",
"_dict_equals",
"_is_pandas",
"_is_npnan",
"_tuple_equals",
"_fh_equals",
"deep_equals",
Expand Down Expand Up @@ -152,6 +153,7 @@ class Parent(BaseObject):
_tags = {"A": "1", "B": 2, "C": 1234, "3": "D"}

def __init__(self, a="something", b=7, c=None):
"""Initialize the class."""
self.a = a
self.b = b
self.c = c
Expand Down

0 comments on commit acaefb3

Please sign in to comment.