Skip to content

Commit e8517f6

Browse files
[CIR][MLIR][LoweringThroughMLIR] CIRUnaryOpLowering on float values.
UnaryOpKind Inc, Dec, Plus and Minus can accept float operands, the lowering should also handle those situations.
1 parent 4037fb6 commit e8517f6

File tree

3 files changed

+69
-19
lines changed

3 files changed

+69
-19
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,36 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
802802
public:
803803
using OpConversionPattern<cir::UnaryOp>::OpConversionPattern;
804804

805+
mlir::Operation *
806+
addImmediate(cir::UnaryOp op, mlir::Type type, mlir::Value input, int64_t n,
807+
mlir::ConversionPatternRewriter &rewriter) const {
808+
if (type.isFloat()) {
809+
auto imm = mlir::arith::ConstantOp::create(rewriter, op.getLoc(),
810+
mlir::FloatAttr::get(type, n));
811+
return rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(op, type, input,
812+
imm);
813+
}
814+
auto imm = mlir::arith::ConstantOp::create(rewriter, op.getLoc(),
815+
mlir::IntegerAttr::get(type, n));
816+
return rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input,
817+
imm);
818+
}
819+
820+
mlir::Operation *
821+
subByImmediate(cir::UnaryOp op, mlir::Type type, mlir::Value input, int64_t n,
822+
mlir::ConversionPatternRewriter &rewriter) const {
823+
if (type.isFloat()) {
824+
auto imm = mlir::arith::ConstantOp::create(rewriter, op.getLoc(),
825+
mlir::FloatAttr::get(type, n));
826+
return rewriter.replaceOpWithNewOp<mlir::arith::SubFOp>(op, type, imm,
827+
input);
828+
}
829+
auto imm = mlir::arith::ConstantOp::create(rewriter, op.getLoc(),
830+
mlir::IntegerAttr::get(type, n));
831+
return rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, imm,
832+
input);
833+
}
834+
805835
mlir::LogicalResult
806836
matchAndRewrite(cir::UnaryOp op, OpAdaptor adaptor,
807837
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -810,36 +840,28 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
810840

811841
switch (op.getKind()) {
812842
case cir::UnaryOpKind::Inc: {
813-
auto One = mlir::arith::ConstantOp::create(
814-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
815-
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input, One);
843+
addImmediate(op, type, input, 1, rewriter);
816844
break;
817845
}
818846
case cir::UnaryOpKind::Dec: {
819-
auto One = mlir::arith::ConstantOp::create(
820-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
821-
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, input, One);
847+
addImmediate(op, type, input, -1, rewriter);
822848
break;
823849
}
824850
case cir::UnaryOpKind::Plus: {
825851
rewriter.replaceOp(op, op.getInput());
826852
break;
827853
}
828854
case cir::UnaryOpKind::Minus: {
829-
auto Zero = mlir::arith::ConstantOp::create(
830-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 0));
831-
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, Zero, input);
855+
subByImmediate(op, type, input, 0, rewriter);
832856
break;
833857
}
834858
case cir::UnaryOpKind::Not: {
835-
auto MinusOne = mlir::arith::ConstantOp::create(
859+
auto o = mlir::arith::ConstantOp::create(
836860
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, -1));
837-
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, MinusOne,
838-
input);
861+
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, o, input);
839862
break;
840863
}
841864
}
842-
843865
return mlir::LogicalResult::success();
844866
}
845867
};

clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
2-
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
1+
// RUN: cir-opt %s --cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s --cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
33

44
!s32i = !cir.int<s, 32>
55
module {
@@ -17,14 +17,32 @@ module {
1717
%5 = cir.load %1 : !cir.ptr<!s32i>, !s32i
1818
%6 = cir.unary(dec, %5) : !s32i, !s32i
1919
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>
20+
21+
// test float
22+
%7 = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
2023
cir.return
2124
}
22-
}
2325

2426
// MLIR: = arith.constant 1
2527
// MLIR: = arith.addi
26-
// MLIR: = arith.constant 1
27-
// MLIR: = arith.subi
28+
// MLIR: = arith.constant -1
29+
// MLIR: = arith.addi
2830

2931
// LLVM: = add i32 %[[#]], 1
30-
// LLVM: = sub i32 %[[#]], 1
32+
// LLVM: = add i32 %[[#]], -1
33+
34+
35+
cir.func @floatingPoints(%arg0: !cir.double) {
36+
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
37+
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
38+
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
39+
%2 = cir.unary(inc, %1) : !cir.double, !cir.double
40+
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
41+
%4 = cir.unary(dec, %3) : !cir.double, !cir.double
42+
cir.return
43+
}
44+
// MLIR: = arith.constant 1.0
45+
// MLIR: = arith.addf
46+
// MLIR: = arith.constant -1.0
47+
// MLIR: = arith.addf
48+
}

clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ module {
1919
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>
2020
cir.return
2121
}
22+
23+
cir.func @floatingPoints(%arg0: !cir.double) {
24+
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
25+
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
26+
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
27+
%2 = cir.unary(plus, %1) : !cir.double, !cir.double
28+
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
29+
%4 = cir.unary(minus, %3) : !cir.double, !cir.double
30+
cir.return
31+
}
2232
}
2333

2434
// MLIR: %[[#INPUT_PLUS:]] = memref.load

0 commit comments

Comments
 (0)