1717import dpctl
1818import dpctl .tensor as dpt
1919import dpctl .tensor ._tensor_impl as ti
20- from dpctl .tensor ._manipulation_functions import _broadcast_shapes
20+ from dpctl .tensor ._elementwise_common import (
21+ _get_dtype ,
22+ _get_queue_usm_type ,
23+ _get_shape ,
24+ _validate_dtype ,
25+ )
26+ from dpctl .tensor ._manipulation_functions import _broadcast_shape_impl
2127from dpctl .utils import ExecutionPlacementError , SequentialOrderManager
2228
2329from ._copy_utils import _empty_like_orderK , _empty_like_triple_orderK
24- from ._type_utils import _all_data_types , _can_cast
30+ from ._type_utils import (
31+ WeakBooleanType ,
32+ WeakComplexType ,
33+ WeakFloatingType ,
34+ WeakIntegralType ,
35+ _all_data_types ,
36+ _can_cast ,
37+ _is_weak_dtype ,
38+ _strong_dtype_num_kind ,
39+ _to_device_supported_dtype ,
40+ _weak_type_num_kind ,
41+ )
42+
43+
44+ def _default_dtype_from_weak_type (dt , dev ):
45+ if isinstance (dt , WeakBooleanType ):
46+ return dpt .bool
47+ if isinstance (dt , WeakIntegralType ):
48+ return dpt .dtype (ti .default_device_int_type (dev ))
49+ if isinstance (dt , WeakFloatingType ):
50+ return dpt .dtype (ti .default_device_fp_type (dev ))
51+ if isinstance (dt , WeakComplexType ):
52+ return dpt .dtype (ti .default_device_complex_type (dev ))
53+
54+
55+ def _resolve_two_weak_types (o1_dtype , o2_dtype , dev ):
56+ "Resolves two weak data types per NEP-0050"
57+ if _is_weak_dtype (o1_dtype ):
58+ if _is_weak_dtype (o2_dtype ):
59+ return _default_dtype_from_weak_type (
60+ o1_dtype , dev
61+ ), _default_dtype_from_weak_type (o2_dtype , dev )
62+ o1_kind_num = _weak_type_num_kind (o1_dtype )
63+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
64+ if o1_kind_num > o2_kind_num :
65+ if isinstance (o1_dtype , WeakIntegralType ):
66+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
67+ if isinstance (o1_dtype , WeakComplexType ):
68+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
69+ return dpt .complex64 , o2_dtype
70+ return (
71+ _to_device_supported_dtype (dpt .complex128 , dev ),
72+ o2_dtype ,
73+ )
74+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
75+ else :
76+ return o2_dtype , o2_dtype
77+ elif _is_weak_dtype (o2_dtype ):
78+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
79+ o2_kind_num = _weak_type_num_kind (o2_dtype )
80+ if o2_kind_num > o1_kind_num :
81+ if isinstance (o2_dtype , WeakIntegralType ):
82+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
83+ if isinstance (o2_dtype , WeakComplexType ):
84+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
85+ return o1_dtype , dpt .complex64
86+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
87+ return (
88+ o1_dtype ,
89+ _to_device_supported_dtype (dpt .float64 , dev ),
90+ )
91+ else :
92+ return o1_dtype , o1_dtype
93+ else :
94+ return o1_dtype , o2_dtype
2595
2696
2797def _where_result_type (dt1 , dt2 , dev ):
@@ -81,36 +151,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81151 raise TypeError (
82152 "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (condition )} "
83153 )
84- if not isinstance (x1 , dpt .usm_ndarray ):
85- raise TypeError (
86- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x1 )} "
154+ if order not in ["K" , "C" , "F" , "A" ]:
155+ order = "K"
156+ q1 , condition_usm_type = condition .sycl_queue , condition .usm_type
157+ q2 , x1_usm_type = _get_queue_usm_type (x1 )
158+ q3 , x2_usm_type = _get_queue_usm_type (x2 )
159+ if q2 is None and q3 is None :
160+ exec_q = q1
161+ out_usm_type = condition_usm_type
162+ elif q3 is None :
163+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
164+ if exec_q is None :
165+ raise ExecutionPlacementError (
166+ "Execution placement can not be unambiguously inferred "
167+ "from input arguments."
168+ )
169+ out_usm_type = dpctl .utils .get_coerced_usm_type (
170+ (
171+ condition_usm_type ,
172+ x1_usm_type ,
173+ )
87174 )
88- if not isinstance (x2 , dpt .usm_ndarray ):
175+ elif q2 is None :
176+ exec_q = dpctl .utils .get_execution_queue ((q1 , q3 ))
177+ if exec_q is None :
178+ raise ExecutionPlacementError (
179+ "Execution placement can not be unambiguously inferred "
180+ "from input arguments."
181+ )
182+ out_usm_type = dpctl .utils .get_coerced_usm_type (
183+ (
184+ condition_usm_type ,
185+ x2_usm_type ,
186+ )
187+ )
188+ else :
189+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 , q3 ))
190+ if exec_q is None :
191+ raise ExecutionPlacementError (
192+ "Execution placement can not be unambiguously inferred "
193+ "from input arguments."
194+ )
195+ out_usm_type = dpctl .utils .get_coerced_usm_type (
196+ (
197+ condition_usm_type ,
198+ x1_usm_type ,
199+ x2_usm_type ,
200+ )
201+ )
202+ dpctl .utils .validate_usm_type (out_usm_type , allow_none = False )
203+ condition_shape = condition .shape
204+ x1_shape = _get_shape (x1 )
205+ x2_shape = _get_shape (x2 )
206+ if not all (
207+ isinstance (s , (tuple , list ))
208+ for s in (
209+ x1_shape ,
210+ x2_shape ,
211+ )
212+ ):
89213 raise TypeError (
90- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x2 )} "
214+ "Shape of arguments can not be inferred. "
215+ "Arguments are expected to be "
216+ "lists, tuples, or both"
91217 )
92- if order not in [ "K" , "C" , "F" , "A" ] :
93- order = "K"
94- exec_q = dpctl . utils . get_execution_queue (
95- (
96- condition . sycl_queue ,
97- x1 . sycl_queue ,
98- x2 . sycl_queue ,
218+ try :
219+ res_shape = _broadcast_shape_impl (
220+ [
221+ condition_shape ,
222+ x1_shape ,
223+ x2_shape ,
224+ ]
99225 )
100- )
101- if exec_q is None :
102- raise dpctl .utils .ExecutionPlacementError
103- out_usm_type = dpctl .utils .get_coerced_usm_type (
104- (
105- condition .usm_type ,
106- x1 .usm_type ,
107- x2 .usm_type ,
226+ except ValueError :
227+ raise ValueError (
228+ "operands could not be broadcast together with shapes "
229+ f"{ condition_shape } , { x1_shape } , and { x2_shape } "
108230 )
109- )
110-
111- x1_dtype = x1 .dtype
112- x2_dtype = x2 .dtype
113- out_dtype = _where_result_type (x1_dtype , x2_dtype , exec_q .sycl_device )
231+ sycl_dev = exec_q .sycl_device
232+ x1_dtype = _get_dtype (x1 , sycl_dev )
233+ x2_dtype = _get_dtype (x2 , sycl_dev )
234+ if not all (_validate_dtype (o ) for o in (x1_dtype , x2_dtype )):
235+ raise ValueError ("Operands have unsupported data types" )
236+ x1_dtype , x2_dtype = _resolve_two_weak_types (x1_dtype , x2_dtype , sycl_dev )
237+ out_dtype = _where_result_type (x1_dtype , x2_dtype , sycl_dev )
114238 if out_dtype is None :
115239 raise TypeError (
116240 "function 'where' does not support input "
@@ -119,8 +243,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119243 "to any supported types according to the casting rule ''safe''."
120244 )
121245
122- res_shape = _broadcast_shapes (condition , x1 , x2 )
123-
124246 orig_out = out
125247 if out is not None :
126248 if not isinstance (out , dpt .usm_ndarray ):
@@ -149,16 +271,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149271 "Input and output allocation queues are not compatible"
150272 )
151273
152- if ti ._array_overlap (condition , out ):
153- if not ti ._same_logical_tensors (condition , out ):
154- out = dpt .empty_like (out )
274+ if ti ._array_overlap (condition , out ) and not ti ._same_logical_tensors (
275+ condition , out
276+ ):
277+ out = dpt .empty_like (out )
155278
156- if ti ._array_overlap (x1 , out ):
157- if not ti ._same_logical_tensors (x1 , out ):
279+ if isinstance (x1 , dpt .usm_ndarray ):
280+ if (
281+ ti ._array_overlap (x1 , out )
282+ and not ti ._same_logical_tensors (x1 , out )
283+ and x1_dtype == out_dtype
284+ ):
158285 out = dpt .empty_like (out )
159286
160- if ti ._array_overlap (x2 , out ):
161- if not ti ._same_logical_tensors (x2 , out ):
287+ if isinstance (x2 , dpt .usm_ndarray ):
288+ if (
289+ ti ._array_overlap (x2 , out )
290+ and not ti ._same_logical_tensors (x2 , out )
291+ and x2_dtype == out_dtype
292+ ):
162293 out = dpt .empty_like (out )
163294
164295 if order == "A" :
@@ -174,6 +305,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174305 )
175306 else "C"
176307 )
308+ if not isinstance (x1 , dpt .usm_ndarray ):
309+ x1 = dpt .asarray (x1 , dtype = x1_dtype , sycl_queue = exec_q )
310+ if not isinstance (x2 , dpt .usm_ndarray ):
311+ x2 = dpt .asarray (x2 , dtype = x2_dtype , sycl_queue = exec_q )
177312
178313 if condition .size == 0 :
179314 if out is not None :
@@ -236,9 +371,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236371 sycl_queue = exec_q ,
237372 )
238373
239- condition = dpt .broadcast_to (condition , res_shape )
240- x1 = dpt .broadcast_to (x1 , res_shape )
241- x2 = dpt .broadcast_to (x2 , res_shape )
374+ if condition_shape != res_shape :
375+ condition = dpt .broadcast_to (condition , res_shape )
376+ if x1_shape != res_shape :
377+ x1 = dpt .broadcast_to (x1 , res_shape )
378+ if x2_shape != res_shape :
379+ x2 = dpt .broadcast_to (x2 , res_shape )
242380
243381 dep_evs = _manager .submitted_events
244382 hev , where_ev = ti ._where (
0 commit comments