Skip to content

Commit baa291b

Browse files
authored
[mlir][mesh] Add null check for dyn_cast to prevent crash (llvm#149266)
This PR adds a null check for dyn_cast result before use to prevent crash, and use `isa` instead `dyn_cast` to make code clean. Fixes llvm#148619.
1 parent 8aa4fc0 commit baa291b

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns(
6262
auto isEndomorphismOp = [reduction](Operation *op,
6363
std::optional<Operation *> referenceOp) {
6464
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65+
if (!allReduceOp)
66+
return false;
6567
auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
6668
auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
67-
if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
69+
if (inType.getElementType() != outType.getElementType() ||
6870
allReduceOp.getReduction() != reduction) {
6971
return false;
7072
}
@@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns(
8789
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
8890
inType.getElementType() == refType.getElementType();
8991
};
90-
auto isAlgebraicOp = [](Operation *op) {
91-
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
92-
};
92+
auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
9393

9494
using ConcreteEndomorphismSimplification = EndomorphismSimplification<
9595
std::decay_t<decltype(getEndomorphismOpOperand)>,

mlir/test/Dialect/Mesh/simplifications.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism(
165165
// CHECK: return %[[ALL_REDUCE_RES]]
166166
return %2 : tensor<5xi32>
167167
}
168+
169+
// Ensure this case without endomorphism op not crash.
170+
// CHECK-LABEL: func.func @no_endomorphism_op
171+
func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
172+
%c0 = arith.constant 0 : index
173+
%c1_i64 = arith.constant 1 : i64
174+
// CHECK: tensor.extract
175+
%extracted = tensor.extract %arg0[%c0] : tensor<2xi64>
176+
// CHECK: arith.maxsi
177+
%0 = arith.maxsi %extracted, %c1_i64 : i64
178+
return %0 : i64
179+
}

0 commit comments

Comments
 (0)