@@ -1496,29 +1496,118 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
1496
1496
}
1497
1497
};
1498
1498
1499
+ class CIRSwitchOpLowering : public mlir ::OpConversionPattern<cir::SwitchOp> {
1500
+ public:
1501
+ using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1502
+
1503
+ mlir::LogicalResult
1504
+ matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1505
+ mlir::ConversionPatternRewriter &rewriter) const override {
1506
+ rewriter.setInsertionPointAfter (op);
1507
+ llvm::SmallVector<CaseOp> cases;
1508
+ if (!op.isSimpleForm (cases))
1509
+ mlir::emitError (op.getLoc (), " not yet implemented" );
1510
+
1511
+ llvm::SmallVector<int64_t > caseValues;
1512
+ // Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1513
+ // This is necessary because some CaseOp might carry 0 or multiple values.
1514
+ llvm::DenseMap<size_t , unsigned > indexMap;
1515
+ caseValues.reserve (cases.size ());
1516
+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1517
+ switch (caseOp.getKind ()) {
1518
+ case CaseOpKind::Equal: {
1519
+ auto valueAttr = caseOp.getValue ()[0 ];
1520
+ auto value = cast<cir::IntAttr>(valueAttr);
1521
+ indexMap[i] = caseValues.size ();
1522
+ caseValues.push_back (value.getUInt ());
1523
+ break ;
1524
+ }
1525
+ case CaseOpKind::Default:
1526
+ break ;
1527
+ case CaseOpKind::Range:
1528
+ case CaseOpKind::Anyof:
1529
+ mlir::emitError (op.getLoc (), " not yet implemented" );
1530
+ }
1531
+ }
1532
+
1533
+ auto operand = adaptor.getOperands ()[0 ];
1534
+ // `scf.index_switch` expects an index of type `index`.
1535
+ auto indexType = mlir::IndexType::get (getContext ());
1536
+ auto indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1537
+ op.getLoc (), indexType, operand);
1538
+ auto indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1539
+ op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1540
+
1541
+ bool metDefault = false ;
1542
+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1543
+ auto ®ion = caseOp.getRegion ();
1544
+ switch (caseOp.getKind ()) {
1545
+ case CaseOpKind::Equal: {
1546
+ auto &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1547
+ rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1548
+ break ;
1549
+ }
1550
+ case CaseOpKind::Default: {
1551
+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1552
+ rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1553
+ metDefault = true ;
1554
+ break ;
1555
+ }
1556
+ case CaseOpKind::Range:
1557
+ case CaseOpKind::Anyof:
1558
+ llvm_unreachable (" NYI" );
1559
+ }
1560
+ }
1561
+
1562
+ // `scf.index_switch` expects its default region to contain exactly one
1563
+ // block. If we don't have a default region in `cir.switch`, we need to
1564
+ // supply it here.
1565
+ if (!metDefault) {
1566
+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1567
+ mlir::Block *block =
1568
+ rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1569
+ rewriter.setInsertionPointToEnd (block);
1570
+ rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1571
+ }
1572
+
1573
+ // The final `cir.break` should be replaced to `scf.yield`.
1574
+ // After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1575
+ for (auto ®ion : indexSwitch.getCaseRegions ()) {
1576
+ auto &lastBlock = region.back ();
1577
+ auto &lastOp = lastBlock.back ();
1578
+ assert (isa<BreakOp>(lastOp));
1579
+ rewriter.setInsertionPointAfter (&lastOp);
1580
+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1581
+ }
1582
+
1583
+ rewriter.replaceOp (op, indexSwitch);
1584
+
1585
+ return mlir::success ();
1586
+ }
1587
+ };
1588
+
1499
1589
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1500
1590
mlir::TypeConverter &converter) {
1501
1591
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1502
1592
1503
- patterns
1504
- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1505
- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1506
- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1507
- CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1508
- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1509
- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1510
- CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1511
- CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1512
- CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1513
- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1514
- CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1515
- CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1516
- CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1517
- CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1518
- CIRVectorInsertLowering, CIRVectorExtractLowering,
1519
- CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1520
- CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1521
- converter, patterns.getContext ());
1593
+ patterns.add <
1594
+ CIRSwitchOpLowering, CIRGetElementOpLowering, CIRATanOpLowering,
1595
+ CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1596
+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1597
+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1598
+ CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1599
+ CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1600
+ CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1601
+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1602
+ CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1603
+ CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1604
+ CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1605
+ CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1606
+ CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1607
+ CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1608
+ CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1609
+ CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1610
+ CIRTrapOpLowering>(converter, patterns.getContext ());
1522
1611
}
1523
1612
1524
1613
static mlir::TypeConverter prepareTypeConverter () {
@@ -1624,6 +1713,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
1624
1713
1625
1714
mlir::PassManager pm (mlirCtx);
1626
1715
1716
+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
1627
1717
pm.addPass (createConvertCIRToMLIRPass ());
1628
1718
pm.addPass (createConvertMLIRToLLVMPass ());
1629
1719
@@ -1669,6 +1759,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
1669
1759
1670
1760
mlir::PassManager pm (mlirCtx);
1671
1761
1762
+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
1672
1763
pm.addPass (createConvertCIRToMLIRPass ());
1673
1764
1674
1765
auto result = !mlir::failed (pm.run (theModule));
0 commit comments