Skip to content

[AArch64][SME] Port all SME routines to RuntimeLibcalls #152505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3558,6 +3558,12 @@ class LLVM_ABI TargetLoweringBase {
return Libcalls.getLibcallImplName(Call);
}

/// Check if this is valid libcall for the current module, otherwise
/// RTLIB::Unsupported.
RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const {
return Libcalls.getSupportedLibcallImpl(FuncName);
}

const char *getMemcpyName() const { return Libcalls.getMemcpyName(); }

/// Get the comparison predicate that's to be used to test the result of the
Expand Down
43 changes: 42 additions & 1 deletion llvm/include/llvm/IR/RuntimeLibcalls.td
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,17 @@ multiclass LibmLongDoubleLibCall<string libcall_basename = !toupper(NAME),
def SC_MEMCPY : RuntimeLibcall;
def SC_MEMMOVE : RuntimeLibcall;
def SC_MEMSET : RuntimeLibcall;
def SC_MEMCHR: RuntimeLibcall;

// AArch64 SME ABI calls
def SMEABI_SME_STATE : RuntimeLibcall;
def SMEABI_TPIDR2_SAVE : RuntimeLibcall;
def SMEABI_ZA_DISABLE : RuntimeLibcall;
def SMEABI_TPIDR2_RESTORE : RuntimeLibcall;
def SMEABI_GET_CURRENT_VG : RuntimeLibcall;
def SMEABI_SME_STATE_SIZE : RuntimeLibcall;
def SMEABI_SME_SAVE : RuntimeLibcall;
def SMEABI_SME_RESTORE : RuntimeLibcall;

// ARM EABI calls
def AEABI_MEMCPY4 : RuntimeLibcall; // Align 4
Expand Down Expand Up @@ -1224,8 +1235,35 @@ defset list<RuntimeLibcallImpl> AArch64LibcallImpls = {
def __arm_sc_memcpy : RuntimeLibcallImpl<SC_MEMCPY>;
def __arm_sc_memmove : RuntimeLibcallImpl<SC_MEMMOVE>;
def __arm_sc_memset : RuntimeLibcallImpl<SC_MEMSET>;
def __arm_sc_memchr : RuntimeLibcallImpl<SC_MEMCHR>;
} // End AArch64LibcallImpls

def __arm_sme_state : RuntimeLibcallImpl<SMEABI_SME_STATE>;
def __arm_tpidr2_save : RuntimeLibcallImpl<SMEABI_TPIDR2_SAVE>;
def __arm_za_disable : RuntimeLibcallImpl<SMEABI_ZA_DISABLE>;
def __arm_tpidr2_restore : RuntimeLibcallImpl<SMEABI_TPIDR2_RESTORE>;
def __arm_get_current_vg : RuntimeLibcallImpl<SMEABI_GET_CURRENT_VG>;
def __arm_sme_state_size : RuntimeLibcallImpl<SMEABI_SME_STATE_SIZE>;
def __arm_sme_save : RuntimeLibcallImpl<SMEABI_SME_SAVE>;
def __arm_sme_restore : RuntimeLibcallImpl<SMEABI_SME_RESTORE>;

def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add
__arm_tpidr2_save,
__arm_za_disable,
__arm_tpidr2_restore),
SMEABI_PreserveMost_From_X0>;

def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add
__arm_get_current_vg,
__arm_sme_state_size,
__arm_sme_save,
__arm_sme_restore),
SMEABI_PreserveMost_From_X1>;

def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add
__arm_sme_state),
SMEABI_PreserveMost_From_X2>;

def isAArch64_ExceptArm64EC
: RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">;
def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">;
Expand All @@ -1245,7 +1283,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary<
LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128,
DefaultLibmExp10,
DefaultStackProtector,
SecurityCheckCookieIfWinMSVC)
SecurityCheckCookieIfWinMSVC,
SMEABI_LibCalls_PreserveMost_From_X0,
SMEABI_LibCalls_PreserveMost_From_X1,
SMEABI_LibCalls_PreserveMost_From_X2)
>;

// Prepend a # to every name
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/RuntimeLibcallsImpl.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>;
def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>;
def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>;
def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>;
def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>;
def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>;
def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>;

/// Abstract definition for functionality the compiler may need to
/// emit a call to. Emits the RTLIB::Libcall enum - This enum defines
Expand Down
16 changes: 10 additions & 6 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) {

if (Opc == AArch64::BL) {
auto Op1 = MBBI->getOperand(0);
return Op1.isSymbol() &&
(StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
auto &TLI =
*MBBI->getMF()->getSubtarget<AArch64Subtarget>().getTargetLowering();
char const *GetCurrentVG =
TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG);
return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG;
}
}

