diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 4739290bf6e4b..a89c1ae475b96 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -710,7 +710,7 @@ void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) { void MemRefDependenceGraph::forEachMemRefInputEdge( unsigned id, const std::function &callback) { if (inEdges.count(id) > 0) - forEachMemRefEdge(inEdges[id], callback); + forEachMemRefEdge(inEdges.at(id), callback); } // Calls 'callback' for each output edge from node 'id' which carries a @@ -718,7 +718,7 @@ void MemRefDependenceGraph::forEachMemRefInputEdge( void MemRefDependenceGraph::forEachMemRefOutputEdge( unsigned id, const std::function &callback) { if (outEdges.count(id) > 0) - forEachMemRefEdge(outEdges[id], callback); + forEachMemRefEdge(outEdges.at(id), callback); } // Calls 'callback' for each edge in 'edges' which carries a memref @@ -730,9 +730,6 @@ void MemRefDependenceGraph::forEachMemRefEdge( if (!isa(edge.value.getType())) continue; assert(nodes.count(edge.id) > 0); - // Skip if 'edge.id' is not a loop nest. - if (!isa(getNode(edge.id)->op)) - continue; // Visit current input edge 'edge'. callback(edge); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 95848d0b67547..1d5a665bf6bb1 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -1473,9 +1473,11 @@ struct GreedyFusion { SmallVector inEdges; mdg->forEachMemRefInputEdge( dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { - // Add 'inEdge' if it is a read-after-write dependence. + // Add 'inEdge' if it is a read-after-write dependence or an edge + // from a memref defining op (e.g. view-like op or alloc op). if (dstNode->getLoadOpCount(inEdge.value) > 0 && - mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) + (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 || + inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op)) inEdges.push_back(inEdge); }); diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index b059b5a98405d..04c8c3ee809a1 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -743,3 +743,31 @@ module { return } } + +// SIBLING-MAXIMAL-LABEL: memref_cast_reused +func.func @memref_cast_reused(%arg: memref<*xf32>) { + %alloc = memref.cast %arg : memref<*xf32> to memref<10xf32> + %alloc_0 = memref.alloc() : memref<10xf32> + %alloc_1 = memref.alloc() : memref<10xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_2 = arith.constant 1.000000e+00 : f32 + affine.for %arg0 = 0 to 10 { + %0 = affine.load %alloc[%arg0] : memref<10xf32> + %1 = arith.addf %0, %cst_2 : f32 + affine.store %1, %alloc_0[%arg0] : memref<10xf32> + } + affine.for %arg0 = 0 to 10 { + %0 = affine.load %alloc[%arg0] : memref<10xf32> + %1 = affine.load %alloc_1[0] : memref<10xf32> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %alloc_1[0] : memref<10xf32> + } + // SIBLING-MAXIMAL: affine.for %{{.*}} = 0 to 10 + // SIBLING-MAXIMAL: addf + // SIBLING-MAXIMAL-NEXT: affine.store + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: addf + // SIBLING-MAXIMAL-NEXT: affine.store + return +}