From e57e967a1656500e35b2e64c91c9695efc920a54 Mon Sep 17 00:00:00 2001 From: xlauko Date: Fri, 30 May 2025 09:29:34 +0200 Subject: [PATCH] XXX --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 12 ++++-- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 5 --- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 4 +- .../Transforms/TargetLowering/CIRCXXABI.h | 2 - .../TargetLowering/ItaniumCXXABI.cpp | 32 +++++++++------ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 39 +++---------------- 6 files changed, 36 insertions(+), 58 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 72ff40e10b9c..7f187641800a 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3010,7 +3010,7 @@ def CIR_ExtractMemberOp : CIR_Op<"extract_member", [Pure]> { ``` }]; - let arguments = (ins CIRRecordType:$record, IndexAttr:$index_attr); + let arguments = (ins CIR_AnyRecordType:$record, IndexAttr:$index_attr); let results = (outs CIR_AnyType:$result); let assemblyFormat = [{ @@ -3075,9 +3075,13 @@ def CIR_InsertMemberOp : CIR_Op<"insert_member", [ ``` }]; - let arguments = (ins CIRRecordType:$record, IndexAttr:$index_attr, - CIR_AnyType:$value); - let results = (outs CIRRecordType:$result); + let arguments = (ins + CIR_AnyRecordType:$record, + IndexAttr:$index_attr, + CIR_AnyType:$value + ); + + let results = (outs CIR_AnyRecordType:$result); let builders = [ OpBuilder<(ins "mlir::Value":$record, "uint64_t":$index, diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 830c8f9b82fc..d4ee0c53e6c8 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -770,11 +770,6 @@ def CIR_RecordType : CIR_Type<"Record", "record", [ let hasCustomAssemblyFormat = 1; } -// Note CIRRecordType is used instead of CIR_RecordType -// because of tablegen conflicts. -def CIRRecordType : Type< - CPred<"::mlir::isa<::cir::RecordType>($_self)">, "CIR record type">; - //===----------------------------------------------------------------------===// // Global type constraints //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index b499796c00b4..4230cf624e20 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -3746,7 +3746,7 @@ LogicalResult cir::GetMemberOp::verify() { //===----------------------------------------------------------------------===// LogicalResult cir::ExtractMemberOp::verify() { - auto recordTy = mlir::cast(getRecord().getType()); + cir::RecordType recordTy = getRecord().getType(); if (recordTy.getKind() == cir::RecordType::Union) return emitError() << "cir.extract_member currently does not work on unions"; @@ -3762,7 +3762,7 @@ LogicalResult cir::ExtractMemberOp::verify() { //===----------------------------------------------------------------------===// LogicalResult cir::InsertMemberOp::verify() { - auto recordTy = mlir::cast(getRecord().getType()); + cir::RecordType recordTy = getRecord().getType(); if (recordTy.getKind() == cir::RecordType::Union) return emitError() << "cir.update_member currently does not work on unions"; if (recordTy.getMembers().size() <= getIndex()) diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index ec62ed1dbacf..9352a89206b5 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -121,13 +121,11 @@ class CIRCXXABI { /// Lower the given cir.base_method op to a sequence of more "primitive" CIR /// operations that act on the ABI types. virtual mlir::Value lowerBaseMethod(cir::BaseMethodOp op, - mlir::Value loweredSrc, mlir::OpBuilder &builder) const = 0; /// Lower the given cir.derived_method op to a sequence of more "primitive" /// CIR operations that act on the ABI types. virtual mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op, - mlir::Value loweredSrc, mlir::OpBuilder &builder) const = 0; virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index 17ff1433bf4a..e17c033f28f6 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -24,6 +24,8 @@ #include "CIRCXXABI.h" #include "LowerModule.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/IR/CIRTypes.h" #include "llvm/Support/ErrorHandling.h" namespace cir { @@ -99,11 +101,10 @@ class ItaniumCXXABI : public CIRCXXABI { mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; - mlir::Value lowerBaseMethod(cir::BaseMethodOp op, mlir::Value loweredSrc, + mlir::Value lowerBaseMethod(cir::BaseMethodOp op, mlir::OpBuilder &builder) const override; mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op, - mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, @@ -472,16 +473,27 @@ static mlir::Value lowerDataMemberCast(mlir::Operation *op, isNull, nullValue, adjustedPtr); } -static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value loweredSrc, +static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value src, std::int64_t offset, bool isDerivedToBase, LowerModule &lowerMod, mlir::OpBuilder &builder) { if (offset == 0) - return loweredSrc; + return src; + + if (auto load = mlir::dyn_cast(src.getDefiningOp())) { + // If the source is a load of method, we can just adjust the base pointer. + load->dump(); + load->getParentOp()->dump(); + // src = load.getAddr(); + llvm_unreachable("NYI: ItaniumCXXABI::lowerMethodCast for cir::LoadOp"); + } + + if (!mlir::isa(src.getType())) + llvm_unreachable("Expected a record type for method pointer"); cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(lowerMod); - auto adjField = builder.create( - op->getLoc(), ptrdiffCIRTy, loweredSrc, 1); + auto adjField = + builder.create(op->getLoc(), ptrdiffCIRTy, src, 1); auto offsetValue = builder.create( op->getLoc(), cir::IntAttr::get(ptrdiffCIRTy, offset)); @@ -489,7 +501,7 @@ static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value loweredSrc, auto adjustedAdjField = builder.create( op->getLoc(), ptrdiffCIRTy, binOpKind, adjField, offsetValue); - return builder.create(op->getLoc(), loweredSrc, 1, + return builder.create(op->getLoc(), src, 1, adjustedAdjField); } @@ -509,16 +521,14 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op, } mlir::Value ItaniumCXXABI::lowerBaseMethod(cir::BaseMethodOp op, - mlir::Value loweredSrc, mlir::OpBuilder &builder) const { - return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(), + return lowerMethodCast(op, op.getSrc(), op.getOffset().getSExtValue(), /*isDerivedToBase=*/true, LM, builder); } mlir::Value ItaniumCXXABI::lowerDerivedMethod(cir::DerivedMethodOp op, - mlir::Value loweredSrc, mlir::OpBuilder &builder) const { - return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(), + return lowerMethodCast(op, op.getSrc(), op.getOffset().getSExtValue(), /*isDerivedToBase=*/false, LM, builder); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index b7e144c4b65f..8e11752eff37 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1150,7 +1150,7 @@ mlir::LogicalResult CIRToLLVMBaseMethodOpLowering::matchAndRewrite( cir::BaseMethodOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::Value loweredResult = - lowerMod->getCXXABI().lowerBaseMethod(op, adaptor.getSrc(), rewriter); + lowerMod->getCXXABI().lowerBaseMethod(op, rewriter); rewriter.replaceOp(op, loweredResult); return mlir::success(); } @@ -1159,7 +1159,7 @@ mlir::LogicalResult CIRToLLVMDerivedMethodOpLowering::matchAndRewrite( cir::DerivedMethodOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::Value loweredResult = - lowerMod->getCXXABI().lowerDerivedMethod(op, adaptor.getSrc(), rewriter); + lowerMod->getCXXABI().lowerDerivedMethod(op, rewriter); rewriter.replaceOp(op, loweredResult); return mlir::success(); } @@ -3746,44 +3746,15 @@ mlir::LogicalResult CIRToLLVMExtractMemberOpLowering::matchAndRewrite( cir::ExtractMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { std::int64_t indecies[1] = {static_cast(op.getIndex())}; - - mlir::Type recordTy = op.getRecord().getType(); - if (auto llvmStructTy = - mlir::dyn_cast(recordTy)) { - rewriter.replaceOpWithNewOp( - op, adaptor.getRecord(), indecies); - return mlir::success(); - } - - auto cirRecordTy = mlir::cast(recordTy); - switch (cirRecordTy.getKind()) { - case cir::RecordType::Struct: - case cir::RecordType::Class: { - rewriter.replaceOpWithNewOp( - op, adaptor.getRecord(), indecies); - return mlir::success(); - } - - case cir::RecordType::Union: { - op.emitError("cir.extract_member cannot extract member from a union"); - return mlir::failure(); - } - } + rewriter.replaceOpWithNewOp( + op, adaptor.getRecord(), indecies); + return mlir::success(); } mlir::LogicalResult CIRToLLVMInsertMemberOpLowering::matchAndRewrite( cir::InsertMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { std::int64_t indecies[1] = {static_cast(op.getIndex())}; - mlir::Type recordTy = op.getRecord().getType(); - - if (auto cirRecordTy = mlir::dyn_cast(recordTy)) { - if (cirRecordTy.getKind() == cir::RecordType::Union) { - op.emitError("cir.update_member cannot update member of a union"); - return mlir::failure(); - } - } - rewriter.replaceOpWithNewOp( op, adaptor.getRecord(), adaptor.getValue(), indecies); return mlir::success();