Skip to content

Commit 95c3873

Browse files
committed
[AArch64][SME] Port all SME routines to RuntimeLibcalls
This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention. Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so tweaked the output to avoid the need for variables.
1 parent fac7453 commit 95c3873

14 files changed

+213
-180
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,12 @@ class LLVM_ABI TargetLoweringBase {
35583558
return Libcalls.getLibcallImplName(Call);
35593559
}
35603560

3561+
/// Check if this is valid libcall for the current module, otherwise
3562+
/// RTLIB::Unsupported.
3563+
RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const {
3564+
return Libcalls.getSupportedLibcallImpl(FuncName);
3565+
}
3566+
35613567
const char *getMemcpyName() const { return Libcalls.getMemcpyName(); }
35623568

35633569
/// Get the comparison predicate that's to be used to test the result of the

llvm/include/llvm/IR/RuntimeLibcalls.td

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,17 @@ multiclass LibmLongDoubleLibCall<string libcall_basename = !toupper(NAME),
405405
def SC_MEMCPY : RuntimeLibcall;
406406
def SC_MEMMOVE : RuntimeLibcall;
407407
def SC_MEMSET : RuntimeLibcall;
408+
def SC_MEMCHR: RuntimeLibcall;
409+
410+
// AArch64 SME ABI calls
411+
def SMEABI_SME_STATE : RuntimeLibcall;
412+
def SMEABI_TPIDR2_SAVE : RuntimeLibcall;
413+
def SMEABI_ZA_DISABLE : RuntimeLibcall;
414+
def SMEABI_TPIDR2_RESTORE : RuntimeLibcall;
415+
def SMEABI_GET_CURRENT_VG : RuntimeLibcall;
416+
def SMEABI_SME_STATE_SIZE : RuntimeLibcall;
417+
def SMEABI_SME_SAVE : RuntimeLibcall;
418+
def SMEABI_SME_RESTORE : RuntimeLibcall;
408419

409420
// ARM EABI calls
410421
def AEABI_MEMCPY4 : RuntimeLibcall; // Align 4
@@ -1223,8 +1234,35 @@ defset list<RuntimeLibcallImpl> AArch64LibcallImpls = {
12231234
def __arm_sc_memcpy : RuntimeLibcallImpl<SC_MEMCPY>;
12241235
def __arm_sc_memmove : RuntimeLibcallImpl<SC_MEMMOVE>;
12251236
def __arm_sc_memset : RuntimeLibcallImpl<SC_MEMSET>;
1237+
def __arm_sc_memchr : RuntimeLibcallImpl<SC_MEMCHR>;
12261238
} // End AArch64LibcallImpls
12271239

1240+
def __arm_sme_state : RuntimeLibcallImpl<SMEABI_SME_STATE>;
1241+
def __arm_tpidr2_save : RuntimeLibcallImpl<SMEABI_TPIDR2_SAVE>;
1242+
def __arm_za_disable : RuntimeLibcallImpl<SMEABI_ZA_DISABLE>;
1243+
def __arm_tpidr2_restore : RuntimeLibcallImpl<SMEABI_TPIDR2_RESTORE>;
1244+
def __arm_get_current_vg : RuntimeLibcallImpl<SMEABI_GET_CURRENT_VG>;
1245+
def __arm_sme_state_size : RuntimeLibcallImpl<SMEABI_SME_STATE_SIZE>;
1246+
def __arm_sme_save : RuntimeLibcallImpl<SMEABI_SME_SAVE>;
1247+
def __arm_sme_restore : RuntimeLibcallImpl<SMEABI_SME_RESTORE>;
1248+
1249+
def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add
1250+
__arm_tpidr2_save,
1251+
__arm_za_disable,
1252+
__arm_tpidr2_restore),
1253+
SMEABI_PreserveMost_From_X0>;
1254+
1255+
def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add
1256+
__arm_get_current_vg,
1257+
__arm_sme_state_size,
1258+
__arm_sme_save,
1259+
__arm_sme_restore),
1260+
SMEABI_PreserveMost_From_X1>;
1261+
1262+
def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add
1263+
__arm_sme_state),
1264+
SMEABI_PreserveMost_From_X2>;
1265+
12281266
def isAArch64_ExceptArm64EC
12291267
: RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">;
12301268
def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">;
@@ -1244,7 +1282,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary<
12441282
LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128,
12451283
DefaultLibmExp10,
12461284
DefaultStackProtector,
1247-
SecurityCheckCookieIfWinMSVC)
1285+
SecurityCheckCookieIfWinMSVC,
1286+
SMEABI_LibCalls_PreserveMost_From_X0,
1287+
SMEABI_LibCalls_PreserveMost_From_X1,
1288+
SMEABI_LibCalls_PreserveMost_From_X2)
12481289
>;
12491290

