Skip to content

Commit 5655a0c

Browse files
authored
mlir: Parallel ForLike ops remover fix (#2548)
* localize gradients helper * Change to ForLike interface Support postadd * rename var * add parallel test * fix * remove yield new values * remove gradientCreator * Fix subtle bug here the index need to be the number of result from the previous if op but the replace implementation would replace them.
1 parent f20d134 commit 5655a0c

File tree

8 files changed

+284
-115
lines changed

8 files changed

+284
-115
lines changed

enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -242,30 +242,16 @@ struct AffineParallelOpInterfaceReverse
242242
bool valid = true;
243243
bool wasAtomic = gutils->AtomicAdd;
244244
gutils->AtomicAdd = true;
245-
std::function<Value(Location, Type)> gradientCreator = [&](Location loc,
246-
Type t) {
247-
auto shadowty = getShadowType(t);
248-
OpBuilder builder(t.getContext());
249-
// Gradients of values defined within the parallel body should be local to
250-
// each iteration
251-
builder.setInsertionPointToStart(revPar.getBody());
252-
253-
auto shadow = enzyme::InitOp::create(
254-
builder, loc, enzyme::GradientType::get(t.getContext(), shadowty));
255-
auto toset =
256-
cast<AutoDiffTypeInterface>(shadowty).createNullValue(builder, loc);
257-
enzyme::SetOp::create(builder, loc, shadow, toset);
258-
return shadow;
259-
};
260-
gutils->registerGradientCreatorHook(gradientCreator);
261-
auto scope = llvm::make_scope_exit(
262-
[&]() { gutils->deregisterGradientCreatorHook(gradientCreator); });
263245

264246
{
265247
Block *oBB = parOp.getBody();
266248
Block *rBB = revPar.getBody();
267249

268250
OpBuilder bodyBuilder = revPar.getBodyBuilder();
251+
252+
bodyBuilder.setInsertionPointToStart(revPar.getBody());
253+
mlir::enzyme::localizeGradients(bodyBuilder, gutils, oBB);
254+
269255
bodyBuilder.setInsertionPoint(rBB->getTerminator());
270256

271257
auto first = oBB->rbegin();
@@ -390,15 +376,21 @@ struct AffineParallelOpEnzymeOpsRemover
390376
replaceWithNewOperands(PatternRewriter &rewriter,
391377
affine::AffineParallelOp otherParOp,
392378
ArrayRef<Value> operands) {
393-
auto reductionKinds = llvm::map_to_vector(
394-
otherParOp.getReductions().getAsRange<arith::AtomicRMWKindAttr>(),
395-
[](auto red) { return red.getValue(); });
379+
SmallVector<mlir::Attribute> reductionKinds(
380+
otherParOp.getReductions().begin(), otherParOp.getReductions().end());
381+
382+
for (unsigned i = otherParOp->getNumOperands(); i < operands.size(); i++) {
383+
reductionKinds.push_back(arith::AtomicRMWKindAttr::get(
384+
otherParOp.getContext(), arith::AtomicRMWKind::addf));
385+
}
386+
387+
ValueRange operands_(operands);
396388
auto newOtherParOp = affine::AffineParallelOp::create(
397-
rewriter, otherParOp.getLoc(), otherParOp.getResultTypes(),
398-
otherParOp.getReductions(), otherParOp.getLowerBoundsMap(),
399-
otherParOp.getLowerBoundsGroups(), otherParOp.getUpperBoundsMap(),
400-
otherParOp.getUpperBoundsGroups(), otherParOp.getSteps(),
401-
otherParOp.getMapOperands());
389+
rewriter, otherParOp.getLoc(), operands_.getTypes(),
390+
ArrayAttr::get(otherParOp.getContext(), reductionKinds),
391+
otherParOp.getLowerBoundsMap(), otherParOp.getLowerBoundsGroups(),
392+
otherParOp.getUpperBoundsMap(), otherParOp.getUpperBoundsGroups(),
393+
otherParOp.getSteps(), otherParOp.getMapOperands());
402394

403395
newOtherParOp.getRegion().takeBody(otherParOp.getRegion());
404396
rewriter.replaceOp(otherParOp, newOtherParOp->getResults().slice(
@@ -409,6 +401,17 @@ struct AffineParallelOpEnzymeOpsRemover
409401
static ValueRange getInits(affine::AffineParallelOp parOp) {
410402
return parOp.getInits();
411403
}
404+
405+
static bool mustPostAdd(affine::AffineParallelOp forOp) { return true; }
406+
407+
static Value initialValueInBlock(OpBuilder &builder, Block *body,
408+
Value grad) {
409+
OpBuilder::InsertionGuard guard(builder);
410+
builder.setInsertionPointToStart(body);
411+
return cast<AutoDiffTypeInterface>(
412+
cast<enzyme::GradientType>(grad.getType()).getBasetype())
413+
.createNullValue(builder, grad.getLoc());
414+
}
412415
};
413416

414417
static void computeAffineIndices(OpBuilder &builder, Location loc,
@@ -741,6 +744,14 @@ struct AffineForOpEnzymeOpsRemover
741744
static ValueRange getInits(affine::AffineForOp forOp) {
742745
return forOp.getInits();
743746
}
747+
748+
static bool mustPostAdd(affine::AffineForOp forOp) { return false; }
749+
750+
static Value initialValueInBlock(OpBuilder &builder, Block *body,
751+
Value grad) {
752+
auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
753+
return body->addArgument(Ty, grad.getLoc());
754+
}
744755
};
745756

746757
#include "Implementations/AffineDerivatives.inc"

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ struct ForOpEnzymeOpsRemover
121121
}
122122

123123
static ValueRange getInits(scf::ForOp forOp) { return forOp.getInitArgs(); }
124+
125+
static bool mustPostAdd(scf::ForOp forOp) { return false; }
126+
127+
static Value initialValueInBlock(OpBuilder &builder, Block *body,
128+
Value grad) {
129+
auto Ty = cast<enzyme::GradientType>(grad.getType()).getBasetype();
130+
return body->addArgument(Ty, grad.getLoc());
131+
}
124132
};
125133

126134
struct ForOpInterfaceReverse
@@ -395,26 +403,11 @@ struct ForOpInterfaceReverse
395403
if (revBB.empty()) {
396404
scf::YieldOp::create(bodyBuilder, repFor->getLoc());
397405
}
398-
bodyBuilder.setInsertionPoint(revBB.getTerminator());
399406

400-
// All values defined in the body should have no use outside this block
401-
// therefore we can set their diffe to zero upon entering the reverse
402-
// block to simplify the work of the remove-unnecessary-enzyme-ops pass.
403-
for (auto operand : oBB.getArguments().slice(1)) {
404-
if (!gutils->isConstantValue(operand)) {
405-
gutils->zeroDiffe(operand, bodyBuilder);
406-
}
407-
}
407+
bodyBuilder.setInsertionPointToStart(&revBB);
408+
mlir::enzyme::localizeGradients(bodyBuilder, gutils, &oBB);
408409

409-
for (auto &it : oBB.getOperations()) {
410-
for (auto res : it.getResults()) {
411-
if (!gutils->isConstantValue(res)) {
412-
auto iface = dyn_cast<AutoDiffTypeInterface>(res.getType());
413-
if (iface && !iface.isMutable())
414-
gutils->zeroDiffe(res, bodyBuilder);
415-
}
416-
}
417-
}
410+
bodyBuilder.setInsertionPoint(revBB.getTerminator());
418411

419412
auto term = oBB.getTerminator();
420413

@@ -714,19 +707,51 @@ struct ParallelOpEnzymeOpsRemover
714707
ArrayRef<Value> operands) {
715708
auto newOtherParOp = scf::ParallelOp::create(
716709
rewriter, otherParallelOp.getLoc(), otherParallelOp.getLowerBound(),
717-
otherParallelOp.getUpperBound(), otherParallelOp.getStep(),
718-
otherParallelOp.getInitVals());
710+
otherParallelOp.getUpperBound(), otherParallelOp.getStep(), operands);
719711

