From 55a307209ce710ba18b2bff00353e483276b3cc7 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 7 Aug 2025 13:40:49 +0000 Subject: [PATCH 1/2] [AArch64][SME] Port all SME routines to RuntimeLibcalls This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention. Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so tweaked the output to avoid the need for variables. --- llvm/include/llvm/CodeGen/TargetLowering.h | 6 ++ llvm/include/llvm/IR/RuntimeLibcalls.td | 43 +++++++++++- llvm/include/llvm/IR/RuntimeLibcallsImpl.td | 3 + .../Target/AArch64/AArch64FrameLowering.cpp | 16 +++-- .../Target/AArch64/AArch64ISelLowering.cpp | 40 +++++------ .../AArch64/AArch64TargetTransformInfo.cpp | 18 ++--- llvm/lib/Target/AArch64/SMEABIPass.cpp | 31 ++++++--- .../AArch64/Utils/AArch64SMEAttributes.cpp | 39 ++++++++--- .../AArch64/Utils/AArch64SMEAttributes.h | 18 ++--- .../RuntimeLibcallEmitter-calling-conv.td | 64 +++++------------- .../RuntimeLibcallEmitter-conflict-warning.td | 14 ++-- llvm/test/TableGen/RuntimeLibcallEmitter.td | 66 +++++++------------ .../Target/AArch64/SMEAttributesTest.cpp | 2 +- .../TableGen/Basic/RuntimeLibcallsEmitter.cpp | 33 +++++----- 14 files changed, 213 insertions(+), 180 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 172b01a649810..a5bffe401f1e6 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3558,6 +3558,12 @@ class LLVM_ABI TargetLoweringBase { return Libcalls.getLibcallImplName(Call); } + /// Check if this is valid libcall for the current module, otherwise + /// RTLIB::Unsupported. + RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const { + return Libcalls.getSupportedLibcallImpl(FuncName); + } + const char *getMemcpyName() const { return Libcalls.getMemcpyName(); } /// Get the comparison predicate that's to be used to test the result of the diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.td b/llvm/include/llvm/IR/RuntimeLibcalls.td index eadf3eae38923..4395113377adc 100644 --- a/llvm/include/llvm/IR/RuntimeLibcalls.td +++ b/llvm/include/llvm/IR/RuntimeLibcalls.td @@ -406,6 +406,17 @@ multiclass LibmLongDoubleLibCall AArch64LibcallImpls = { def __arm_sc_memcpy : RuntimeLibcallImpl; def __arm_sc_memmove : RuntimeLibcallImpl; def __arm_sc_memset : RuntimeLibcallImpl; + def __arm_sc_memchr : RuntimeLibcallImpl; } // End AArch64LibcallImpls +def __arm_sme_state : RuntimeLibcallImpl; +def __arm_tpidr2_save : RuntimeLibcallImpl; +def __arm_za_disable : RuntimeLibcallImpl; +def __arm_tpidr2_restore : RuntimeLibcallImpl; +def __arm_get_current_vg : RuntimeLibcallImpl; +def __arm_sme_state_size : RuntimeLibcallImpl; +def __arm_sme_save : RuntimeLibcallImpl; +def __arm_sme_restore : RuntimeLibcallImpl; + +def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add + __arm_tpidr2_save, + __arm_za_disable, + __arm_tpidr2_restore), + SMEABI_PreserveMost_From_X0>; + +def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add + __arm_get_current_vg, + __arm_sme_state_size, + __arm_sme_save, + __arm_sme_restore), + SMEABI_PreserveMost_From_X1>; + +def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add + __arm_sme_state), + SMEABI_PreserveMost_From_X2>; + def isAArch64_ExceptArm64EC : RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">; def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">; @@ -1245,7 +1283,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary< LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128, DefaultLibmExp10, DefaultStackProtector, - SecurityCheckCookieIfWinMSVC) + SecurityCheckCookieIfWinMSVC, + SMEABI_LibCalls_PreserveMost_From_X0, + SMEABI_LibCalls_PreserveMost_From_X1, + SMEABI_LibCalls_PreserveMost_From_X2) >; // Prepend a # to every name diff --git a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td index 601c291daf89d..b5752c1b69ad8 100644 --- a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td +++ b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td @@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>; def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>; def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>; def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>; +def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>; +def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>; +def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>; /// Abstract definition for functionality the compiler may need to /// emit a call to. Emits the RTLIB::Libcall enum - This enum defines diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index 885f2a94f85f5..ba02c82b25aaf 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) { if (Opc == AArch64::BL) { auto Op1 = MBBI->getOperand(0); - return Op1.isSymbol() && - (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg"); + auto &TLI = + *MBBI->getMF()->getSubtarget().getTargetLowering(); + char const *GetCurrentVG = + TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG); + return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG; } } @@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, ArrayRef CSI, const TargetRegisterInfo *TRI) const { MachineFunction &MF = *MBB.getParent(); + auto &TLI = *MF.getSubtarget().getTargetLowering(); const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); AArch64FunctionInfo *AFI = MF.getInfo(); bool NeedsWinCFI = needsWinCFI(MF); @@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( .addReg(AArch64::X0, RegState::Implicit) .setMIFlag(MachineInstr::FrameSetup); - const uint32_t *RegMask = TRI->getCallPreservedMask( - MF, - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1); + RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG; + const uint32_t *RegMask = + TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC)); BuildMI(MBB, MI, DL, TII.get(AArch64::BL)) - .addExternalSymbol("__arm_get_current_vg") + .addExternalSymbol(TLI.getLibcallName(LC)) .addRegMask(RegMask) .addReg(AArch64::X0, RegState::ImplicitDefine) .setMIFlag(MachineInstr::FrameSetup); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 4681bfaa55476..cbffe40dec604 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, AArch64FunctionInfo *FuncInfo = MF->getInfo(); const TargetInstrInfo *TII = Subtarget->getInstrInfo(); if (FuncInfo->isSMESaveBufferUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state_size") + .addExternalSymbol(getLibcallName(LC)) .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg()) .addReg(AArch64::X0); @@ -5739,15 +5738,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL, EVT VT) const { - SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; + SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC), getPointerTy(DAG.getDataLayout())); Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); Type *RetTy = StructType::get(Int64Ty, Int64Ty); TargetLowering::CallLoweringInfo CLI(DAG); ArgListTy Args; CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, - RetTy, Callee, std::move(Args)); + getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); std::pair CallResult = LowerCallTo(CLI); SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), @@ -8600,12 +8599,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, } static SMECallAttrs -getSMECallAttrs(const Function &Caller, +getSMECallAttrs(const Function &Caller, const TargetLowering &TLI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) - return SMECallAttrs(*CLI.CB); + return SMECallAttrs(*CLI.CB, &TLI); if (auto *ES = dyn_cast(CLI.Callee)) - return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol())); + return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI)); return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal)); } @@ -8627,7 +8626,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // SME Streaming functions are not eligible for TCO as they may require // the streaming mode or ZA to be restored after returning from the call. - SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI); + SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || CallAttrs.caller().hasStreamingBody()) @@ -8921,14 +8920,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64); Args.push_back(Entry); - SDValue Callee = - DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore", - TLI.getPointerTy(DAG.getDataLayout())); + RTLIB::Libcall LC = + IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE; + SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC), + TLI.getPointerTy(DAG.getDataLayout())); auto *RetTy = Type::getVoidTy(*DAG.getContext()); TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy, - Callee, std::move(Args)); + TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); return TLI.LowerCallTo(CLI).second; } @@ -9116,7 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Determine whether we need any streaming mode changes. - SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI); + SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI); auto DescribeCallsite = [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & { @@ -9693,11 +9692,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (RequiresLazySave) { // Conditionally restore the lazy save using a pseudo node. + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE; TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); SDValue RegMask = DAG.getRegisterMask( - TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC))); SDValue RestoreRoutine = DAG.getTargetExternalSymbol( - "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); + getLibcallName(LC), getPointerTy(DAG.getDataLayout())); SDValue TPIDR2_EL0 = DAG.getNode( ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); @@ -29036,7 +29036,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { // Checks to allow the use of SME instructions if (auto *Base = dyn_cast(&Inst)) { - auto CallAttrs = SMECallAttrs(*Base); + auto CallAttrs = SMECallAttrs(*Base, this); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index c6233461be655..bcefad936d630 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -220,20 +220,16 @@ static cl::opt EnableFixedwidthAutovecInStreamingMode( static cl::opt EnableScalableAutovecInStreamingMode( "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); -static bool isSMEABIRoutineCall(const CallInst &CI) { +static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); - return F && StringSwitch(F->getName()) - .Case("__arm_sme_state", true) - .Case("__arm_tpidr2_save", true) - .Case("__arm_tpidr2_restore", true) - .Case("__arm_za_disable", true) - .Default(false); + return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); } /// Returns true if the function has explicit operations that can only be /// lowered using incompatible instructions for the selected mode. This also /// returns true if the function F may use or modify ZA state. -static bool hasPossibleIncompatibleOps(const Function *F) { +static bool hasPossibleIncompatibleOps(const Function *F, + const TargetLowering &TLI) { for (const BasicBlock &BB : *F) { for (const Instruction &I : BB) { // Be conservative for now and assume that any call to inline asm or to @@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) { // all native LLVM instructions can be lowered to compatible instructions. if (isa(I) && !I.isDebugOrPseudoInst() && (cast(I).isInlineAsm() || isa(I) || - isSMEABIRoutineCall(cast(I)))) + isSMEABIRoutineCall(cast(I), TLI))) return true; } } @@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) { - if (hasPossibleIncompatibleOps(Callee)) + if (hasPossibleIncompatibleOps(Callee, *getTLI())) return false; } @@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, // change only once and avoid inlining of G into F. SMEAttrs FAttrs(*F); - SMECallAttrs CallAttrs(Call); + SMECallAttrs CallAttrs(Call, getTLI()); if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) { if (F == Call.getCaller()) // (1) diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index 4af4d49306625..2008516885c35 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -15,11 +15,16 @@ #include "AArch64.h" #include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/RuntimeLibcalls.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; @@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass { bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + private: bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder, - SMEAttrs FnAttrs); + SMEAttrs FnAttrs, const TargetLowering &TLI); }; } // end anonymous namespace @@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); } //===----------------------------------------------------------------------===// // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0. -void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { +void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI, + bool ZT0IsUndef = false) { auto &Ctx = M->getContext(); auto *TPIDR2SaveTy = FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false); auto Attrs = AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible"); + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE; FunctionCallee Callee = - M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs); + M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs); CallInst *Call = Builder.CreateCall(Callee); // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark @@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { if (ZT0IsUndef) Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef")); - Call->setCallingConv( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0); + Call->setCallingConv(TLI.getLibcallCallingConv(LC)); // A save to TPIDR2 should be followed by clearing TPIDR2_EL0. Function *WriteIntr = @@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { /// interface if it does not share ZA or ZT0. /// bool SMEABI::updateNewStateFunctions(Module *M, Function *F, - IRBuilder<> &Builder, SMEAttrs FnAttrs) { + IRBuilder<> &Builder, SMEAttrs FnAttrs, + const TargetLowering &TLI) { LLVMContext &Context = F->getContext(); BasicBlock *OrigBB = &F->getEntryBlock(); Builder.SetInsertPoint(&OrigBB->front()); @@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F, // Create a call __arm_tpidr2_save, which commits the lazy save. Builder.SetInsertPoint(&SaveBB->back()); - emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); + emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); // Enable pstate.za at the start of the function. Builder.SetInsertPoint(&OrigBB->front()); @@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) { if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za")) return false; + const TargetMachine &TM = + getAnalysis().getTM(); + const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering(); + bool Changed = false; SMEAttrs FnAttrs(F); if (FnAttrs.isNewZA() || FnAttrs.isNewZT0()) - Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs); + Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI); return Changed; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 271094f935e0e..bb788fcebe4ae 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "AArch64SMEAttributes.h" +#include "llvm/CodeGen/TargetLowering.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/RuntimeLibcalls.h" #include using namespace llvm; @@ -77,19 +79,36 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= encodeZT0State(StateValue::New); } -void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) { +void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, + const TargetLowering &TLI) { + RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + if (Impl == RTLIB::Unsupported) + return; + RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl); unsigned KnownAttrs = SMEAttrs::Normal; - if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state") + switch (LC) { + case RTLIB::SMEABI_SME_STATE: + case RTLIB::SMEABI_TPIDR2_SAVE: + case RTLIB::SMEABI_GET_CURRENT_VG: + case RTLIB::SMEABI_SME_STATE_SIZE: + case RTLIB::SMEABI_SME_SAVE: + case RTLIB::SMEABI_SME_RESTORE: KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine); - if (FuncName == "__arm_tpidr2_restore") + break; + case RTLIB::SMEABI_ZA_DISABLE: + case RTLIB::SMEABI_TPIDR2_RESTORE: KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) | SMEAttrs::SME_ABI_Routine; - if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" || - FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr") + break; + case RTLIB::SC_MEMCPY: + case RTLIB::SC_MEMMOVE: + case RTLIB::SC_MEMSET: + case RTLIB::SC_MEMCHR: KnownAttrs |= SMEAttrs::SM_Compatible; - if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" || - FuncName == "__arm_sme_state_size") - KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + break; + default: + break; + } set(KnownAttrs); } @@ -110,11 +129,11 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB) +SMECallAttrs::SMECallAttrs(const CallBase &CB, const TargetLowering *TLI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) - CalledFn = SMEAttrs(*CalledFunction, SMEAttrs::InferAttrsFromName::Yes); + CalledFn = SMEAttrs(*CalledFunction, TLI); // FIXME: We probably should not allow SME attributes on direct calls but // clang duplicates streaming mode attributes at each callsite. diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index f1be0ecbee7ed..06376c74025f8 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -13,6 +13,8 @@ namespace llvm { +class TargetLowering; + class Function; class CallBase; class AttributeList; @@ -48,17 +50,17 @@ class SMEAttrs { CallSiteFlags_Mask = ZT0_Undef }; - enum class InferAttrsFromName { No, Yes }; - SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No) + SMEAttrs(const Function &F, const TargetLowering *TLI = nullptr) : SMEAttrs(F.getAttributes()) { - if (Infer == InferAttrsFromName::Yes) - addKnownFunctionAttrs(F.getName()); + if (TLI) + addKnownFunctionAttrs(F.getName(), *TLI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); }; + SMEAttrs(StringRef FuncName, const TargetLowering &TLI) { + addKnownFunctionAttrs(FuncName, TLI); + }; void set(unsigned M, bool Enable = true); @@ -146,7 +148,7 @@ class SMEAttrs { } private: - void addKnownFunctionAttrs(StringRef FuncName); + void addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI); }; /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has @@ -163,7 +165,7 @@ class SMECallAttrs { SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB); + SMECallAttrs(const CallBase &CB, const TargetLowering *TLI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } diff --git a/llvm/test/TableGen/RuntimeLibcallEmitter-calling-conv.td b/llvm/test/TableGen/RuntimeLibcallEmitter-calling-conv.td index 49d5ecaa0e5c5..feef07502eedb 100644 --- a/llvm/test/TableGen/RuntimeLibcallEmitter-calling-conv.td +++ b/llvm/test/TableGen/RuntimeLibcallEmitter-calling-conv.td @@ -48,79 +48,47 @@ def MSP430LibraryWithCondCC : SystemRuntimeLibrary; // func_a and func_b both provide SOME_FUNC. // CHECK: if (isTargetArchA()) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::SOME_FUNC, RTLIB::func_b}, // func_b -// CHECK-NEXT: }; +// CHECK-NEXT: }); // ERR: :[[@LINE+1]]:5: warning: conflicting implementations for libcall SOME_FUNC: func_b, func_a def TheSystemLibraryA : SystemRuntimeLibrary; // CHECK: if (isTargetArchB()) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::OTHER_FUNC, RTLIB::other_func}, // other_func -// CHECK-NEXT: {RTLIB::SOME_FUNC, RTLIB::func_a}, // func_a -// CHECK-NEXT: }; +// CHECK-NEXT: {RTLIB::SOME_FUNC, RTLIB::func_a}, // func_a +// CHECK-NEXT: }); // ERR: :[[@LINE+1]]:5: warning: conflicting implementations for libcall SOME_FUNC: func_a, func_b def TheSystemLibraryB : SystemRuntimeLibrary; // CHECK: if (isTargetArchC()) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::ANOTHER_DUP, RTLIB::dup1}, // dup1 // CHECK-NEXT: {RTLIB::OTHER_FUNC, RTLIB::other_func}, // other_func // CHECK-NEXT: {RTLIB::SOME_FUNC, RTLIB::func_a}, // func_a -// CHECK-NEXT: }; +// CHECK-NEXT: }); // ERR: :[[@LINE+3]]:5: warning: conflicting implementations for libcall ANOTHER_DUP: dup1, dup0 // ERR: :[[@LINE+2]]:5: warning: conflicting implementations for libcall SOME_FUNC: func_a, func_b diff --git a/llvm/test/TableGen/RuntimeLibcallEmitter.td b/llvm/test/TableGen/RuntimeLibcallEmitter.td index 642f8b85a89c6..59ccd2341c54c 100644 --- a/llvm/test/TableGen/RuntimeLibcallEmitter.td +++ b/llvm/test/TableGen/RuntimeLibcallEmitter.td @@ -155,38 +155,36 @@ def BlahLibrary : SystemRuntimeLibrary Libcalls, +// CHECK-NEXT: std::optional CC = {}) +// CHECK-NEXT: { +// CHECK-NEXT: for (const auto [Func, Impl] : Libcalls) { +// CHECK-NEXT: setLibcallImpl(Func, Impl); +// CHECK-NEXT: if (CC) +// CHECK-NEXT: setLibcallImplCallingConv(Impl, *CC); +// CHECK-NEXT: } +// CHECK-NEXT: }; // CHECK-EMPTY: // CHECK-NEXT: if (TT.getArch() == Triple::blah) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::BZERO, RTLIB::bzero}, // bzero // CHECK-NEXT: {RTLIB::CALLOC, RTLIB::calloc}, // calloc // CHECK-NEXT: {RTLIB::SQRT_F128, RTLIB::sqrtl_f128}, // sqrtl -// CHECK-NEXT: }; -// CHECK-EMPTY: -// CHECK-NEXT: for (const auto [Func, Impl] : LibraryCalls) { -// CHECK-NEXT: setLibcallImpl(Func, Impl); -// CHECK-NEXT: } +// CHECK-NEXT: }); // CHECK-EMPTY: // CHECK-NEXT: if (TT.hasCompilerRT()) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls_hasCompilerRT[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::SHL_I32, RTLIB::__ashlsi3}, // __ashlsi3 // CHECK-NEXT: {RTLIB::SRL_I64, RTLIB::__lshrdi3}, // __lshrdi3 -// CHECK-NEXT: }; -// CHECK-EMPTY: -// CHECK-NEXT: for (const auto [Func, Impl] : LibraryCalls_hasCompilerRT) { -// CHECK-NEXT: setLibcallImpl(Func, Impl); -// CHECK-NEXT: } +// CHECK-NEXT: }); // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-EMPTY: // CHECK-NEXT: if (TT.getOS() == Triple::bar) { -// CHECK-NEXT: static const LibcallImplPair LibraryCalls_isBarOS[] = { +// CHECK-NEXT: setLibcallsImpl({ // CHECK-NEXT: {RTLIB::MEMSET, RTLIB::___memset}, // ___memset -// CHECK-NEXT: }; -// CHECK-EMPTY: -// CHECK-NEXT: for (const auto [Func, Impl] : LibraryCalls_isBarOS) { -// CHECK-NEXT: setLibcallImpl(Func, Impl); -// CHECK-NEXT: } +// CHECK-NEXT: }); // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-EMPTY: @@ -194,37 +192,25 @@ def BlahLibrary : SystemRuntimeLibrary((CallModule->getFunction("foo")->begin()->front())); - ASSERT_TRUE(SMECallAttrs(Call).callsite().hasUndefZT0()); + ASSERT_TRUE(SMECallAttrs(Call, nullptr).callsite().hasUndefZT0()); // Invalid combinations. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible), diff --git a/llvm/utils/TableGen/Basic/RuntimeLibcallsEmitter.cpp b/llvm/utils/TableGen/Basic/RuntimeLibcallsEmitter.cpp index 412431b96d030..6c35f60b07be7 100644 --- a/llvm/utils/TableGen/Basic/RuntimeLibcallsEmitter.cpp +++ b/llvm/utils/TableGen/Basic/RuntimeLibcallsEmitter.cpp @@ -360,6 +360,16 @@ void RuntimeLibcallEmitter::emitSystemRuntimeLibrarySetCalls( " struct LibcallImplPair {\n" " RTLIB::Libcall Func;\n" " RTLIB::LibcallImpl Impl;\n" + " };\n" + " auto setLibcallsImpl = [this](\n" + " ArrayRef Libcalls,\n" + " std::optional CC = {})\n" + " {\n" + " for (const auto [Func, Impl] : Libcalls) {\n" + " setLibcallImpl(Func, Impl);\n" + " if (CC)\n" + " setLibcallImplCallingConv(Impl, *CC);\n" + " }\n" " };\n"; ArrayRef AllLibs = Records.getAllDerivedDefinitions("SystemRuntimeLibrary"); @@ -485,31 +495,18 @@ void RuntimeLibcallEmitter::emitSystemRuntimeLibrarySetCalls( Funcs.erase(UniqueI, Funcs.end()); - OS << indent(IndentDepth + 2) - << "static const LibcallImplPair LibraryCalls"; - SubsetPredicate.emitTableVariableNameSuffix(OS); - OS << "[] = {\n"; + OS << indent(IndentDepth + 2) << "setLibcallsImpl({\n"; for (const RuntimeLibcallImpl *LibCallImpl : Funcs) { - OS << indent(IndentDepth + 6); + OS << indent(IndentDepth + 4); LibCallImpl->emitTableEntry(OS); } - - OS << indent(IndentDepth + 2) << "};\n\n" - << indent(IndentDepth + 2) - << "for (const auto [Func, Impl] : LibraryCalls"; - SubsetPredicate.emitTableVariableNameSuffix(OS); - OS << ") {\n" - << indent(IndentDepth + 4) << "setLibcallImpl(Func, Impl);\n"; - + OS << indent(IndentDepth + 2) << "}"; if (FuncsWithCC.CallingConv) { StringRef CCEnum = FuncsWithCC.CallingConv->getValueAsString("CallingConv"); - OS << indent(IndentDepth + 4) << "setLibcallImplCallingConv(Impl, " - << CCEnum << ");\n"; + OS << ", " << CCEnum; } - - OS << indent(IndentDepth + 2) << "}\n"; - OS << '\n'; + OS << ");\n\n"; if (!SubsetPredicate.isAlwaysAvailable()) { OS << indent(IndentDepth); From ecb2dcf89d0c35737d9a27d9ebf119591074106c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 12 Aug 2025 08:47:01 +0000 Subject: [PATCH 2/2] Rebase: Use RuntimeLibcalls for EmitEntryPStateSM --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 7 +++---- llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index cbffe40dec604..c85d3a646b9d5 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3108,13 +3108,12 @@ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI, const TargetInstrInfo *TII = Subtarget->getInstrInfo(); Register ResultReg = MI.getOperand(0).getReg(); if (FuncInfo->isPStateSMRegUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state") + .addExternalSymbol(getLibcallName(LC)) .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)); + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg) .addReg(AArch64::X0); } else { diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index bb788fcebe4ae..934f68b29922a 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -93,7 +93,7 @@ void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, case RTLIB::SMEABI_SME_STATE_SIZE: case RTLIB::SMEABI_SME_SAVE: case RTLIB::SMEABI_SME_RESTORE: - KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine); + KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; break; case RTLIB::SMEABI_ZA_DISABLE: case RTLIB::SMEABI_TPIDR2_RESTORE: