Skip to content

[AMDGPU] Move common fields out of WaitcntBrackets. NFC. #148864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 59 additions & 63 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,13 @@ class WaitcntGeneratorGFX12Plus : public WaitcntGenerator {
};

class SIInsertWaitcnts {
public:
const GCNSubtarget *ST;
InstCounterType SmemAccessCounter;
InstCounterType MaxCounter;
const unsigned *WaitEventMaskForInst;

private:
const GCNSubtarget *ST = nullptr;
const SIInstrInfo *TII = nullptr;
const SIRegisterInfo *TRI = nullptr;
const MachineRegisterInfo *MRI = nullptr;
Expand All @@ -424,8 +429,6 @@ class SIInsertWaitcnts {
bool Dirty = true;
};

InstCounterType SmemAccessCounter;

MapVector<MachineBasicBlock *, BlockInfo> BlockInfos;

bool ForceEmitWaitcnt[NUM_INST_CNTS];
Expand All @@ -442,7 +445,7 @@ class SIInsertWaitcnts {
// message.
DenseSet<MachineInstr *> ReleaseVGPRInsts;

InstCounterType MaxCounter = NUM_NORMAL_INST_CNTS;
HardwareLimits Limits;

public:
SIInsertWaitcnts(MachineLoopInfo *MLI, MachinePostDominatorTree *PDT,
Expand All @@ -453,6 +456,30 @@ class SIInsertWaitcnts {
(void)ForceVMCounter;
}

unsigned getWaitCountMax(InstCounterType T) const {
switch (T) {
case LOAD_CNT:
return Limits.LoadcntMax;
case DS_CNT:
return Limits.DscntMax;
case EXP_CNT:
return Limits.ExpcntMax;
case STORE_CNT:
return Limits.StorecntMax;
case SAMPLE_CNT:
return Limits.SamplecntMax;
case BVH_CNT:
return Limits.BvhcntMax;
case KM_CNT:
return Limits.KmcntMax;
case X_CNT:
return Limits.XcntMax;
default:
break;
}
return 0;
}

bool shouldFlushVmCnt(MachineLoop *ML, const WaitcntBrackets &Brackets);
bool isPreheaderToFlush(MachineBasicBlock &MBB,
const WaitcntBrackets &ScoreBrackets);
Expand Down Expand Up @@ -568,39 +595,10 @@ class SIInsertWaitcnts {
// "s_waitcnt 0" before use.
class WaitcntBrackets {
public:
WaitcntBrackets(const GCNSubtarget *SubTarget, InstCounterType MaxCounter,
HardwareLimits Limits, const unsigned *WaitEventMaskForInst,
InstCounterType SmemAccessCounter)
: ST(SubTarget), MaxCounter(MaxCounter), Limits(Limits),
WaitEventMaskForInst(WaitEventMaskForInst),
SmemAccessCounter(SmemAccessCounter) {}

unsigned getWaitCountMax(InstCounterType T) const {
switch (T) {
case LOAD_CNT:
return Limits.LoadcntMax;
case DS_CNT:
return Limits.DscntMax;
case EXP_CNT:
return Limits.ExpcntMax;
case STORE_CNT:
return Limits.StorecntMax;
case SAMPLE_CNT:
return Limits.SamplecntMax;
case BVH_CNT:
return Limits.BvhcntMax;
case KM_CNT:
return Limits.KmcntMax;
case X_CNT:
return Limits.XcntMax;
default:
break;
}
return 0;
}
WaitcntBrackets(const SIInsertWaitcnts *Context) : Context(Context) {}

bool isSmemCounter(InstCounterType T) const {
return T == SmemAccessCounter || T == X_CNT;
return T == Context->SmemAccessCounter || T == X_CNT;
}

unsigned getSgprScoresIdx(InstCounterType T) const {
Expand Down Expand Up @@ -658,7 +656,7 @@ class WaitcntBrackets {
return PendingEvents & (1 << E);
}
unsigned hasPendingEvent(InstCounterType T) const {
unsigned HasPending = PendingEvents & WaitEventMaskForInst[T];
unsigned HasPending = PendingEvents & Context->WaitEventMaskForInst[T];
assert((HasPending != 0) == (getScoreRange(T) != 0));
return HasPending;
}
Expand Down Expand Up @@ -686,7 +684,8 @@ class WaitcntBrackets {
}

unsigned getPendingGDSWait() const {
return std::min(getScoreUB(DS_CNT) - LastGDS, getWaitCountMax(DS_CNT) - 1);
return std::min(getScoreUB(DS_CNT) - LastGDS,
Context->getWaitCountMax(DS_CNT) - 1);
}

void setPendingGDS() { LastGDS = ScoreUBs[DS_CNT]; }
Expand All @@ -710,8 +709,9 @@ class WaitcntBrackets {
}

void setStateOnFunctionEntryOrReturn() {
setScoreUB(STORE_CNT, getScoreUB(STORE_CNT) + getWaitCountMax(STORE_CNT));
PendingEvents |= WaitEventMaskForInst[STORE_CNT];
setScoreUB(STORE_CNT,
getScoreUB(STORE_CNT) + Context->getWaitCountMax(STORE_CNT));
PendingEvents |= Context->WaitEventMaskForInst[STORE_CNT];
}

ArrayRef<const MachineInstr *> getLDSDMAStores() const {
Expand Down Expand Up @@ -747,8 +747,8 @@ class WaitcntBrackets {
if (T != EXP_CNT)
return;

if (getScoreRange(EXP_CNT) > getWaitCountMax(EXP_CNT))
ScoreLBs[EXP_CNT] = ScoreUBs[EXP_CNT] - getWaitCountMax(EXP_CNT);
if (getScoreRange(EXP_CNT) > Context->getWaitCountMax(EXP_CNT))
ScoreLBs[EXP_CNT] = ScoreUBs[EXP_CNT] - Context->getWaitCountMax(EXP_CNT);
}

void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
Expand All @@ -763,11 +763,8 @@ class WaitcntBrackets {
const MachineOperand &Op, InstCounterType CntTy,
unsigned Val);

const GCNSubtarget *ST = nullptr;
InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
HardwareLimits Limits = {};
const unsigned *WaitEventMaskForInst;
InstCounterType SmemAccessCounter;
const SIInsertWaitcnts *Context;

unsigned ScoreLBs[NUM_INST_CNTS] = {0};
unsigned ScoreUBs[NUM_INST_CNTS] = {0};
unsigned PendingEvents = 0;
Expand Down Expand Up @@ -829,7 +826,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,

RegInterval Result;

MCRegister MCReg = AMDGPU::getMCReg(Op.getReg(), *ST);
MCRegister MCReg = AMDGPU::getMCReg(Op.getReg(), *Context->ST);
unsigned RegIdx = TRI->getHWRegIndex(MCReg);
assert(isUInt<8>(RegIdx));

Expand Down Expand Up @@ -887,7 +884,7 @@ void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
// this at compile time, so we have to assume it might be applied if the
// instruction supports it).
bool WaitcntBrackets::hasPointSampleAccel(const MachineInstr &MI) const {
if (!ST->hasPointSampleAccel() || !SIInstrInfo::isMIMG(MI))
if (!Context->ST->hasPointSampleAccel() || !SIInstrInfo::isMIMG(MI))
return false;

const AMDGPU::MIMGInfo *Info = AMDGPU::getMIMGInfo(MI.getOpcode());
Expand All @@ -913,7 +910,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
WaitEventType E, MachineInstr &Inst) {
InstCounterType T = eventCounter(WaitEventMaskForInst, E);
InstCounterType T = eventCounter(Context->WaitEventMaskForInst, E);

unsigned UB = getScoreUB(T);
unsigned CurrScore = UB + 1;
Expand Down Expand Up @@ -1082,8 +1079,10 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
}

void WaitcntBrackets::print(raw_ostream &OS) const {
const GCNSubtarget *ST = Context->ST;

OS << '\n';
for (auto T : inst_counter_types(MaxCounter)) {
for (auto T : inst_counter_types(Context->MaxCounter)) {
unsigned SR = getScoreRange(T);

switch (T) {
Expand Down Expand Up @@ -1197,7 +1196,7 @@ void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
// s_waitcnt instruction.
if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
!ST->hasFlatLgkmVMemCountInOrder()) {
!Context->ST->hasFlatLgkmVMemCountInOrder()) {
// If there is a pending FLAT operation, and this is a VMem or LGKM
// waitcnt and the target can report early completion, then we need
// to force a waitcnt 0.
Expand All @@ -1211,7 +1210,7 @@ void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
// If a counter has been maxed out avoid overflow by waiting for
// MAX(CounterType) - 1 instead.
unsigned NeededWait =
std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
std::min(UB - ScoreToWait, Context->getWaitCountMax(T) - 1);
addWait(Wait, T, NeededWait);
}
}
Expand Down Expand Up @@ -1239,7 +1238,7 @@ void WaitcntBrackets::applyWaitcnt(InstCounterType T, unsigned Count) {
setScoreLB(T, std::max(getScoreLB(T), UB - Count));
} else {
setScoreLB(T, UB);
PendingEvents &= ~WaitEventMaskForInst[T];
PendingEvents &= ~Context->WaitEventMaskForInst[T];
}
}

Expand All @@ -1264,7 +1263,7 @@ void WaitcntBrackets::applyXcnt(const AMDGPU::Waitcnt &Wait) {
// the decrement may go out of order.
bool WaitcntBrackets::counterOutOfOrder(InstCounterType T) const {
// Scalar memory read always can go out of order.
if ((T == SmemAccessCounter && hasPendingEvent(SMEM_ACCESS)) ||
if ((T == Context->SmemAccessCounter && hasPendingEvent(SMEM_ACCESS)) ||
(T == X_CNT && hasPendingEvent(SMEM_GROUP)))
return true;
return hasMixedPendingEvents(T);
Expand Down Expand Up @@ -2388,8 +2387,9 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
VgprUB = std::max(VgprUB, Other.VgprUB);
SgprUB = std::max(SgprUB, Other.SgprUB);

for (auto T : inst_counter_types(MaxCounter)) {
for (auto T : inst_counter_types(Context->MaxCounter)) {
// Merge event flags for this counter
const unsigned *WaitEventMaskForInst = Context->WaitEventMaskForInst;
const unsigned OldEvents = PendingEvents & WaitEventMaskForInst[T];
const unsigned OtherEvents = Other.PendingEvents & WaitEventMaskForInst[T];
if (OtherEvents & ~OldEvents)
Expand Down Expand Up @@ -2748,11 +2748,10 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
for (auto T : inst_counter_types())
ForceEmitWaitcnt[T] = false;

const unsigned *WaitEventMaskForInst = WCG->getWaitEventMask();
WaitEventMaskForInst = WCG->getWaitEventMask();

SmemAccessCounter = eventCounter(WaitEventMaskForInst, SMEM_ACCESS);

HardwareLimits Limits = {};
if (ST->hasExtendedWaitCounts()) {
Limits.LoadcntMax = AMDGPU::getLoadcntBitMask(IV);
Limits.DscntMax = AMDGPU::getDscntBitMask(IV);
Expand Down Expand Up @@ -2809,8 +2808,7 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
BuildMI(EntryBB, I, DebugLoc(), TII->get(AMDGPU::S_WAITCNT)).addImm(0);
}

auto NonKernelInitialState = std::make_unique<WaitcntBrackets>(
ST, MaxCounter, Limits, WaitEventMaskForInst, SmemAccessCounter);
auto NonKernelInitialState = std::make_unique<WaitcntBrackets>(this);
NonKernelInitialState->setStateOnFunctionEntryOrReturn();
BlockInfos[&EntryBB].Incoming = std::move(NonKernelInitialState);

Expand Down Expand Up @@ -2841,15 +2839,13 @@ bool SIInsertWaitcnts::run(MachineFunction &MF) {
*Brackets = *BI.Incoming;
} else {
if (!Brackets) {
Brackets = std::make_unique<WaitcntBrackets>(
ST, MaxCounter, Limits, WaitEventMaskForInst, SmemAccessCounter);
Brackets = std::make_unique<WaitcntBrackets>(this);
} else {
// Reinitialize in-place. N.B. do not do this by assigning from a
// temporary because the WaitcntBrackets class is large and it could
// cause this function to use an unreasonable amount of stack space.
Brackets->~WaitcntBrackets();
new (Brackets.get()) WaitcntBrackets(
ST, MaxCounter, Limits, WaitEventMaskForInst, SmemAccessCounter);
new (Brackets.get()) WaitcntBrackets(this);
}
}

Expand Down
Loading