12501291
// Prepend a # to every name

llvm/include/llvm/IR/RuntimeLibcallsImpl.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>;
3636
def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>;
3737
def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>;
3838
def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>;
39+
def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>;
40+
def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>;
41+
def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>;
3942

4043
/// Abstract definition for functionality the compiler may need to
4144
/// emit a call to. Emits the RTLIB::Libcall enum - This enum defines

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
14871487

14881488
if (Opc == AArch64::BL) {
14891489
auto Op1 = MBBI->getOperand(0);
1490-
return Op1.isSymbol() &&
1491-
(StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
1490+
auto &TLI =
1491+
*MBBI->getMF()->getSubtarget<AArch64Subtarget>().getTargetLowering();
1492+
char const *GetCurrentVG =
1493+
TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG);
1494+
return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG;
14921495
}
14931496
}
14941497

@@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
34683471
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
34693472
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
34703473
MachineFunction &MF = *MBB.getParent();
3474+
auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering();
34713475
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
34723476
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
34733477
bool NeedsWinCFI = needsWinCFI(MF);
@@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
35813585
.addReg(AArch64::X0, RegState::Implicit)
35823586
.setMIFlag(MachineInstr::FrameSetup);
35833587

3584-
const uint32_t *RegMask = TRI->getCallPreservedMask(
3585-
MF,
3586-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
3588+
RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG;
3589+
const uint32_t *RegMask =
3590+
TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC));
35873591
BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
3588-
.addExternalSymbol("__arm_get_current_vg")
3592+
.addExternalSymbol(TLI.getLibcallName(LC))
35893593
.addRegMask(RegMask)
35903594
.addReg(AArch64::X0, RegState::ImplicitDefine)
35913595
.setMIFlag(MachineInstr::FrameSetup);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
30833083
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
30843084
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
30853085
if (FuncInfo->isSMESaveBufferUsed()) {
3086+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
30863087
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
30873088
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3088-
.addExternalSymbol("__arm_sme_state_size")
3089+
.addExternalSymbol(getLibcallName(LC))
30893090
.addReg(AArch64::X0, RegState::ImplicitDefine)
3090-
.addRegMask(TRI->getCallPreservedMask(
3091-
*MF, CallingConv::
3092-
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3091+
.addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
30933092
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
30943093
MI.getOperand(0).getReg())
30953094
.addReg(AArch64::X0);
@@ -5711,15 +5710,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
57115710
SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
57125711
SDValue Chain, SDLoc DL,
57135712
EVT VT) const {
5714-
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
5713+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
5714+
SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
57155715
getPointerTy(DAG.getDataLayout()));
57165716
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
57175717
Type *RetTy = StructType::get(Int64Ty, Int64Ty);
57185718
TargetLowering::CallLoweringInfo CLI(DAG);
57195719
ArgListTy Args;
57205720
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
5721-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
5722-
RetTy, Callee, std::move(Args));
5721+
getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
57235722
std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
57245723
SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
57255724
return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
@@ -8564,12 +8563,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
85648563
}
85658564

