Skip to content

Commit 94e9529

Browse files
committed
TST: test binops vs. np.generics
1 parent 25cc3d7 commit 94e9529

File tree

2 files changed

+142
-69
lines changed

2 files changed

+142
-69
lines changed

array_api_strict/_array_object.py

+43-43
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,15 @@ def _check_allowed_dtypes(
233233

234234
return other
235235

236-
def _check_device(self, other: Array | bool | int | float | complex) -> None:
237-
"""Check that other is on a device compatible with the current array"""
238-
if isinstance(other, (bool, int, float, complex)):
239-
return
240-
elif isinstance(other, Array):
236+
def _check_type_device(self, other: Array | bool | int | float | complex) -> None:
237+
"""Check that other is either a Python scalar or an array on a device
238+
compatible with the current array.
239+
"""
240+
if isinstance(other, Array):
241241
if self.device != other.device:
242242
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
243-
else:
244-
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
243+
elif not isinstance(other, bool | int | float | complex):
244+
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")
245245

246246
# Helper function to match the type promotion rules in the spec
247247
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
@@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
542542
"""
543543
Performs the operation __add__.
544544
"""
545-
self._check_device(other)
545+
self._check_type_device(other)
546546
other = self._check_allowed_dtypes(other, "numeric", "__add__")
547547
if other is NotImplemented:
548548
return other
@@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
554554
"""
555555
Performs the operation __and__.
556556
"""
557-
self._check_device(other)
557+
self._check_type_device(other)
558558
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
559559
if other is NotImplemented:
560560
return other
@@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
651651
"""
652652
Performs the operation __eq__.
653653
"""
654-
self._check_device(other)
654+
self._check_type_device(other)
655655
# Even though "all" dtypes are allowed, we still require them to be
656656
# promotable with each other.
657657
other = self._check_allowed_dtypes(other, "all", "__eq__")
@@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
677677
"""
678678
Performs the operation __floordiv__.
679679
"""
680-
self._check_device(other)
680+
self._check_type_device(other)
681681
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
682682
if other is NotImplemented:
683683
return other
@@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
689689
"""
690690
Performs the operation __ge__.
691691
"""
692-
self._check_device(other)
692+
self._check_type_device(other)
693693
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
694694
if other is NotImplemented:
695695
return other
@@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
741741
"""
742742
Performs the operation __gt__.
743743
"""
744-
self._check_device(other)
744+
self._check_type_device(other)
745745
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
746746
if other is NotImplemented:
747747
return other
@@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
796796
"""
797797
Performs the operation __le__.
798798
"""
799-
self._check_device(other)
799+
self._check_type_device(other)
800800
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
801801
if other is NotImplemented:
802802
return other
@@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808808
"""
809809
Performs the operation __lshift__.
810810
"""
811-
self._check_device(other)
811+
self._check_type_device(other)
812812
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
813813
if other is NotImplemented:
814814
return other
@@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
820820
"""
821821
Performs the operation __lt__.
822822
"""
823-
self._check_device(other)
823+
self._check_type_device(other)
824824
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
825825
if other is NotImplemented:
826826
return other
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832832
"""
833833
Performs the operation __matmul__.
834834
"""
835-
self._check_device(other)
835+
self._check_type_device(other)
836836
# matmul is not defined for scalars, but without this, we may get
837837
# the wrong error message from asarray.
838838
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
@@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
845845
"""
846846
Performs the operation __mod__.
847847
"""
848-
self._check_device(other)
848+
self._check_type_device(other)
849849
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
850850
if other is NotImplemented:
851851
return other
@@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
857857
"""
858858
Performs the operation __mul__.
859859
"""
860-
self._check_device(other)
860+
self._check_type_device(other)
861861
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
862862
if other is NotImplemented:
863863
return other
@@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
869869
"""
870870
Performs the operation __ne__.
871871
"""
872-
self._check_device(other)
872+
self._check_type_device(other)
873873
other = self._check_allowed_dtypes(other, "all", "__ne__")
874874
if other is NotImplemented:
875875
return other
@@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
890890
"""
891891
Performs the operation __or__.
892892
"""
893-
self._check_device(other)
893+
self._check_type_device(other)
894894
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
895895
if other is NotImplemented:
896896
return other
@@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
913913
"""
914914
from ._elementwise_functions import pow # type: ignore[attr-defined]
915915

916-
self._check_device(other)
916+
self._check_type_device(other)
917917
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
918918
if other is NotImplemented:
919919
return other
@@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925925
"""
926926
Performs the operation __rshift__.
927927
"""
928-
self._check_device(other)
928+
self._check_type_device(other)
929929
other = self._check_allowed_dtypes(other, "integer", "__rshift__")
930930
if other is NotImplemented:
931931
return other
@@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
961961
"""
962962
Performs the operation __sub__.
963963
"""
964-
self._check_device(other)
964+
self._check_type_device(other)
965965
other = self._check_allowed_dtypes(other, "numeric", "__sub__")
966966
if other is NotImplemented:
967967
return other
@@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975975
"""
976976
Performs the operation __truediv__.
977977
"""
978-
self._check_device(other)
978+
self._check_type_device(other)
979979
other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
980980
if other is NotImplemented:
981981
return other
@@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
987987
"""
988988
Performs the operation __xor__.
989989
"""
990-
self._check_device(other)
990+
self._check_type_device(other)
991991
other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
992992
if other is NotImplemented:
993993
return other
@@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
999999
"""
10001000
Performs the operation __iadd__.
10011001
"""
1002-
self._check_device(other)
1002+
self._check_type_device(other)
10031003
other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
10041004
if other is NotImplemented:
10051005
return other
@@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
10101010
"""
10111011
Performs the operation __radd__.
10121012
"""
1013-
self._check_device(other)
1013+
self._check_type_device(other)
10141014
other = self._check_allowed_dtypes(other, "numeric", "__radd__")
10151015
if other is NotImplemented:
10161016
return other
@@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
10221022
"""
10231023
Performs the operation __iand__.
10241024
"""
1025-
self._check_device(other)
1025+
self._check_type_device(other)
10261026
other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
10271027
if other is NotImplemented:
10281028
return other
@@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
10331033
"""
10341034
Performs the operation __rand__.
10351035
"""
1036-
self._check_device(other)
1036+
self._check_type_device(other)
10371037
other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
10381038
if other is NotImplemented:
10391039
return other
@@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10451045
"""
10461046
Performs the operation __ifloordiv__.
10471047
"""
1048-
self._check_device(other)
1048+
self._check_type_device(other)
10491049
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
10501050
if other is NotImplemented:
10511051
return other
@@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
10561056
"""
10571057
Performs the operation __rfloordiv__.
10581058
"""
1059-
self._check_device(other)
1059+
self._check_type_device(other)
10601060
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
10611061
if other is NotImplemented:
10621062
return other
@@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
10681068
"""
10691069
Performs the operation __ilshift__.
10701070
"""
1071-
self._check_device(other)
1071+
self._check_type_device(other)
10721072
other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
10731073
if other is NotImplemented:
10741074
return other
@@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
10791079
"""
10801080
Performs the operation __rlshift__.
10811081
"""
1082-
self._check_device(other)
1082+
self._check_type_device(other)
10831083
other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
10841084
if other is NotImplemented:
10851085
return other
@@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array:
10961096
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
10971097
if other is NotImplemented:
10981098
return other
1099-
self._check_device(other)
1099+
self._check_type_device(other)
11001100
res = self._array.__imatmul__(other._array)
11011101
return self.__class__._new(res, device=self.device)
11021102

@@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11091109
other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
11101110
if other is NotImplemented:
11111111
return other
1112-
self._check_device(other)
1112+
self._check_type_device(other)
11131113
res = self._array.__rmatmul__(other._array)
11141114
return self.__class__._new(res, device=self.device)
11151115

@@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
11301130
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
11311131
if other is NotImplemented:
11321132
return other
1133-
self._check_device(other)
1133+
self._check_type_device(other)
11341134
self, other = self._normalize_two_args(self, other)
11351135
res = self._array.__rmod__(other._array)
11361136
return self.__class__._new(res, device=self.device)
@@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11521152
other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
11531153
if other is NotImplemented:
11541154
return other
1155-
self._check_device(other)
1155+
self._check_type_device(other)
11561156
self, other = self._normalize_two_args(self, other)
11571157
res = self._array.__rmul__(other._array)
11581158
return self.__class__._new(res, device=self.device)
@@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
11711171
"""
11721172
Performs the operation __ror__.
11731173
"""
1174-
self._check_device(other)
1174+
self._check_type_device(other)
11751175
other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
11761176
if other is NotImplemented:
11771177
return other
@@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12191219
other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
12201220
if other is NotImplemented:
12211221
return other
1222-
self._check_device(other)
1222+
self._check_type_device(other)
12231223
self, other = self._normalize_two_args(self, other)
12241224
res = self._array.__rrshift__(other._array)
12251225
return self.__class__._new(res, device=self.device)
@@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12411241
other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
12421242
if other is NotImplemented:
12431243
return other
1244-
self._check_device(other)
1244+
self._check_type_device(other)
12451245
self, other = self._normalize_two_args(self, other)
12461246
res = self._array.__rsub__(other._array)
12471247
return self.__class__._new(res, device=self.device)
@@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12631263
other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
12641264
if other is NotImplemented:
12651265
return other
1266-
self._check_device(other)
1266+
self._check_type_device(other)
12671267
self, other = self._normalize_two_args(self, other)
12681268
res = self._array.__rtruediv__(other._array)
12691269
return self.__class__._new(res, device=self.device)
@@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
12851285
other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
12861286
if other is NotImplemented:
12871287
return other
1288-
self._check_device(other)
1288+
self._check_type_device(other)
12891289
self, other = self._normalize_two_args(self, other)
12901290
res = self._array.__rxor__(other._array)
12911291
return self.__class__._new(res, device=self.device)

0 commit comments

Comments
 (0)