Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

using namespace cir;
using namespace llvm;
Expand Down Expand Up @@ -570,6 +571,70 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
}
}

void rewriteBreak(mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
// Collect all BreakOp inside this while.
llvm::SmallVector<cir::BreakOp> breaks;
whileOp->walk([&](mlir::Operation *op) {
if (auto breakOp = dyn_cast<BreakOp>(op))
breaks.push_back(breakOp);
});
if (breaks.empty())
return;
auto *pp = whileOp->getParentOp();
pp->dump();
for (auto breakOp : breaks) {
// When there is another loop between this WhileOp and the BreakOp,
// we should change that loop instead.
if (breakOp->getParentOfType<mlir::scf::WhileOp>() != whileOp)
continue;
// Similar to the case of ContinueOp, when there is an `IfOp`,
// we need to take special care.
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
parent = parent->getParentOp()) {
if (auto ifOp = dyn_cast<cir::IfOp>(parent))
llvm_unreachable("NYI");
}
// Operations after this BreakOp has to be removed.
for (mlir::Operation *runner = breakOp->getNextNode(); runner;) {
mlir::Operation *next = runner->getNextNode();
runner->erase();
runner = next;
}

// Blocks after this BreakOp also has to be removed.
for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) {
mlir::Block *next = block->getNextNode();
block->erase();
block = next;
}

// We know this BreakOp isn't nested in any IfOp.
// Therefore, the loop is executed only once.
// We pull everything out of the loop.
auto &beforeOps = whileOp.getBeforeBody()->getOperations();
for (mlir::Operation *op = &*beforeOps.begin(); op;) {
if (isa<ConditionOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}

auto &afterOps = whileOp.getAfterBody()->getOperations();
for (mlir::Operation *op = &*afterOps.begin(); op;) {
if (isa<YieldOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}
}

rewriter.eraseOp(whileOp);
pp->dump();
}

public:
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;

Expand All @@ -579,6 +644,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
SCFWhileLoop loop(op, adaptor, &rewriter);
auto whileOp = loop.transferToSCFWhileOp();
rewriteContinue(whileOp, rewriter);
rewriteBreak(whileOp, rewriter);
rewriter.eraseOp(op);
return mlir::success();
}
Expand Down
76 changes: 48 additions & 28 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,17 @@
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/LowerToMLIR.h"
#include "clang/CIR/LoweringHelpers.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/TimeProfiler.h"

using namespace cir;
Expand Down Expand Up @@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
} else {
// For scopes with results, use scf.execute_region
SmallVector<mlir::Type> types;
if (mlir::failed(
getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types)))
if (mlir::failed(getTypeConverter()->convertTypes(
scopeOp->getResultTypes(), types)))
return mlir::failure();
auto exec =
rewriter.create<mlir::scf::ExecuteRegionOp>(scopeOp.getLoc(), types);
Expand Down Expand Up @@ -1023,6 +1021,28 @@ class CIRYieldOpLowering : public mlir::OpConversionPattern<cir::YieldOp> {
}
};

class CIRBreakOpLowering : public mlir::OpConversionPattern<cir::BreakOp> {
public:
using OpConversionPattern<cir::BreakOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(cir::BreakOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::IfOp, mlir::scf::ForOp, mlir::scf::WhileOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
op, adaptor.getOperands());
return mlir::success();
})
.Case<mlir::memref::AllocaScopeOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::memref::AllocaScopeReturnOp>(
op, adaptor.getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
}
};

class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
public:
using mlir::OpConversionPattern<cir::IfOp>::OpConversionPattern;
Expand Down Expand Up @@ -1519,24 +1539,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns
.add<CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
CIRFuncOpLowering, CIRBrCondOpLowering,
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
CIRTrapOpLowering>(converter, patterns.getContext());
patterns.add<
CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down Expand Up @@ -1610,7 +1629,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
mlir::ModuleOp theModule = getOperation();

auto converter = prepareTypeConverter();

mlir::RewritePatternSet patterns(&getContext());

populateCIRLoopToSCFConversionPatterns(patterns, converter);
Expand All @@ -1628,10 +1647,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
// cir dialect, for example the `cir.continue`. If we marked cir as illegal
// here, then MLIR would think any remaining `cir.continue` indicates a
// failure, which is not what we want.

patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);

if (mlir::failed(mlir::applyPartialConversion(theModule, target,
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
CIRYieldOpLowering, CIRBreakOpLowering>(converter, context);

if (mlir::failed(mlir::applyPartialConversion(theModule, target,
std::move(patterns)))) {
signalPassFailure();
}
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void while_break() {
int i = 0;
while (i < 100) {
i++;
break;
i++;
}
// This should be compiled into the condition `i < 100` and a single `i++`,
// without the while-loop.

// CHECK: memref.alloca_scope {
// CHECK: %[[IV:.+]] = memref.load %alloca[]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the generated IR and I'm not sure this works right, the loop is gone but you still need to increment i, what am I missing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the 2nd i++ (operations after break) that got removed, the 1st i++ is kept.

  while (i < 100) {
    i++;
    break;
    i++;
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got that but I'm still not seeing a loop there and I'm curious what's going on

// CHECK: memref.alloca_scope {
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]]
// CHECK: memref.store %[[INCR]], %alloca[]
// CHECK: }
// CHECK: }
}
Loading