85668565
static SMECallAttrs
8567-
getSMECallAttrs(const Function &Caller,
8566+
getSMECallAttrs(const Function &Caller, const TargetLowering &TLI,
85688567
const TargetLowering::CallLoweringInfo &CLI) {
85698568
if (CLI.CB)
8570-
return SMECallAttrs(*CLI.CB);
8569+
return SMECallAttrs(*CLI.CB, &TLI);
85718570
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8572-
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
8571+
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
85738572
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
85748573
}
85758574

@@ -8591,7 +8590,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
85918590

85928591
// SME Streaming functions are not eligible for TCO as they may require
85938592
// the streaming mode or ZA to be restored after returning from the call.
8594-
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8593+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
85958594
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
85968595
CallAttrs.requiresPreservingAllZAState() ||
85978596
CallAttrs.caller().hasStreamingBody())
@@ -8879,14 +8878,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
88798878
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
88808879
Args.push_back(Entry);
88818880

8882-
SDValue Callee =
8883-
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8884-
TLI.getPointerTy(DAG.getDataLayout()));
8881+
RTLIB::Libcall LC =
8882+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8883+
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8884+
TLI.getPointerTy(DAG.getDataLayout()));
88858885
auto *RetTy = Type::getVoidTy(*DAG.getContext());
88868886
TargetLowering::CallLoweringInfo CLI(DAG);
88878887
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8888-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8889-
Callee, std::move(Args));
8888+
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
88908889
return TLI.LowerCallTo(CLI).second;
88918890
}
88928891

@@ -9074,7 +9073,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90749073
}
90759074

90769075
// Determine whether we need any streaming mode changes.
9077-
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9076+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
90789077

90799078
auto DescribeCallsite =
90809079
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9659,11 +9658,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96599658

96609659
if (RequiresLazySave) {
96619660
// Conditionally restore the lazy save using a pseudo node.
9661+
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
96629662
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
96639663
SDValue RegMask = DAG.getRegisterMask(
9664-
TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
9664+
TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
96659665
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
9666-
"__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
9666+
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
96679667
SDValue TPIDR2_EL0 = DAG.getNode(
96689668
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
96699669
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
@@ -29004,7 +29004,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2900429004

2900529005
// Checks to allow the use of SME instructions
2900629006
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
29007-
auto CallAttrs = SMECallAttrs(*Base);
29007+
auto CallAttrs = SMECallAttrs(*Base, this);
2900829008
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
2900929009
CallAttrs.requiresPreservingZT0() ||
2901029010
CallAttrs.requiresPreservingAllZAState())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,16 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
220220
static cl::opt<bool> EnableScalableAutovecInStreamingMode(
221221
"enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
222222

223-
static bool isSMEABIRoutineCall(const CallInst &CI) {
223+
static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) {
224224
const auto *F = CI.getCalledFunction();
225-
return F && StringSwitch<bool>(F->getName())
226-
.Case("__arm_sme_state", true)
227-
.Case("__arm_tpidr2_save", true)
228-
.Case("__arm_tpidr2_restore", true)
229-
.Case("__arm_za_disable", true)
230-
.Default(false);
225+
return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
231226
}
232227

233228
/// Returns true if the function has explicit operations that can only be
234229
/// lowered using incompatible instructions for the selected mode. This also
235230
/// returns true if the function F may use or modify ZA state.
236-
static bool hasPossibleIncompatibleOps(const Function *F) {
231+
static bool hasPossibleIncompatibleOps(const Function *F,
232+
const TargetLowering &TLI) {
237233
for (const BasicBlock &BB : *F) {
238234
for (const Instruction &I : BB) {
239235
// Be conservative for now and assume that any call to inline asm or to
@@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
242238
// all native LLVM instructions can be lowered to compatible instructions.
243239
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
244240
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
245-
isSMEABIRoutineCall(cast<CallInst>(I))))
241+
isSMEABIRoutineCall(cast<CallInst>(I), TLI)))
246242
return true;
247243
}
248244
}
@@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
290286
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
291287
CallAttrs.requiresPreservingZT0() ||
292288
CallAttrs.requiresPreservingAllZAState()) {
293-
if (hasPossibleIncompatibleOps(Callee))
289+
if (hasPossibleIncompatibleOps(Callee, *getTLI()))
294290
return false;
295291
}
296292

