Skip to content

Commit 64eed26

Browse files
authored
fix dsv2 (#218)
1 parent 4317831 commit 64eed26

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/conversion.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,30 +268,39 @@ def aten_mm(self, x, y):
268268
def _register_binary_ops(self):
269269
binary_ops = {
270270
(torch.ops.aten.add.Tensor, "add"): (
271+
None,
271272
atb_op.Add,
272273
atb_op.AclNnAdds,
273274
atb_op.AclNnAdd,
274275
),
275276
(torch.ops.aten.sub.Tensor, "sub"): (
277+
None,
276278
atb_op.Sub,
277279
atb_op.AclNnSubs,
278280
atb_op.AclNnSub,
279281
),
280282
(torch.ops.aten.mul.Tensor, "mul"): (
283+
atb_op.Muls,
281284
atb_op.Mul,
282285
atb_op.AclNnMuls,
283286
atb_op.AclNnMul,
284287
),
285288
(torch.ops.aten.div.Tensor, "div"): (
289+
None,
286290
atb_op.Div,
287291
atb_op.AclNnDivs,
288292
atb_op.AclNnDiv,
289293
),
290294
}
291295

292-
for (aten_op, op_name), (tensor_op, scalar_op, aclnn_op) in binary_ops.items():
296+
for (aten_op, op_name), (
297+
scalar_op,
298+
tensor_op,
299+
aclnn_scalar_op,
300+
aclnn_op,
301+
) in binary_ops.items():
293302

294-
def make_handler(tensor_op, scalar_op, aclnn_op):
303+
def make_handler(scalar_op, tensor_op, aclnn_scalar_op, aclnn_op):
295304
def handler(self, x, y):
296305
atb_supported_dtype = [torch.float16, torch.bfloat16]
297306
out_dtype = fx_traceback.get_current_meta()["val"].dtype
@@ -311,11 +320,16 @@ def handler(self, x, y):
311320
return self.get_proxy(aclnn_op, (x, y, dtype))
312321
else:
313322
dtype = get_ascend_dtype(out_dtype)
314-
return self.get_proxy(scalar_op, (x, y, dtype))
323+
if out_dtype in atb_supported_dtype and scalar_op is not None:
324+
return self.get_proxy(scalar_op, (x, y, dtype))
325+
else:
326+
return self.get_proxy(aclnn_scalar_op, (x, y, dtype))
315327

316328
return handler
317329

318-
register_conversion(aten_op)(make_handler(tensor_op, scalar_op, aclnn_op))
330+
register_conversion(aten_op)(
331+
make_handler(scalar_op, tensor_op, aclnn_scalar_op, aclnn_op)
332+
)
319333

320334
@register_conversion(torch.ops.aten.pow.Tensor_Scalar)
321335
def aten_pow_tensor_scalar(self, x, y):
@@ -602,9 +616,10 @@ def select_int(self, x, dim, index):
602616
@register_conversion(torch.ops.aten.slice.Tensor)
603617
def slice_tensor(self, x, dim, start, end, step=1):
604618
dtype = fx_traceback.get_current_meta()["val"].dtype
619+
x_shape = x.node.meta["val"].shape
605620
if dtype == torch.int64 or step != 1:
621+
end = x_shape[dim] if end >= x_shape[dim] else end
606622
return self.get_proxy(atb_op.AclNnSlice, (x, dim, start, end, step))
607-
x_shape = x.node.meta["val"].shape
608623
offsets = [0] * len(x_shape)
609624
size = [-1] * len(x_shape)
610625

0 commit comments

Comments
 (0)