@@ -233,15 +233,15 @@ def _check_allowed_dtypes(
233
233
234
234
return other
235
235
236
- def _check_device (self , other : Array | bool | int | float | complex ) -> None :
237
- """Check that other is on a device compatible with the current array"""
238
- if isinstance ( other , ( bool , int , float , complex )):
239
- return
240
- elif isinstance (other , Array ):
236
+ def _check_type_device (self , other : Array | bool | int | float | complex ) -> None :
237
+ """Check that other is either a Python scalar or an array on a device
238
+ compatible with the current array.
239
+ """
240
+ if isinstance (other , Array ):
241
241
if self .device != other .device :
242
242
raise ValueError (f"Arrays from two different devices ({ self .device } and { other .device } ) can not be combined." )
243
- else :
244
- raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
243
+ elif not isinstance ( other , bool | int | float | complex ) :
244
+ raise TypeError (f"Expected Array or Python scalar; got { type (other )} " )
245
245
246
246
# Helper function to match the type promotion rules in the spec
247
247
def _promote_scalar (self , scalar : bool | int | float | complex ) -> Array :
@@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
542
542
"""
543
543
Performs the operation __add__.
544
544
"""
545
- self ._check_device (other )
545
+ self ._check_type_device (other )
546
546
other = self ._check_allowed_dtypes (other , "numeric" , "__add__" )
547
547
if other is NotImplemented :
548
548
return other
@@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
554
554
"""
555
555
Performs the operation __and__.
556
556
"""
557
- self ._check_device (other )
557
+ self ._check_type_device (other )
558
558
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__and__" )
559
559
if other is NotImplemented :
560
560
return other
@@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
651
651
"""
652
652
Performs the operation __eq__.
653
653
"""
654
- self ._check_device (other )
654
+ self ._check_type_device (other )
655
655
# Even though "all" dtypes are allowed, we still require them to be
656
656
# promotable with each other.
657
657
other = self ._check_allowed_dtypes (other , "all" , "__eq__" )
@@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
677
677
"""
678
678
Performs the operation __floordiv__.
679
679
"""
680
- self ._check_device (other )
680
+ self ._check_type_device (other )
681
681
other = self ._check_allowed_dtypes (other , "real numeric" , "__floordiv__" )
682
682
if other is NotImplemented :
683
683
return other
@@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
689
689
"""
690
690
Performs the operation __ge__.
691
691
"""
692
- self ._check_device (other )
692
+ self ._check_type_device (other )
693
693
other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
694
694
if other is NotImplemented :
695
695
return other
@@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
741
741
"""
742
742
Performs the operation __gt__.
743
743
"""
744
- self ._check_device (other )
744
+ self ._check_type_device (other )
745
745
other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
746
746
if other is NotImplemented :
747
747
return other
@@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
796
796
"""
797
797
Performs the operation __le__.
798
798
"""
799
- self ._check_device (other )
799
+ self ._check_type_device (other )
800
800
other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
801
801
if other is NotImplemented :
802
802
return other
@@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808
808
"""
809
809
Performs the operation __lshift__.
810
810
"""
811
- self ._check_device (other )
811
+ self ._check_type_device (other )
812
812
other = self ._check_allowed_dtypes (other , "integer" , "__lshift__" )
813
813
if other is NotImplemented :
814
814
return other
@@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
820
820
"""
821
821
Performs the operation __lt__.
822
822
"""
823
- self ._check_device (other )
823
+ self ._check_type_device (other )
824
824
other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
825
825
if other is NotImplemented :
826
826
return other
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832
832
"""
833
833
Performs the operation __matmul__.
834
834
"""
835
- self ._check_device (other )
835
+ self ._check_type_device (other )
836
836
# matmul is not defined for scalars, but without this, we may get
837
837
# the wrong error message from asarray.
838
838
other = self ._check_allowed_dtypes (other , "numeric" , "__matmul__" )
@@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
845
845
"""
846
846
Performs the operation __mod__.
847
847
"""
848
- self ._check_device (other )
848
+ self ._check_type_device (other )
849
849
other = self ._check_allowed_dtypes (other , "real numeric" , "__mod__" )
850
850
if other is NotImplemented :
851
851
return other
@@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
857
857
"""
858
858
Performs the operation __mul__.
859
859
"""
860
- self ._check_device (other )
860
+ self ._check_type_device (other )
861
861
other = self ._check_allowed_dtypes (other , "numeric" , "__mul__" )
862
862
if other is NotImplemented :
863
863
return other
@@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
869
869
"""
870
870
Performs the operation __ne__.
871
871
"""
872
- self ._check_device (other )
872
+ self ._check_type_device (other )
873
873
other = self ._check_allowed_dtypes (other , "all" , "__ne__" )
874
874
if other is NotImplemented :
875
875
return other
@@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
890
890
"""
891
891
Performs the operation __or__.
892
892
"""
893
- self ._check_device (other )
893
+ self ._check_type_device (other )
894
894
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__or__" )
895
895
if other is NotImplemented :
896
896
return other
@@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
913
913
"""
914
914
from ._elementwise_functions import pow # type: ignore[attr-defined]
915
915
916
- self ._check_device (other )
916
+ self ._check_type_device (other )
917
917
other = self ._check_allowed_dtypes (other , "numeric" , "__pow__" )
918
918
if other is NotImplemented :
919
919
return other
@@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925
925
"""
926
926
Performs the operation __rshift__.
927
927
"""
928
- self ._check_device (other )
928
+ self ._check_type_device (other )
929
929
other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
930
930
if other is NotImplemented :
931
931
return other
@@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
961
961
"""
962
962
Performs the operation __sub__.
963
963
"""
964
- self ._check_device (other )
964
+ self ._check_type_device (other )
965
965
other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
966
966
if other is NotImplemented :
967
967
return other
@@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975
975
"""
976
976
Performs the operation __truediv__.
977
977
"""
978
- self ._check_device (other )
978
+ self ._check_type_device (other )
979
979
other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
980
980
if other is NotImplemented :
981
981
return other
@@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
987
987
"""
988
988
Performs the operation __xor__.
989
989
"""
990
- self ._check_device (other )
990
+ self ._check_type_device (other )
991
991
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
992
992
if other is NotImplemented :
993
993
return other
@@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
999
999
"""
1000
1000
Performs the operation __iadd__.
1001
1001
"""
1002
- self ._check_device (other )
1002
+ self ._check_type_device (other )
1003
1003
other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
1004
1004
if other is NotImplemented :
1005
1005
return other
@@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
1010
1010
"""
1011
1011
Performs the operation __radd__.
1012
1012
"""
1013
- self ._check_device (other )
1013
+ self ._check_type_device (other )
1014
1014
other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
1015
1015
if other is NotImplemented :
1016
1016
return other
@@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
1022
1022
"""
1023
1023
Performs the operation __iand__.
1024
1024
"""
1025
- self ._check_device (other )
1025
+ self ._check_type_device (other )
1026
1026
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
1027
1027
if other is NotImplemented :
1028
1028
return other
@@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
1033
1033
"""
1034
1034
Performs the operation __rand__.
1035
1035
"""
1036
- self ._check_device (other )
1036
+ self ._check_type_device (other )
1037
1037
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
1038
1038
if other is NotImplemented :
1039
1039
return other
@@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
1045
1045
"""
1046
1046
Performs the operation __ifloordiv__.
1047
1047
"""
1048
- self ._check_device (other )
1048
+ self ._check_type_device (other )
1049
1049
other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
1050
1050
if other is NotImplemented :
1051
1051
return other
@@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
1056
1056
"""
1057
1057
Performs the operation __rfloordiv__.
1058
1058
"""
1059
- self ._check_device (other )
1059
+ self ._check_type_device (other )
1060
1060
other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
1061
1061
if other is NotImplemented :
1062
1062
return other
@@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
1068
1068
"""
1069
1069
Performs the operation __ilshift__.
1070
1070
"""
1071
- self ._check_device (other )
1071
+ self ._check_type_device (other )
1072
1072
other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
1073
1073
if other is NotImplemented :
1074
1074
return other
@@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
1079
1079
"""
1080
1080
Performs the operation __rlshift__.
1081
1081
"""
1082
- self ._check_device (other )
1082
+ self ._check_type_device (other )
1083
1083
other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
1084
1084
if other is NotImplemented :
1085
1085
return other
@@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array:
1096
1096
other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
1097
1097
if other is NotImplemented :
1098
1098
return other
1099
- self ._check_device (other )
1099
+ self ._check_type_device (other )
1100
1100
res = self ._array .__imatmul__ (other ._array )
1101
1101
return self .__class__ ._new (res , device = self .device )
1102
1102
@@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
1109
1109
other = self ._check_allowed_dtypes (other , "numeric" , "__rmatmul__" )
1110
1110
if other is NotImplemented :
1111
1111
return other
1112
- self ._check_device (other )
1112
+ self ._check_type_device (other )
1113
1113
res = self ._array .__rmatmul__ (other ._array )
1114
1114
return self .__class__ ._new (res , device = self .device )
1115
1115
@@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
1130
1130
other = self ._check_allowed_dtypes (other , "real numeric" , "__rmod__" )
1131
1131
if other is NotImplemented :
1132
1132
return other
1133
- self ._check_device (other )
1133
+ self ._check_type_device (other )
1134
1134
self , other = self ._normalize_two_args (self , other )
1135
1135
res = self ._array .__rmod__ (other ._array )
1136
1136
return self .__class__ ._new (res , device = self .device )
@@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
1152
1152
other = self ._check_allowed_dtypes (other , "numeric" , "__rmul__" )
1153
1153
if other is NotImplemented :
1154
1154
return other
1155
- self ._check_device (other )
1155
+ self ._check_type_device (other )
1156
1156
self , other = self ._normalize_two_args (self , other )
1157
1157
res = self ._array .__rmul__ (other ._array )
1158
1158
return self .__class__ ._new (res , device = self .device )
@@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
1171
1171
"""
1172
1172
Performs the operation __ror__.
1173
1173
"""
1174
- self ._check_device (other )
1174
+ self ._check_type_device (other )
1175
1175
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__ror__" )
1176
1176
if other is NotImplemented :
1177
1177
return other
@@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
1219
1219
other = self ._check_allowed_dtypes (other , "integer" , "__rrshift__" )
1220
1220
if other is NotImplemented :
1221
1221
return other
1222
- self ._check_device (other )
1222
+ self ._check_type_device (other )
1223
1223
self , other = self ._normalize_two_args (self , other )
1224
1224
res = self ._array .__rrshift__ (other ._array )
1225
1225
return self .__class__ ._new (res , device = self .device )
@@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
1241
1241
other = self ._check_allowed_dtypes (other , "numeric" , "__rsub__" )
1242
1242
if other is NotImplemented :
1243
1243
return other
1244
- self ._check_device (other )
1244
+ self ._check_type_device (other )
1245
1245
self , other = self ._normalize_two_args (self , other )
1246
1246
res = self ._array .__rsub__ (other ._array )
1247
1247
return self .__class__ ._new (res , device = self .device )
@@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
1263
1263
other = self ._check_allowed_dtypes (other , "floating-point" , "__rtruediv__" )
1264
1264
if other is NotImplemented :
1265
1265
return other
1266
- self ._check_device (other )
1266
+ self ._check_type_device (other )
1267
1267
self , other = self ._normalize_two_args (self , other )
1268
1268
res = self ._array .__rtruediv__ (other ._array )
1269
1269
return self .__class__ ._new (res , device = self .device )
@@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
1285
1285
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rxor__" )
1286
1286
if other is NotImplemented :
1287
1287
return other
1288
- self ._check_device (other )
1288
+ self ._check_type_device (other )
1289
1289
self , other = self ._normalize_two_args (self , other )
1290
1290
res = self ._array .__rxor__ (other ._array )
1291
1291
return self .__class__ ._new (res , device = self .device )
0 commit comments