720712
newOtherParOp.getRegion().takeBody(otherParallelOp.getRegion());
721713
rewriter.replaceOp(
722714
otherParallelOp,
723715
newOtherParOp.getResults().slice(0, otherParallelOp.getNumResults()));
716+
717+
if (operands.size() >= 1) {
718+
OpBuilder::InsertionGuard guard(rewriter);
719+
Operation *oldTerm = newOtherParOp.getBody()->getTerminator();
720+
rewriter.setInsertionPointToEnd(newOtherParOp.getBody());
721+
auto term = scf::ReduceOp::create(rewriter, newOtherParOp.getLoc(),
722+
oldTerm->getOperands());
723+
724+
for (auto [reg, operand] :
725+
llvm::zip_equal(term->getRegions(), operands)) {
726+
Block *b = &reg.front();
727+
rewriter.setInsertionPointToEnd(b);
728+
729+
auto Ty = cast<AutoDiffTypeInterface>(operand.getType());
730+
Value reduced = Ty.createAddOp(rewriter, operand.getLoc(),
731+
b->getArgument(0), b->getArgument(1));
732+
scf::ReduceReturnOp::create(rewriter, reduced.getLoc(), reduced);
733+
}
734+
735+
oldTerm->erase();
736+
}
737+
724738
return newOtherParOp;
725739
}
726740

