Skip to content

[CIR] Record type constraints #1834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3010,7 +3010,7 @@ def CIR_ExtractMemberOp : CIR_Op<"extract_member", [Pure]> {
```
}];

let arguments = (ins CIRRecordType:$record, IndexAttr:$index_attr);
let arguments = (ins CIR_AnyRecordType:$record, IndexAttr:$index_attr);
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -3075,9 +3075,13 @@ def CIR_InsertMemberOp : CIR_Op<"insert_member", [
```
}];

let arguments = (ins CIRRecordType:$record, IndexAttr:$index_attr,
CIR_AnyType:$value);
let results = (outs CIRRecordType:$result);
let arguments = (ins
CIR_AnyRecordType:$record,
IndexAttr:$index_attr,
CIR_AnyType:$value
);

let results = (outs CIR_AnyRecordType:$result);

let builders = [
OpBuilder<(ins "mlir::Value":$record, "uint64_t":$index,
Expand Down
5 changes: 0 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,11 +770,6 @@ def CIR_RecordType : CIR_Type<"Record", "record", [
let hasCustomAssemblyFormat = 1;
}

// Note CIRRecordType is used instead of CIR_RecordType
// because of tablegen conflicts.
def CIRRecordType : Type<
CPred<"::mlir::isa<::cir::RecordType>($_self)">, "CIR record type">;

//===----------------------------------------------------------------------===//
// Global type constraints
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3746,7 +3746,7 @@ LogicalResult cir::GetMemberOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult cir::ExtractMemberOp::verify() {
auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
cir::RecordType recordTy = getRecord().getType();
if (recordTy.getKind() == cir::RecordType::Union)
return emitError()
<< "cir.extract_member currently does not work on unions";
Expand All @@ -3762,7 +3762,7 @@ LogicalResult cir::ExtractMemberOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult cir::InsertMemberOp::verify() {
auto recordTy = mlir::cast<cir::RecordType>(getRecord().getType());
cir::RecordType recordTy = getRecord().getType();
if (recordTy.getKind() == cir::RecordType::Union)
return emitError() << "cir.update_member currently does not work on unions";
if (recordTy.getMembers().size() <= getIndex())
Expand Down
2 changes: 0 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,11 @@ class CIRCXXABI {
/// Lower the given cir.base_method op to a sequence of more "primitive" CIR
/// operations that act on the ABI types.
virtual mlir::Value lowerBaseMethod(cir::BaseMethodOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

/// Lower the given cir.derived_method op to a sequence of more "primitive"
/// CIR operations that act on the ABI types.
virtual mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
Expand Down
32 changes: 21 additions & 11 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "CIRCXXABI.h"
#include "LowerModule.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "llvm/Support/ErrorHandling.h"

namespace cir {
Expand Down Expand Up @@ -99,11 +101,10 @@ class ItaniumCXXABI : public CIRCXXABI {
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerBaseMethod(cir::BaseMethodOp op, mlir::Value loweredSrc,
mlir::Value lowerBaseMethod(cir::BaseMethodOp op,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
Expand Down Expand Up @@ -472,24 +473,35 @@ static mlir::Value lowerDataMemberCast(mlir::Operation *op,
isNull, nullValue, adjustedPtr);
}

static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value loweredSrc,
static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value src,
std::int64_t offset, bool isDerivedToBase,
LowerModule &lowerMod,
mlir::OpBuilder &builder) {
if (offset == 0)
return loweredSrc;
return src;

if (auto load = mlir::dyn_cast<cir::LoadOp>(src.getDefiningOp())) {
// If the source is a load of method, we can just adjust the base pointer.
load->dump();
load->getParentOp()->dump();
// src = load.getAddr();
llvm_unreachable("NYI: ItaniumCXXABI::lowerMethodCast for cir::LoadOp");
}

if (!mlir::isa<cir::RecordType>(src.getType()))
llvm_unreachable("Expected a record type for method pointer");

cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(lowerMod);
auto adjField = builder.create<cir::ExtractMemberOp>(
op->getLoc(), ptrdiffCIRTy, loweredSrc, 1);
auto adjField =
builder.create<cir::ExtractMemberOp>(op->getLoc(), ptrdiffCIRTy, src, 1);

auto offsetValue = builder.create<cir::ConstantOp>(
op->getLoc(), cir::IntAttr::get(ptrdiffCIRTy, offset));
auto binOpKind = isDerivedToBase ? cir::BinOpKind::Sub : cir::BinOpKind::Add;
auto adjustedAdjField = builder.create<cir::BinOp>(
op->getLoc(), ptrdiffCIRTy, binOpKind, adjField, offsetValue);

return builder.create<cir::InsertMemberOp>(op->getLoc(), loweredSrc, 1,
return builder.create<cir::InsertMemberOp>(op->getLoc(), src, 1,
adjustedAdjField);
}

Expand All @@ -509,16 +521,14 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
}

mlir::Value ItaniumCXXABI::lowerBaseMethod(cir::BaseMethodOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
return lowerMethodCast(op, op.getSrc(), op.getOffset().getSExtValue(),
/*isDerivedToBase=*/true, LM, builder);
}

mlir::Value ItaniumCXXABI::lowerDerivedMethod(cir::DerivedMethodOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
return lowerMethodCast(op, op.getSrc(), op.getOffset().getSExtValue(),
/*isDerivedToBase=*/false, LM, builder);
}

Expand Down
39 changes: 5 additions & 34 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ mlir::LogicalResult CIRToLLVMBaseMethodOpLowering::matchAndRewrite(
cir::BaseMethodOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Value loweredResult =
lowerMod->getCXXABI().lowerBaseMethod(op, adaptor.getSrc(), rewriter);
lowerMod->getCXXABI().lowerBaseMethod(op, rewriter);
rewriter.replaceOp(op, loweredResult);
return mlir::success();
}
Expand All @@ -1159,7 +1159,7 @@ mlir::LogicalResult CIRToLLVMDerivedMethodOpLowering::matchAndRewrite(
cir::DerivedMethodOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Value loweredResult =
lowerMod->getCXXABI().lowerDerivedMethod(op, adaptor.getSrc(), rewriter);
lowerMod->getCXXABI().lowerDerivedMethod(op, rewriter);
rewriter.replaceOp(op, loweredResult);
return mlir::success();
}
Expand Down Expand Up @@ -3746,44 +3746,15 @@ mlir::LogicalResult CIRToLLVMExtractMemberOpLowering::matchAndRewrite(
cir::ExtractMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
std::int64_t indecies[1] = {static_cast<std::int64_t>(op.getIndex())};

mlir::Type recordTy = op.getRecord().getType();
if (auto llvmStructTy =
mlir::dyn_cast<mlir::LLVM::LLVMStructType>(recordTy)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
op, adaptor.getRecord(), indecies);
return mlir::success();
}

auto cirRecordTy = mlir::cast<cir::RecordType>(recordTy);
switch (cirRecordTy.getKind()) {
case cir::RecordType::Struct:
case cir::RecordType::Class: {
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
op, adaptor.getRecord(), indecies);
return mlir::success();
}

case cir::RecordType::Union: {
op.emitError("cir.extract_member cannot extract member from a union");
return mlir::failure();
}
}
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
op, adaptor.getRecord(), indecies);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMInsertMemberOpLowering::matchAndRewrite(
cir::InsertMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
std::int64_t indecies[1] = {static_cast<std::int64_t>(op.getIndex())};
mlir::Type recordTy = op.getRecord().getType();

if (auto cirRecordTy = mlir::dyn_cast<cir::RecordType>(recordTy)) {
if (cirRecordTy.getKind() == cir::RecordType::Union) {
op.emitError("cir.update_member cannot update member of a union");
return mlir::failure();
}
}

rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
op, adaptor.getRecord(), adaptor.getValue(), indecies);
return mlir::success();
Expand Down
Loading