Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

using namespace cir;
using namespace llvm;
Expand Down Expand Up @@ -570,15 +571,82 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
}
}

void optimizeOnCertainBreak(mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
// Collect all BreakOp inside this while.
llvm::SmallVector<cir::BreakOp> breaks;
whileOp->walk([&](mlir::Operation *op) {
if (auto breakOp = dyn_cast<BreakOp>(op))
breaks.push_back(breakOp);
});
if (breaks.empty())
return;
auto *pp = whileOp->getParentOp();
pp->dump();
for (auto breakOp : breaks) {
// When there is another loop between this WhileOp and the BreakOp,
// we should change that loop instead.
if (breakOp->getParentOfType<mlir::scf::WhileOp>() != whileOp)
continue;
// Similar to the case of ContinueOp, when there is an `IfOp`,
// we need to take special care.
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
parent = parent->getParentOp()) {
if (auto ifOp = dyn_cast<cir::IfOp>(parent))
llvm_unreachable("NYI");
}
// Operations after this BreakOp has to be removed.
for (mlir::Operation *runner = breakOp->getNextNode(); runner;) {
mlir::Operation *next = runner->getNextNode();
runner->erase();
runner = next;
}

// Blocks after this BreakOp also has to be removed.
for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) {
mlir::Block *next = block->getNextNode();
block->erase();
block = next;
}

// We know this BreakOp isn't nested in any IfOp.
// Therefore, the loop is executed only once.
// We pull everything out of the loop.
auto &beforeOps = whileOp.getBeforeBody()->getOperations();
for (mlir::Operation *op = &*beforeOps.begin(); op;) {
if (isa<ConditionOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}

auto &afterOps = whileOp.getAfterBody()->getOperations();
for (mlir::Operation *op = &*afterOps.begin(); op;) {
if (isa<YieldOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}
}

rewriter.eraseOp(whileOp);
pp->dump();
}

public:
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;

/// This rewrite will do some optimizations at the same time.
/// Unreachable code and unnecessary loops will be eliminated.
mlir::LogicalResult
matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFWhileLoop loop(op, adaptor, &rewriter);
auto whileOp = loop.transferToSCFWhileOp();
rewriteContinue(whileOp, rewriter);
optimizeOnCertainBreak(whileOp, rewriter);
rewriter.eraseOp(op);
return mlir::success();
}
Expand Down
76 changes: 48 additions & 28 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,17 @@
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/LowerToMLIR.h"
#include "clang/CIR/LoweringHelpers.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/TimeProfiler.h"

using namespace cir;
Expand Down Expand Up @@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
} else {
// For scopes with results, use scf.execute_region
SmallVector<mlir::Type> types;
if (mlir::failed(
getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types)))
if (mlir::failed(getTypeConverter()->convertTypes(
scopeOp->getResultTypes(), types)))
return mlir::failure();
auto exec =
rewriter.create<mlir::scf::ExecuteRegionOp>(scopeOp.getLoc(), types);
Expand Down Expand Up @@ -1023,6 +1021,28 @@ class CIRYieldOpLowering : public mlir::OpConversionPattern<cir::YieldOp> {
}
};

class CIRBreakOpLowering : public mlir::OpConversionPattern<cir::BreakOp> {
public:
using OpConversionPattern<cir::BreakOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(cir::BreakOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::IfOp, mlir::scf::ForOp, mlir::scf::WhileOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
op, adaptor.getOperands());
return mlir::success();
})
.Case<mlir::memref::AllocaScopeOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::memref::AllocaScopeReturnOp>(
op, adaptor.getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
}
};

class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
public:
using mlir::OpConversionPattern<cir::IfOp>::OpConversionPattern;
Expand Down Expand Up @@ -1519,24 +1539,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns
.add<CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
CIRFuncOpLowering, CIRBrCondOpLowering,
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
CIRTrapOpLowering>(converter, patterns.getContext());
patterns.add<
CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down Expand Up @@ -1610,7 +1629,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
mlir::ModuleOp theModule = getOperation();

auto converter = prepareTypeConverter();

mlir::RewritePatternSet patterns(&getContext());

populateCIRLoopToSCFConversionPatterns(patterns, converter);
Expand All @@ -1628,10 +1647,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
// cir dialect, for example the `cir.continue`. If we marked cir as illegal
// here, then MLIR would think any remaining `cir.continue` indicates a
// failure, which is not what we want.

patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);

if (mlir::failed(mlir::applyPartialConversion(theModule, target,
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
CIRYieldOpLowering, CIRBreakOpLowering>(converter, context);

if (mlir::failed(mlir::applyPartialConversion(theModule, target,
std::move(patterns)))) {
signalPassFailure();
}
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void while_break() {
int i = 0;
while (i < 100) {
i++;
break;
i++;
}
// This should be compiled into the condition `i < 100` and a single `i++`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks way more than lowering, you are really doing an "optimization" here. Please make sure you document that behavior in code, because it's super surprising.

Btw, we should be thinking on how to improve the CIR representation such that it can more easily be lowered, there's a lot of heavy lifting needed to lower to MLIR, we can probably do better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! the function rewriteBreak got renamed and comments are added to its matchAndRewrite function.

// without the while-loop.

// CHECK: memref.alloca_scope {
// CHECK: %[[IV:.+]] = memref.load %alloca[]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]]
// CHECK: memref.alloca_scope {
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]]
// CHECK: memref.store %[[INCR]], %alloca[]
// CHECK: }
// CHECK: }
}
Loading