@@ -19,7 +19,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
1919 """
2020 if x .dtype not in _real_numeric_dtypes :
2121 raise TypeError ("Only real numeric dtypes are allowed in argmax" )
22- return Array ._new (np .asarray (np .argmax (x ._array , axis = axis , keepdims = keepdims )))
22+ return Array ._new (np .asarray (np .argmax (x ._array , axis = axis , keepdims = keepdims )), device = x . device )
2323
2424
2525def argmin (x : Array , / , * , axis : Optional [int ] = None , keepdims : bool = False ) -> Array :
@@ -30,7 +30,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
3030 """
3131 if x .dtype not in _real_numeric_dtypes :
3232 raise TypeError ("Only real numeric dtypes are allowed in argmin" )
33- return Array ._new (np .asarray (np .argmin (x ._array , axis = axis , keepdims = keepdims )))
33+ return Array ._new (np .asarray (np .argmin (x ._array , axis = axis , keepdims = keepdims )), device = x . device )
3434
3535
3636@requires_data_dependent_shapes
@@ -61,12 +61,16 @@ def searchsorted(
6161 """
6262 if x1 .dtype not in _real_numeric_dtypes or x2 .dtype not in _real_numeric_dtypes :
6363 raise TypeError ("Only real numeric dtypes are allowed in searchsorted" )
64+
65+ if x1 .device != x2 .device :
66+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
67+
6468 sorter = sorter ._array if sorter is not None else None
6569 # TODO: The sort order of nans and signed zeros is implementation
6670 # dependent. Should we error/warn if they are present?
6771
6872 # x1 must be 1-D, but NumPy already requires this.
69- return Array ._new (np .searchsorted (x1 ._array , x2 ._array , side = side , sorter = sorter ))
73+ return Array ._new (np .searchsorted (x1 ._array , x2 ._array , side = side , sorter = sorter ), device = x1 . device )
7074
7175def where (condition : Array , x1 : Array , x2 : Array , / ) -> Array :
7276 """
@@ -76,5 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
7680 """
7781 # Call result type here just to raise on disallowed type combinations
7882 _result_type (x1 .dtype , x2 .dtype )
83+
84+ if len ({a .device for a in (condition , x1 , x2 )}) > 1 :
85+ raise ValueError ("where inputs must all be on the same device" )
86+
7987 x1 , x2 = Array ._normalize_two_args (x1 , x2 )
8088 return Array ._new (np .where (condition ._array , x1 ._array , x2 ._array ))
0 commit comments