727741
static ValueRange getInits(scf::ParallelOp parallelOp) {
728742
return parallelOp.getInitVals();
729743
}
744+
745+
static bool mustPostAdd(scf::ParallelOp forOp) { return false; }
746+
747+
static Value initialValueInBlock(OpBuilder &builder, Block *body,
748+
Value grad) {
749+
OpBuilder::InsertionGuard guard(builder);
750+
builder.setInsertionPointToStart(body);
751+
return cast<AutoDiffTypeInterface>(
752+
cast<enzyme::GradientType>(grad.getType()).getBasetype())
753+
.createNullValue(builder, grad.getLoc());
754+
}
730755
};
731756

732757
struct ParallelOpInterfaceReverse
@@ -754,30 +779,16 @@ struct ParallelOpInterfaceReverse
754779
bool valid = true;
755780
bool wasAtomic = gutils->AtomicAdd;
756781
gutils->AtomicAdd = true;
757-
std::function<Value(Location, Type)> gradientCreator = [&](Location loc,
758-
Type t) {
759-
auto shadowty = getShadowType(t);
760-
OpBuilder builder(t.getContext());
761-
// Gradients of values defined within the parallel body should be local to
762-
// each iteration
763-
builder.setInsertionPointToStart(revPar.getBody());
764-
765-
auto shadow = enzyme::InitOp::create(
766-
builder, loc, enzyme::GradientType::get(t.getContext(), shadowty));
767-
auto toset =
768-
cast<AutoDiffTypeInterface>(shadowty).createNullValue(builder, loc);
769-
enzyme::SetOp::create(builder, loc, shadow, toset);
770-
return shadow;
771-
};
772-
gutils->registerGradientCreatorHook(gradientCreator);
773-
auto scope = llvm::make_scope_exit(
774-
[&]() { gutils->deregisterGradientCreatorHook(gradientCreator); });
775782

776783
{
777784
Block *oBB = parallelOp.getBody();
778785
Block *revBB = revPar.getBody();
779786

780787
OpBuilder bodyBuilder(revBB, revBB->end());
788+
789+
bodyBuilder.setInsertionPointToStart(revBB);
790+
mlir::enzyme::localizeGradients(bodyBuilder, gutils, oBB);
791+
781792
bodyBuilder.setInsertionPoint(revBB->getTerminator());
782793

783794
auto first = oBB->rbegin();

enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "RemovalUtils.h"
1010
#include "Interfaces/AutoDiffOpInterface.h"
1111
#include "Interfaces/AutoDiffTypeInterface.h"
12+
#include "Interfaces/GradientUtilsReverse.h"
1213
#include "Utils.h"
1314
#include "mlir/Analysis/TopologicalSortUtils.h"
1415
#include "mlir/IR/PatternMatch.h"
@@ -22,6 +23,55 @@ using namespace mlir::enzyme;
2223

2324
#define DEBUG_TYPE "enzyme-mincut"
2425

26+
void mlir::enzyme::localizeGradients(OpBuilder &builder,
27+
MGradientUtilsReverse *gutils,
28+
Block *fwd) {
29+
Operation *parent = fwd->getParentOp();
30+
31+
auto localizeGradientValue = [&](Value val) {
32+
if (gutils->isConstantValue(val))
33+
return;
34+
auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType());
35+
if (iface && !iface.isMutable()) {
36+
auto grad = gutils->getDifferential(val);
37+
38+
enzyme::SetOp initialSet = nullptr;
39+
for (auto user : grad.getUsers()) {
40+
if (!parent->isProperAncestor(user)) {
41+
assert(!initialSet);
42+
initialSet = dyn_cast<enzyme::SetOp>(user);
43+
assert(initialSet);
44+
}
45+
}
46+
47+
auto initOp = grad.getDefiningOp<enzyme::InitOp>();
48+
49+
{
50+
OpBuilder::InsertionGuard g(builder);
51+
Value zero =
52+
iface.createNullValue(builder, initialSet.getValue().getLoc());
53+
builder.setInsertionPointAfter(zero.getDefiningOp());
54+
enzyme::SetOp::create(builder, initialSet.getLoc(), grad, zero);
55+
initialSet->erase();
56+
}
57+
58+
builder.setInsertionPointToStart(builder.getBlock());
59+
initOp->remove();
60+
builder.insert(initOp);
61+
}
62+
};
63+
64+
for (auto operand : fwd->getArguments()) {
65+
localizeGradientValue(operand);
66+
}
67+
68+
for (auto &it : fwd->getOperations()) {
69+
for (auto res : it.getResults()) {
70+
localizeGradientValue(res);
71+
}
72+
}
73+
}
74+
2575
void mlir::enzyme::removalBlockExplore(
2676
Block *block, IRMapping &mapping, PatternRewriter &rewriter,
2777
llvm::SetVector<Value> &gradients,

0 commit comments

Comments
 (0)