@@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
357353
// change only once and avoid inlining of G into F.
358354

359355
SMEAttrs FAttrs(*F);
360-
SMECallAttrs CallAttrs(Call);
356+
SMECallAttrs CallAttrs(Call, getTLI());
361357

362358
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
363359
if (F == Call.getCaller()) // (1)

llvm/lib/Target/AArch64/SMEABIPass.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515
#include "AArch64.h"
1616
#include "Utils/AArch64SMEAttributes.h"
1717
#include "llvm/ADT/StringRef.h"
18+
#include "llvm/CodeGen/TargetLowering.h"
19+
#include "llvm/CodeGen/TargetPassConfig.h"
20+
#include "llvm/CodeGen/TargetSubtargetInfo.h"
1821
#include "llvm/IR/IRBuilder.h"
1922
#include "llvm/IR/Instructions.h"
2023
#include "llvm/IR/IntrinsicsAArch64.h"
2124
#include "llvm/IR/LLVMContext.h"
2225
#include "llvm/IR/Module.h"
26+
#include "llvm/IR/RuntimeLibcalls.h"
27+
#include "llvm/Target/TargetMachine.h"
2328
#include "llvm/Transforms/Utils/Cloning.h"
2429

2530
using namespace llvm;
@@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass {
3338

3439
bool runOnFunction(Function &F) override;
3540

41+
void getAnalysisUsage(AnalysisUsage &AU) const override {
42+
AU.addRequired<TargetPassConfig>();
43+
}
44+
3645
private:
3746
bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
38-
SMEAttrs FnAttrs);
47+
SMEAttrs FnAttrs, const TargetLowering &TLI);
3948
};
4049
} // end anonymous namespace
4150

@@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
5160
//===----------------------------------------------------------------------===//
5261

5362
// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
54-
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
63+
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI,
64+
bool ZT0IsUndef = false) {
5565
auto &Ctx = M->getContext();
5666
auto *TPIDR2SaveTy =
5767
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
5868
auto Attrs =
5969
AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
70+
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE;
6071
FunctionCallee Callee =
61-
M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
72+
M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs);
6273
CallInst *Call = Builder.CreateCall(Callee);
6374

6475
// If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
@@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
6778
if (ZT0IsUndef)
6879
Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
6980

70-
Call->setCallingConv(
71-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
81+
Call->setCallingConv(TLI.getLibcallCallingConv(LC));
7282

7383
// A save to TPIDR2 should be followed by clearing TPIDR2_EL0.
7484
Function *WriteIntr =
@@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
98108
/// interface if it does not share ZA or ZT0.
99109
///
100110
bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
101-
IRBuilder<> &Builder, SMEAttrs FnAttrs) {
111+
IRBuilder<> &Builder, SMEAttrs FnAttrs,
112+
const TargetLowering &TLI) {
102113
LLVMContext &Context = F->getContext();
103114
BasicBlock *OrigBB = &F->getEntryBlock();
104115
Builder.SetInsertPoint(&OrigBB->front());
@@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
124135

125136
// Create a call __arm_tpidr2_save, which commits the lazy save.
126137
Builder.SetInsertPoint(&SaveBB->back());
127-
emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
138+
emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
128139

129140
// Enable pstate.za at the start of the function.
130141
Builder.SetInsertPoint(&OrigBB->front());
@@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) {
172183
if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za"))
173184
return false;
174185

186+
const TargetMachine &TM =
187+
getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
188+
const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering();
189+
175190
bool Changed = false;
176191
SMEAttrs FnAttrs(F);
177192
if (FnAttrs.isNewZA() || FnAttrs.isNewZT0())
178-
Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
193+
Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI);
179194

180195
return Changed;
181196
}

0 commit comments

Comments
 (0)