diff --git a/openeo_processes_dask/process_implementations/comparison.py b/openeo_processes_dask/process_implementations/comparison.py index a99a6ba..f5e688b 100644 --- a/openeo_processes_dask/process_implementations/comparison.py +++ b/openeo_processes_dask/process_implementations/comparison.py @@ -13,6 +13,8 @@ __all__ = [ "is_infinite", "is_valid", + "is_nan", + "is_nodata", "eq", "neq", "gt", @@ -35,6 +37,16 @@ def is_valid(x: ArrayLike): return np.logical_and(notnull(x), finite) +def is_nodata(x: ArrayLike): + return x is None + + +def is_nan(x: ArrayLike): + if is_nodata(x): + return is_nodata(x) + return np.isnan(x) + + def eq( x: ArrayLike, y: ArrayLike, diff --git a/tests/test_comparison.py b/tests/test_comparison.py index d115884..d1ae993 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -9,13 +9,7 @@ from openeo_pg_parser_networkx.pg_schema import ParameterReference from openeo_processes_dask.process_implementations import merge_cubes -from openeo_processes_dask.process_implementations.comparison import ( - between, - eq, - is_infinite, - is_valid, - neq, -) +from openeo_processes_dask.process_implementations.comparison import * from openeo_processes_dask.process_implementations.cubes.apply import apply from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension from tests.general_checks import assert_numpy_equals_dask_numpy, general_output_checks @@ -73,6 +67,33 @@ def test_is_inf(value, expected, is_dask): assert hasattr(output, "dask") +@pytest.mark.parametrize( + "value,expected", + [ + (1, False), + (np.nan, True), + ], +) +def test_is_nan(value, expected): + value = np.asarray(value) + + is_dask = da.from_array(value) + + output = is_nan(value) + np.testing.assert_array_equal(output, expected) + + assert hasattr(is_nan(is_dask), "dask") + + +@pytest.mark.parametrize( + "value,expected", + [(1, False), ("Test", False), (None, True), ([np.nan, np.nan], False)], +) +def test_is_nodata(value, expected): + output = is_nodata(value) + np.testing.assert_array_equal(output, expected) + + @pytest.mark.parametrize( "x, y, delta, case_sensitive", [