From b0e4df9114ecbdd5c3975401a630727d66aec4b7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 14:09:33 +0100 Subject: [PATCH 1/2] TST: test binops vs. np.generics --- array_api_strict/_array_object.py | 86 +++++++------- array_api_strict/tests/test_array_object.py | 125 ++++++++++++++++---- 2 files changed, 142 insertions(+), 69 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 483952e..6f2c506 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -233,15 +233,15 @@ def _check_allowed_dtypes( return other - def _check_device(self, other: Array | bool | int | float | complex) -> None: - """Check that other is on a device compatible with the current array""" - if isinstance(other, (bool, int, float, complex)): - return - elif isinstance(other, Array): + def _check_type_device(self, other: Array | bool | int | float | complex) -> None: + """Check that other is either a Python scalar or an array on a device + compatible with the current array. + """ + if isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - else: - raise TypeError(f"Expected Array | python scalar; got {type(other)}") + elif not isinstance(other, bool | int | float | complex): + raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: @@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __add__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other @@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __and__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other @@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __eq__. """ - self._check_device(other) + self._check_type_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __floordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other @@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ge__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other @@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __gt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other @@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array: """ Performs the operation __le__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other @@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array: """ Performs the operation __lshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other @@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __lt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other @@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array: """ Performs the operation __matmul__. """ - self._check_device(other) + self._check_type_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") @@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __mod__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other @@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __mul__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other @@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __ne__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other @@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __or__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other @@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array: """ from ._elementwise_functions import pow # type: ignore[attr-defined] - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other @@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other @@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __sub__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other @@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __truediv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other @@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __xor__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other @@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __iadd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __radd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other @@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __iand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other @@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ifloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rfloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other @@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array: """ Performs the operation __ilshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rlshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other @@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__imatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array: other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) @@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ror__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other @@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) @@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) @@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) @@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dbab1af..91f3838 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -255,30 +255,37 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): func(s) return False +binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer or boolean", + "__eq__": "all", + "__floordiv__": "real numeric", + "__ge__": "real numeric", + "__gt__": "real numeric", + "__le__": "real numeric", + "__lshift__": "integer", + "__lt__": "real numeric", + "__mod__": "real numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer or boolean", + "__pow__": "numeric", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", +} +unary_op_dtypes = { + "__abs__": "numeric", + "__invert__": "integer or boolean", + "__neg__": "numeric", + "__pos__": "numeric", +} def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise - binary_op_dtypes = { - "__add__": "numeric", - "__and__": "integer or boolean", - "__eq__": "all", - "__floordiv__": "real numeric", - "__ge__": "real numeric", - "__gt__": "real numeric", - "__le__": "real numeric", - "__lshift__": "integer", - "__lt__": "real numeric", - "__mod__": "real numeric", - "__mul__": "numeric", - "__ne__": "all", - "__or__": "integer or boolean", - "__pow__": "numeric", - "__rshift__": "integer", - "__sub__": "numeric", - "__truediv__": "floating-point", - "__xor__": "integer or boolean", - } + # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -337,12 +344,6 @@ def _array_vals(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) - unary_op_dtypes = { - "__abs__": "numeric", - "__invert__": "integer or boolean", - "__neg__": "numeric", - "__pos__": "numeric", - } for op, dtypes in unary_op_dtypes.items(): for a in _array_vals(): if ( @@ -410,6 +411,78 @@ def _matmul_array_vals(): x.__imatmul__(y) +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("real numeric", "floating-point") + ], +) +def test_binary_operators_vs_numpy_int(op): + """np.int64 is not a subclass of int and must be disallowed""" + a = asarray(1) + i64 = np.int64(1) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + getattr(a, op)(i64) + + +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("integer", "integer or boolean") + ], +) +def test_binary_operators_vs_numpy_float(op): + """ + np.float64 is a subclass of float and must be allowed. + np.float32 is not and must be rejected. + """ + a = asarray(1.) + f64 = np.float64(1.) + f32 = np.float32(1.) + func = getattr(a, op) + for op in binary_op_dtypes: + assert isinstance(func(f64), Array) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + func(f32) + + +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("integer", "integer or boolean", "real numeric") + ], +) +def test_binary_operators_vs_numpy_complex(op): + """ + np.complex128 is a subclass of complex and must be allowed. + np.complex64 is not and must be rejected. + """ + a = asarray(1.) + c64 = np.complex64(1.) + c128 = np.complex128(1.) + func = getattr(a, op) + for op in binary_op_dtypes: + assert isinstance(func(c128), Array) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + func(c64) + + +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_device_mismatch(op, dtypes): + if dtypes in ("real numeric", "floating-point"): + dtype = float64 + else: + dtype = int64 + + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + with pytest.raises(ValueError, match="different devices"): + getattr(a, op)(b) + + def test_python_scalar_construtors(): b = asarray(False) i = asarray(0) From e7fcd348a454e95b354de3d36573668296fac3d2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 10:31:01 +0100 Subject: [PATCH 2/2] Disallow float64 and complex128 --- array_api_strict/_array_object.py | 3 +- array_api_strict/tests/test_array_object.py | 97 ++++++++------------- 2 files changed, 38 insertions(+), 62 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 6f2c506..579da90 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -240,7 +240,8 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non if isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - elif not isinstance(other, bool | int | float | complex): + # Disallow subclasses of Python scalars, such as np.float64 and np.complex128 + elif type(other) not in (bool, int, float, complex): raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 91f3838..e950be5 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -411,72 +411,47 @@ def _matmul_array_vals(): x.__imatmul__(y) -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("real numeric", "floating-point") - ], -) -def test_binary_operators_vs_numpy_int(op): - """np.int64 is not a subclass of int and must be disallowed""" - a = asarray(1) - i64 = np.int64(1) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - getattr(a, op)(i64) - - -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("integer", "integer or boolean") - ], -) -def test_binary_operators_vs_numpy_float(op): - """ - np.float64 is a subclass of float and must be allowed. - np.float32 is not and must be rejected. - """ - a = asarray(1.) - f64 = np.float64(1.) - f32 = np.float32(1.) - func = getattr(a, op) - for op in binary_op_dtypes: - assert isinstance(func(f64), Array) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - func(f32) - - -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("integer", "integer or boolean", "real numeric") - ], -) -def test_binary_operators_vs_numpy_complex(op): - """ - np.complex128 is a subclass of complex and must be allowed. - np.complex64 is not and must be rejected. +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_vs_numpy_generics(op, dtypes): + """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 + are disallowed in binary operators. + np.float64 and np.complex128 are subclasses of float and complex, so they need + special treatment in order to be rejected. """ - a = asarray(1.) - c64 = np.complex64(1.) - c128 = np.complex128(1.) - func = getattr(a, op) - for op in binary_op_dtypes: - assert isinstance(func(c128), Array) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - func(c64) + match = "Expected Array or Python scalar" + + if dtypes not in ("numeric", "integer", "real numeric", "floating-point"): + a = asarray(True) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.bool_(True)) + + if dtypes != "floating-point": + a = asarray(1) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.int64(1)) + + if dtypes not in ("integer", "integer or boolean"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.float32(1.)) + with pytest.raises(TypeError, match=match): + func(np.float64(1.)) + + if dtypes not in ("integer", "integer or boolean", "real numeric"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.complex64(1.)) + with pytest.raises(TypeError, match=match): + func(np.complex128(1.)) @pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) def test_binary_operators_device_mismatch(op, dtypes): - if dtypes in ("real numeric", "floating-point"): - dtype = float64 - else: - dtype = int64 - + dtype = float64 if dtypes == "floating-point" else int64 a = asarray(1, dtype=dtype, device=CPU_DEVICE) b = asarray(1, dtype=dtype, device=Device("device1")) with pytest.raises(ValueError, match="different devices"):