@@ -8154,53 +8154,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8154
8154
if (Subtarget->hasCustomCallingConv())
8155
8155
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
8156
8156
8157
- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8158
- // will be expanded and stored in the static object later using a pseudonode.
8159
- if (Attrs.hasZAState()) {
8160
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8161
- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8162
- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8163
- DAG.getConstant(1, DL, MVT::i32));
8164
-
8165
- SDValue Buffer;
8166
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8167
- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8168
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8169
- } else {
8170
- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8171
- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8172
- DAG.getVTList(MVT::i64, MVT::Other),
8173
- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8174
- MFI.CreateVariableSizedObject(Align(16), nullptr);
8175
- }
8176
- Chain = DAG.getNode(
8177
- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8178
- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8179
- } else if (Attrs.hasAgnosticZAInterface()) {
8180
- // Call __arm_sme_state_size().
8181
- SDValue BufferSize =
8182
- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8183
- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8184
- Chain = BufferSize.getValue(1);
8185
-
8186
- SDValue Buffer;
8187
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8188
- Buffer =
8189
- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8190
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8191
- } else {
8192
- // Allocate space dynamically.
8193
- Buffer = DAG.getNode(
8194
- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8195
- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8196
- MFI.CreateVariableSizedObject(Align(16), nullptr);
8157
+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8158
+ // Old SME ABI lowering (deprecated):
8159
+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8160
+ // will be expanded and stored in the static object later using a
8161
+ // pseudonode.
8162
+ if (Attrs.hasZAState()) {
8163
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8164
+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8165
+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8166
+ DAG.getConstant(1, DL, MVT::i32));
8167
+ SDValue Buffer;
8168
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8169
+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8170
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8171
+ } else {
8172
+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8173
+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8174
+ DAG.getVTList(MVT::i64, MVT::Other),
8175
+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8176
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8177
+ }
8178
+ Chain = DAG.getNode(
8179
+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8180
+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8181
+ } else if (Attrs.hasAgnosticZAInterface()) {
8182
+ // Call __arm_sme_state_size().
8183
+ SDValue BufferSize =
8184
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8185
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8186
+ Chain = BufferSize.getValue(1);
8187
+ SDValue Buffer;
8188
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8189
+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8190
+ DAG.getVTList(MVT::i64, MVT::Other),
8191
+ {Chain, BufferSize});
8192
+ } else {
8193
+ // Allocate space dynamically.
8194
+ Buffer = DAG.getNode(
8195
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8196
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8197
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8198
+ }
8199
+ // Copy the value to a virtual register, and save that in FuncInfo.
8200
+ Register BufferPtr =
8201
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8202
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8203
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8197
8204
}
8198
-
8199
- // Copy the value to a virtual register, and save that in FuncInfo.
8200
- Register BufferPtr =
8201
- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8202
- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8203
- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8204
8205
}
8205
8206
8206
8207
if (CallConv == CallingConv::PreserveNone) {
@@ -8217,6 +8218,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8217
8218
}
8218
8219
}
8219
8220
8221
+ if (Subtarget->useNewSMEABILowering()) {
8222
+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8223
+ if (Attrs.isNewZT0())
8224
+ Chain = DAG.getNode(
8225
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8226
+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8227
+ DAG.getTargetConstant(0, DL, MVT::i32));
8228
+ }
8229
+
8220
8230
return Chain;
8221
8231
}
8222
8232
@@ -8781,14 +8791,12 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8781
8791
MachineFunction &MF = DAG.getMachineFunction();
8782
8792
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8783
8793
FuncInfo->setSMESaveBufferUsed();
8784
-
8785
8794
TargetLowering::ArgListTy Args;
8786
8795
TargetLowering::ArgListEntry Entry;
8787
8796
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8788
8797
Entry.Node =
8789
8798
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8790
8799
Args.push_back(Entry);
8791
-
8792
8800
SDValue Callee =
8793
8801
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8794
8802
TLI.getPointerTy(DAG.getDataLayout()));
@@ -8906,6 +8914,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8906
8914
*DAG.getContext());
8907
8915
RetCCInfo.AnalyzeCallResult(Ins, RetCC);
8908
8916
8917
+ // Determine whether we need any streaming mode changes.
8918
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
8919
+
8909
8920
// Check callee args/returns for SVE registers and set calling convention
8910
8921
// accordingly.
8911
8922
if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -8919,14 +8930,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8919
8930
CallConv = CallingConv::AArch64_SVE_VectorCall;
8920
8931
}
8921
8932
8933
+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
8934
+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
8935
+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
8936
+ // TODO: Handle agnostic ZA functions.
8937
+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
8938
+ return std::nullopt;
8939
+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
8940
+ return std::nullopt;
8941
+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
8942
+ : AArch64ISD::INOUT_ZA_USE;
8943
+ }();
8944
+
8922
8945
if (IsTailCall) {
8923
8946
// Check if it's really possible to do a tail call.
8924
8947
IsTailCall = isEligibleForTailCallOptimization(CLI);
8925
8948
8926
8949
// A sibling call is one where we're under the usual C ABI and not planning
8927
8950
// to change that but can still do a tail call:
8928
- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
8929
- CallConv != CallingConv::SwiftTail)
8951
+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
8952
+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
8930
8953
IsSibCall = true;
8931
8954
8932
8955
if (IsTailCall)
@@ -8978,9 +9001,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8978
9001
assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
8979
9002
}
8980
9003
8981
- // Determine whether we need any streaming mode changes.
8982
- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
8983
-
8984
9004
auto DescribeCallsite =
8985
9005
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
8986
9006
R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -8994,7 +9014,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8994
9014
return R;
8995
9015
};
8996
9016
8997
- bool RequiresLazySave = CallAttrs.requiresLazySave();
9017
+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
8998
9018
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
8999
9019
if (RequiresLazySave) {
9000
9020
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9076,10 +9096,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9076
9096
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
9077
9097
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
9078
9098
9079
- // Adjust the stack pointer for the new arguments...
9099
+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
9080
9100
// These operations are automatically eliminated by the prolog/epilog pass
9081
- if (!IsSibCall)
9101
+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9102
+ "ZA markers require CALLSEQ_START");
9103
+ if (!IsSibCall) {
9082
9104
Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9105
+ if (ZAMarkerNode) {
9106
+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9107
+ // using a chain can result in incorrect scheduling. The markers referer
9108
+ // to the position just before the CALLSEQ_START (though occur after as
9109
+ // CALLSEQ_START lacks in-glue).
9110
+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9111
+ {Chain, Chain.getValue(1)});
9112
+ }
9113
+ }
9083
9114
9084
9115
SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
9085
9116
getPointerTy(DAG.getDataLayout()));
@@ -9551,7 +9582,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9551
9582
}
9552
9583
}
9553
9584
9554
- if (CallAttrs.requiresEnablingZAAfterCall())
9585
+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
9555
9586
// Unconditionally resume ZA.
9556
9587
Result = DAG.getNode(
9557
9588
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9572,7 +9603,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9572
9603
SDValue TPIDR2_EL0 = DAG.getNode(
9573
9604
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
9574
9605
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9575
-
9576
9606
// Copy the address of the TPIDR2 block into X0 before 'calling' the
9577
9607
// RESTORE_ZA pseudo.
9578
9608
SDValue Glue;
@@ -9584,7 +9614,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9584
9614
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
9585
9615
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
9586
9616
RestoreRoutine, RegMask, Result.getValue(1)});
9587
-
9588
9617
// Finally reset the TPIDR2_EL0 register to 0.
9589
9618
Result = DAG.getNode(
9590
9619
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments