@@ -234,14 +234,15 @@ def _check_allowed_dtypes(
234234 return other
235235
236236 def _check_device (self , other : Array | complex ) -> None :
237- """Check that other is on a device compatible with the current array"""
238- if isinstance ( other , ( bool , int , float , complex )):
239- return
240- elif isinstance (other , Array ):
237+ """Check that other is either a Python scalar or an array on a device
238+ compatible with the current array.
239+ """
240+ if isinstance (other , Array ):
241241 if self .device != other .device :
242242 raise ValueError (f"Arrays from two different devices ({ self .device } and { other .device } ) can not be combined." )
243- else :
244- raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
243+ # Disallow subclasses of Python scalars, such as np.float64 and np.complex128
244+ elif type (other ) not in (bool , int , float , complex ):
245+ raise TypeError (f"Expected Array or Python scalar; got { type (other )} " )
245246
246247 # Helper function to match the type promotion rules in the spec
247248 def _promote_scalar (self , scalar : complex ) -> Array :
@@ -542,7 +543,7 @@ def __add__(self, other: Array | complex, /) -> Array:
542543 """
543544 Performs the operation __add__.
544545 """
545- self ._check_device (other )
546+ self ._check_type_device (other )
546547 other = self ._check_allowed_dtypes (other , "numeric" , "__add__" )
547548 if other is NotImplemented :
548549 return other
@@ -554,7 +555,7 @@ def __and__(self, other: Array | int, /) -> Array:
554555 """
555556 Performs the operation __and__.
556557 """
557- self ._check_device (other )
558+ self ._check_type_device (other )
558559 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__and__" )
559560 if other is NotImplemented :
560561 return other
@@ -651,7 +652,7 @@ def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
651652 """
652653 Performs the operation __eq__.
653654 """
654- self ._check_device (other )
655+ self ._check_type_device (other )
655656 # Even though "all" dtypes are allowed, we still require them to be
656657 # promotable with each other.
657658 other = self ._check_allowed_dtypes (other , "all" , "__eq__" )
@@ -677,7 +678,7 @@ def __floordiv__(self, other: Array | float, /) -> Array:
677678 """
678679 Performs the operation __floordiv__.
679680 """
680- self ._check_device (other )
681+ self ._check_type_device (other )
681682 other = self ._check_allowed_dtypes (other , "real numeric" , "__floordiv__" )
682683 if other is NotImplemented :
683684 return other
@@ -689,7 +690,7 @@ def __ge__(self, other: Array | float, /) -> Array:
689690 """
690691 Performs the operation __ge__.
691692 """
692- self ._check_device (other )
693+ self ._check_type_device (other )
693694 other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
694695 if other is NotImplemented :
695696 return other
@@ -741,7 +742,7 @@ def __gt__(self, other: Array | float, /) -> Array:
741742 """
742743 Performs the operation __gt__.
743744 """
744- self ._check_device (other )
745+ self ._check_type_device (other )
745746 other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
746747 if other is NotImplemented :
747748 return other
@@ -796,7 +797,7 @@ def __le__(self, other: Array | float, /) -> Array:
796797 """
797798 Performs the operation __le__.
798799 """
799- self ._check_device (other )
800+ self ._check_type_device (other )
800801 other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
801802 if other is NotImplemented :
802803 return other
@@ -808,7 +809,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808809 """
809810 Performs the operation __lshift__.
810811 """
811- self ._check_device (other )
812+ self ._check_type_device (other )
812813 other = self ._check_allowed_dtypes (other , "integer" , "__lshift__" )
813814 if other is NotImplemented :
814815 return other
@@ -820,7 +821,7 @@ def __lt__(self, other: Array | float, /) -> Array:
820821 """
821822 Performs the operation __lt__.
822823 """
823- self ._check_device (other )
824+ self ._check_type_device (other )
824825 other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
825826 if other is NotImplemented :
826827 return other
@@ -832,7 +833,7 @@ def __matmul__(self, other: Array, /) -> Array:
832833 """
833834 Performs the operation __matmul__.
834835 """
835- self ._check_device (other )
836+ self ._check_type_device (other )
836837 # matmul is not defined for scalars, but without this, we may get
837838 # the wrong error message from asarray.
838839 other = self ._check_allowed_dtypes (other , "numeric" , "__matmul__" )
@@ -845,7 +846,7 @@ def __mod__(self, other: Array | float, /) -> Array:
845846 """
846847 Performs the operation __mod__.
847848 """
848- self ._check_device (other )
849+ self ._check_type_device (other )
849850 other = self ._check_allowed_dtypes (other , "real numeric" , "__mod__" )
850851 if other is NotImplemented :
851852 return other
@@ -857,7 +858,7 @@ def __mul__(self, other: Array | complex, /) -> Array:
857858 """
858859 Performs the operation __mul__.
859860 """
860- self ._check_device (other )
861+ self ._check_type_device (other )
861862 other = self ._check_allowed_dtypes (other , "numeric" , "__mul__" )
862863 if other is NotImplemented :
863864 return other
@@ -869,7 +870,7 @@ def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
869870 """
870871 Performs the operation __ne__.
871872 """
872- self ._check_device (other )
873+ self ._check_type_device (other )
873874 other = self ._check_allowed_dtypes (other , "all" , "__ne__" )
874875 if other is NotImplemented :
875876 return other
@@ -890,7 +891,7 @@ def __or__(self, other: Array | int, /) -> Array:
890891 """
891892 Performs the operation __or__.
892893 """
893- self ._check_device (other )
894+ self ._check_type_device (other )
894895 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__or__" )
895896 if other is NotImplemented :
896897 return other
@@ -913,7 +914,7 @@ def __pow__(self, other: Array | complex, /) -> Array:
913914 """
914915 from ._elementwise_functions import pow # type: ignore[attr-defined]
915916
916- self ._check_device (other )
917+ self ._check_type_device (other )
917918 other = self ._check_allowed_dtypes (other , "numeric" , "__pow__" )
918919 if other is NotImplemented :
919920 return other
@@ -925,7 +926,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925926 """
926927 Performs the operation __rshift__.
927928 """
928- self ._check_device (other )
929+ self ._check_type_device (other )
929930 other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
930931 if other is NotImplemented :
931932 return other
@@ -961,7 +962,7 @@ def __sub__(self, other: Array | complex, /) -> Array:
961962 """
962963 Performs the operation __sub__.
963964 """
964- self ._check_device (other )
965+ self ._check_type_device (other )
965966 other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
966967 if other is NotImplemented :
967968 return other
@@ -975,7 +976,7 @@ def __truediv__(self, other: Array | complex, /) -> Array:
975976 """
976977 Performs the operation __truediv__.
977978 """
978- self ._check_device (other )
979+ self ._check_type_device (other )
979980 other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
980981 if other is NotImplemented :
981982 return other
@@ -987,7 +988,7 @@ def __xor__(self, other: Array | int, /) -> Array:
987988 """
988989 Performs the operation __xor__.
989990 """
990- self ._check_device (other )
991+ self ._check_type_device (other )
991992 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
992993 if other is NotImplemented :
993994 return other
@@ -999,7 +1000,7 @@ def __iadd__(self, other: Array | complex, /) -> Array:
9991000 """
10001001 Performs the operation __iadd__.
10011002 """
1002- self ._check_device (other )
1003+ self ._check_type_device (other )
10031004 other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
10041005 if other is NotImplemented :
10051006 return other
@@ -1010,7 +1011,7 @@ def __radd__(self, other: Array | complex, /) -> Array:
10101011 """
10111012 Performs the operation __radd__.
10121013 """
1013- self ._check_device (other )
1014+ self ._check_type_device (other )
10141015 other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
10151016 if other is NotImplemented :
10161017 return other
@@ -1022,7 +1023,7 @@ def __iand__(self, other: Array | int, /) -> Array:
10221023 """
10231024 Performs the operation __iand__.
10241025 """
1025- self ._check_device (other )
1026+ self ._check_type_device (other )
10261027 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
10271028 if other is NotImplemented :
10281029 return other
@@ -1033,7 +1034,7 @@ def __rand__(self, other: Array | int, /) -> Array:
10331034 """
10341035 Performs the operation __rand__.
10351036 """
1036- self ._check_device (other )
1037+ self ._check_type_device (other )
10371038 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
10381039 if other is NotImplemented :
10391040 return other
@@ -1045,7 +1046,7 @@ def __ifloordiv__(self, other: Array | float, /) -> Array:
10451046 """
10461047 Performs the operation __ifloordiv__.
10471048 """
1048- self ._check_device (other )
1049+ self ._check_type_device (other )
10491050 other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
10501051 if other is NotImplemented :
10511052 return other
@@ -1056,7 +1057,7 @@ def __rfloordiv__(self, other: Array | float, /) -> Array:
10561057 """
10571058 Performs the operation __rfloordiv__.
10581059 """
1059- self ._check_device (other )
1060+ self ._check_type_device (other )
10601061 other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
10611062 if other is NotImplemented :
10621063 return other
@@ -1068,7 +1069,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
10681069 """
10691070 Performs the operation __ilshift__.
10701071 """
1071- self ._check_device (other )
1072+ self ._check_type_device (other )
10721073 other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
10731074 if other is NotImplemented :
10741075 return other
@@ -1079,7 +1080,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
10791080 """
10801081 Performs the operation __rlshift__.
10811082 """
1082- self ._check_device (other )
1083+ self ._check_type_device (other )
10831084 other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
10841085 if other is NotImplemented :
10851086 return other
@@ -1096,7 +1097,7 @@ def __imatmul__(self, other: Array, /) -> Array:
10961097 other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
10971098 if other is NotImplemented :
10981099 return other
1099- self ._check_device (other )
1100+ self ._check_type_device (other )
11001101 res = self ._array .__imatmul__ (other ._array )
11011102 return self .__class__ ._new (res , device = self .device )
11021103
@@ -1109,7 +1110,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11091110 other = self ._check_allowed_dtypes (other , "numeric" , "__rmatmul__" )
11101111 if other is NotImplemented :
11111112 return other
1112- self ._check_device (other )
1113+ self ._check_type_device (other )
11131114 res = self ._array .__rmatmul__ (other ._array )
11141115 return self .__class__ ._new (res , device = self .device )
11151116
@@ -1130,7 +1131,7 @@ def __rmod__(self, other: Array | float, /) -> Array:
11301131 other = self ._check_allowed_dtypes (other , "real numeric" , "__rmod__" )
11311132 if other is NotImplemented :
11321133 return other
1133- self ._check_device (other )
1134+ self ._check_type_device (other )
11341135 self , other = self ._normalize_two_args (self , other )
11351136 res = self ._array .__rmod__ (other ._array )
11361137 return self .__class__ ._new (res , device = self .device )
@@ -1152,7 +1153,7 @@ def __rmul__(self, other: Array | complex, /) -> Array:
11521153 other = self ._check_allowed_dtypes (other , "numeric" , "__rmul__" )
11531154 if other is NotImplemented :
11541155 return other
1155- self ._check_device (other )
1156+ self ._check_type_device (other )
11561157 self , other = self ._normalize_two_args (self , other )
11571158 res = self ._array .__rmul__ (other ._array )
11581159 return self .__class__ ._new (res , device = self .device )
@@ -1171,7 +1172,7 @@ def __ror__(self, other: Array | int, /) -> Array:
11711172 """
11721173 Performs the operation __ror__.
11731174 """
1174- self ._check_device (other )
1175+ self ._check_type_device (other )
11751176 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__ror__" )
11761177 if other is NotImplemented :
11771178 return other
@@ -1219,7 +1220,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12191220 other = self ._check_allowed_dtypes (other , "integer" , "__rrshift__" )
12201221 if other is NotImplemented :
12211222 return other
1222- self ._check_device (other )
1223+ self ._check_type_device (other )
12231224 self , other = self ._normalize_two_args (self , other )
12241225 res = self ._array .__rrshift__ (other ._array )
12251226 return self .__class__ ._new (res , device = self .device )
@@ -1241,7 +1242,7 @@ def __rsub__(self, other: Array | complex, /) -> Array:
12411242 other = self ._check_allowed_dtypes (other , "numeric" , "__rsub__" )
12421243 if other is NotImplemented :
12431244 return other
1244- self ._check_device (other )
1245+ self ._check_type_device (other )
12451246 self , other = self ._normalize_two_args (self , other )
12461247 res = self ._array .__rsub__ (other ._array )
12471248 return self .__class__ ._new (res , device = self .device )
@@ -1263,7 +1264,7 @@ def __rtruediv__(self, other: Array | complex, /) -> Array:
12631264 other = self ._check_allowed_dtypes (other , "floating-point" , "__rtruediv__" )
12641265 if other is NotImplemented :
12651266 return other
1266- self ._check_device (other )
1267+ self ._check_type_device (other )
12671268 self , other = self ._normalize_two_args (self , other )
12681269 res = self ._array .__rtruediv__ (other ._array )
12691270 return self .__class__ ._new (res , device = self .device )
@@ -1285,7 +1286,7 @@ def __rxor__(self, other: Array | int, /) -> Array:
12851286 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rxor__" )
12861287 if other is NotImplemented :
12871288 return other
1288- self ._check_device (other )
1289+ self ._check_type_device (other )
12891290 self , other = self ._normalize_two_args (self , other )
12901291 res = self ._array .__rxor__ (other ._array )
12911292 return self .__class__ ._new (res , device = self .device )
0 commit comments