Skip to content

Commit 0e89470

Browse files
authored
fix ops.where on GPU (#2111)
1 parent 50e4dcd commit 0e89470

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

mindnlp/core/_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def __getitem__(self, slices):
265265
return origin_getitem(self, slices)
266266

267267
Tensor.__getitem__ = __getitem__
268-
StubTensor.__getitem__ = __getitem__
268+
StubTensor.__getitem__ = _stub_method(__getitem__)
269269

270270
def _convert_numpy_slices(self, key):
271271
"""递归转换 key 中的 NumPy 整数为内置 int"""

mindnlp/core/ops/pointwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,10 @@ def mul(input, other, *, out=None):
589589
if isinstance(other, bool):
590590
out = ops.bitwise_and(input, other)
591591
else:
592-
out = ops.mul(input.int(), other).bool()
592+
out = ops.mul(input.int(), other)
593593
else:
594594
out = ops.mul(input, other)
595+
return out
595596

596597
if isinstance(other, mindspore.Tensor):
597598
out_dtype = min(input.dtype, other.dtype)

0 commit comments

Comments
 (0)