@@ -268,30 +268,39 @@ def aten_mm(self, x, y):
268268 def _register_binary_ops (self ):
269269 binary_ops = {
270270 (torch .ops .aten .add .Tensor , "add" ): (
271+ None ,
271272 atb_op .Add ,
272273 atb_op .AclNnAdds ,
273274 atb_op .AclNnAdd ,
274275 ),
275276 (torch .ops .aten .sub .Tensor , "sub" ): (
277+ None ,
276278 atb_op .Sub ,
277279 atb_op .AclNnSubs ,
278280 atb_op .AclNnSub ,
279281 ),
280282 (torch .ops .aten .mul .Tensor , "mul" ): (
283+ atb_op .Muls ,
281284 atb_op .Mul ,
282285 atb_op .AclNnMuls ,
283286 atb_op .AclNnMul ,
284287 ),
285288 (torch .ops .aten .div .Tensor , "div" ): (
289+ None ,
286290 atb_op .Div ,
287291 atb_op .AclNnDivs ,
288292 atb_op .AclNnDiv ,
289293 ),
290294 }
291295
292- for (aten_op , op_name ), (tensor_op , scalar_op , aclnn_op ) in binary_ops .items ():
296+ for (aten_op , op_name ), (
297+ scalar_op ,
298+ tensor_op ,
299+ aclnn_scalar_op ,
300+ aclnn_op ,
301+ ) in binary_ops .items ():
293302
294- def make_handler (tensor_op , scalar_op , aclnn_op ):
303+ def make_handler (scalar_op , tensor_op , aclnn_scalar_op , aclnn_op ):
295304 def handler (self , x , y ):
296305 atb_supported_dtype = [torch .float16 , torch .bfloat16 ]
297306 out_dtype = fx_traceback .get_current_meta ()["val" ].dtype
@@ -311,11 +320,16 @@ def handler(self, x, y):
311320 return self .get_proxy (aclnn_op , (x , y , dtype ))
312321 else :
313322 dtype = get_ascend_dtype (out_dtype )
314- return self .get_proxy (scalar_op , (x , y , dtype ))
323+ if out_dtype in atb_supported_dtype and scalar_op is not None :
324+ return self .get_proxy (scalar_op , (x , y , dtype ))
325+ else :
326+ return self .get_proxy (aclnn_scalar_op , (x , y , dtype ))
315327
316328 return handler
317329
318- register_conversion (aten_op )(make_handler (tensor_op , scalar_op , aclnn_op ))
330+ register_conversion (aten_op )(
331+ make_handler (scalar_op , tensor_op , aclnn_scalar_op , aclnn_op )
332+ )
319333
320334 @register_conversion (torch .ops .aten .pow .Tensor_Scalar )
321335 def aten_pow_tensor_scalar (self , x , y ):
@@ -602,9 +616,10 @@ def select_int(self, x, dim, index):
602616 @register_conversion (torch .ops .aten .slice .Tensor )
603617 def slice_tensor (self , x , dim , start , end , step = 1 ):
604618 dtype = fx_traceback .get_current_meta ()["val" ].dtype
619+ x_shape = x .node .meta ["val" ].shape
605620 if dtype == torch .int64 or step != 1 :
621+ end = x_shape [dim ] if end >= x_shape [dim ] else end
606622 return self .get_proxy (atb_op .AclNnSlice , (x , dim , start , end , step ))
607- x_shape = x .node .meta ["val" ].shape
608623 offsets = [0 ] * len (x_shape )
609624 size = [- 1 ] * len (x_shape )
610625
0 commit comments