File tree Expand file tree Collapse file tree 2 files changed +16
-4
lines changed
include/mlir/Dialect/Mesh/Transforms Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns(
62
62
auto isEndomorphismOp = [reduction](Operation *op,
63
63
std::optional<Operation *> referenceOp) {
64
64
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65
+ if (!allReduceOp)
66
+ return false ;
65
67
auto inType = cast<ShapedType>(allReduceOp.getInput ().getType ());
66
68
auto outType = cast<ShapedType>(allReduceOp.getResult ().getType ());
67
- if (!allReduceOp || inType.getElementType () != outType.getElementType () ||
69
+ if (inType.getElementType () != outType.getElementType () ||
68
70
allReduceOp.getReduction () != reduction) {
69
71
return false ;
70
72
}
@@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns(
87
89
return refAllReduceOp->getAttrs () == allReduceOp->getAttrs () &&
88
90
inType.getElementType () == refType.getElementType ();
89
91
};
90
- auto isAlgebraicOp = [](Operation *op) {
91
- return static_cast <bool >(llvm::dyn_cast<AlgebraicOp>(op));
92
- };
92
+ auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
93
93
94
94
using ConcreteEndomorphismSimplification = EndomorphismSimplification<
95
95
std::decay_t <decltype (getEndomorphismOpOperand)>,
Original file line number Diff line number Diff line change @@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism(
165
165
// CHECK: return %[[ALL_REDUCE_RES]]
166
166
return %2 : tensor <5 xi32 >
167
167
}
168
+
169
+ // Ensure this case without endomorphism op not crash.
170
+ // CHECK-LABEL: func.func @no_endomorphism_op
171
+ func.func @no_endomorphism_op (%arg0: tensor <2 xi64 >) -> i64 {
172
+ %c0 = arith.constant 0 : index
173
+ %c1_i64 = arith.constant 1 : i64
174
+ // CHECK: tensor.extract
175
+ %extracted = tensor.extract %arg0 [%c0 ] : tensor <2 xi64 >
176
+ // CHECK: arith.maxsi
177
+ %0 = arith.maxsi %extracted , %c1_i64 : i64
178
+ return %0 : i64
179
+ }
You can’t perform that action at this time.
0 commit comments