Expand Down Expand Up @@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
MachineFunction &MF = *MBB.getParent();
auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering();
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
bool NeedsWinCFI = needsWinCFI(MF);
Expand Down Expand Up @@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
.addReg(AArch64::X0, RegState::Implicit)
.setMIFlag(MachineInstr::FrameSetup);

const uint32_t *RegMask = TRI->getCallPreservedMask(
MF,
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG;
const uint32_t *RegMask =
TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC));
BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
.addExternalSymbol("__arm_get_current_vg")
.addExternalSymbol(TLI.getLibcallName(LC))
.addRegMask(RegMask)
.addReg(AArch64::X0, RegState::ImplicitDefine)
.setMIFlag(MachineInstr::FrameSetup);
Expand Down
47 changes: 23 additions & 24 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state_size")
.addExternalSymbol(getLibcallName(LC))
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
.addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::X0);
Expand All @@ -3109,13 +3108,12 @@ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
Register ResultReg = MI.getOperand(0).getReg();
if (FuncInfo->isPStateSMRegUsed()) {
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state")
.addExternalSymbol(getLibcallName(LC))
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
.addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
.addReg(AArch64::X0);
} else {
Expand Down Expand Up @@ -5739,15 +5737,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
SDValue Chain, SDLoc DL,
EVT VT) const {
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
getPointerTy(DAG.getDataLayout()));
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
Type *RetTy = StructType::get(Int64Ty, Int64Ty);
TargetLowering::CallLoweringInfo CLI(DAG);
ArgListTy Args;
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
RetTy, Callee, std::move(Args));
getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
Expand Down Expand Up @@ -8600,12 +8598,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
}

static SMECallAttrs
getSMECallAttrs(const Function &Caller,
getSMECallAttrs(const Function &Caller, const TargetLowering &TLI,
const TargetLowering::CallLoweringInfo &CLI) {
if (CLI.CB)
return SMECallAttrs(*CLI.CB);
return SMECallAttrs(*CLI.CB, &TLI);
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
}

Expand All @@ -8627,7 +8625,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(

// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState() ||
CallAttrs.caller().hasStreamingBody())
Expand Down Expand Up @@ -8921,14 +8919,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
Args.push_back(Entry);

SDValue Callee =
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
TLI.getPointerTy(DAG.getDataLayout()));
RTLIB::Libcall LC =
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
Callee, std::move(Args));
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}

Expand Down Expand Up @@ -9116,7 +9114,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);

auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
Expand Down Expand Up @@ -9693,11 +9691,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

if (RequiresLazySave) {
// Conditionally restore the lazy save using a pseudo node.
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
"__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
Expand Down Expand Up @@ -29036,7 +29035,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {

// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
auto CallAttrs = SMECallAttrs(*Base);
auto CallAttrs = SMECallAttrs(*Base, this);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState())
Expand Down
18 changes: 7 additions & 11 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,16 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
static cl::opt<bool> EnableScalableAutovecInStreamingMode(
"enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);

static bool isSMEABIRoutineCall(const CallInst &CI) {
static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) {
const auto *F = CI.getCalledFunction();
return F && StringSwitch<bool>(F->getName())
.Case("__arm_sme_state", true)
.Case("__arm_tpidr2_save", true)
.Case("__arm_tpidr2_restore", true)
.Case("__arm_za_disable", true)
.Default(false);
return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
}

/// Returns true if the function has explicit operations that can only be
/// lowered using incompatible instructions for the selected mode. This also
/// returns true if the function F may use or modify ZA state.
static bool hasPossibleIncompatibleOps(const Function *F) {
static bool hasPossibleIncompatibleOps(const Function *F,
const TargetLowering &TLI) {
for (const BasicBlock &BB : *F) {
for (const Instruction &I : BB) {
// Be conservative for now and assume that any call to inline asm or to
Expand All @@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
// all native LLVM instructions can be lowered to compatible instructions.
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
isSMEABIRoutineCall(cast<CallInst>(I))))
isSMEABIRoutineCall(cast<CallInst>(I), TLI)))
return true;
}
}
Expand Down Expand Up @@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState()) {
if (hasPossibleIncompatibleOps(Callee))
if (hasPossibleIncompatibleOps(Callee, *getTLI()))
return false;
}

Expand Down Expand Up @@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// change only once and avoid inlining of G into F.

SMEAttrs FAttrs(*F);
SMECallAttrs CallAttrs(Call);
SMECallAttrs CallAttrs(Call, getTLI());

if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
Expand Down
Loading