Skip to content

Commit e47d5eb

Browse files
authored
[AMDGPU] Hazard handling for gfx1250 wmma instructions (#149865)
If both instructions are xdl WMMA, hazard exists when the first WMMA writes a register (D0) and the second WMMA reads it (A1/B1/Index1). If the first instruction is a xdl WMMA, and the second one is a VALU, three kinds of hazards exist: WMMA writes (D0), VALU reads (Use1); WMMA writes (D0), VALU writes (D1); WMMA reads (A0/B0.Index0), VALU writes (D1). The actual number of hazard slots depends on the categories of the first xdl WMMA as well as whether the second instruction is a xdl WMMA or VALU. If there is not enough unrelated VALUs in between the two instructions, appropriate number (to cover the missing) of V_NOPs will be inserted to satisfy the hazard handling requirements.
1 parent 9052a85 commit e47d5eb

File tree

4 files changed

+2513
-3
lines changed

4 files changed

+2513
-3
lines changed

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
520520
const MachineInstr *MI, IsExpiredFn IsExpired) {
521521
DenseSet<const MachineBasicBlock *> Visited;
522522
return getWaitStatesSince(IsHazard, MI->getParent(),
523-
std::next(MI->getReverseIterator()),
524-
0, IsExpired, Visited);
523+
std::next(MI->getReverseIterator()), 0, IsExpired,
524+
Visited, SIInstrInfo::getNumWaitStates);
525525
}
526526

527527
int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1190,7 +1190,8 @@ void GCNHazardRecognizer::fixHazards(MachineInstr *MI) {
11901190
fixVALUPartialForwardingHazard(MI);
11911191
fixVALUTransUseHazard(MI);
11921192
fixVALUTransCoexecutionHazards(MI);
1193-
fixWMMAHazards(MI);
1193+
fixWMMAHazards(MI); // fall-through if co-execution is enabled.
1194+
fixWMMACoexecutionHazards(MI);
11941195
fixShift64HighRegBug(MI);
11951196
fixVALUMaskWriteHazard(MI);
11961197
fixRequiredExportPriority(MI);
@@ -1909,6 +1910,182 @@ bool GCNHazardRecognizer::fixWMMAHazards(MachineInstr *MI) {
19091910
return true;
19101911
}
19111912

1913+
static bool isCoexecutableVALUInst(const MachineInstr &MI) {
1914+
return SIInstrInfo::isVALU(MI) && !SIInstrInfo::isTRANS(MI) &&
1915+
!SIInstrInfo::isWMMA(MI) && !SIInstrInfo::isSWMMAC(MI); // What else?
1916+
}
1917+
1918+
static bool IsWMMAHazardInstInCategory(const MachineInstr &MI,
1919+
const SIInstrInfo *TII, unsigned Latency,
1920+
unsigned Category) {
1921+
assert(TII->isXDLWMMA(MI) && (Latency == 8 || Latency == 16) &&
1922+
"Handle me if the xdl wmma instruction latency changes");
1923+
1924+
switch (Category) {
1925+
case 0: // Dense WMMA Instructions:
1926+
// WMMA_*F16, WMMA_*BF16
1927+
// WMMA_*FP8FP8
1928+
// WMMA_*FP8BF8
1929+
// WMMA_*BF8FP8
1930+
// WMMA_*BF8BF8
1931+
// WMMA_*F8F6F4 if SRCA & SRCB != F8
1932+
return Latency == 8 && SIInstrInfo::isWMMA(MI);
1933+
1934+
case 1: // Dense WMMA Instructions:
1935+
// WMMA_IU8
1936+
// WMMA_IU4
1937+
// WMMA_*F8F6F4 if SRCA OR SRCB == F8
1938+
return Latency == 16 && SIInstrInfo::isWMMA(MI);
1939+
1940+
case 2: // Dense SWMMAC Instructions
1941+
// SWMMAC_*F16, SWMMAC_*BF16,
1942+
// SWMMAC_*FP8FP8
1943+
// SWMMAC_*BF8FP8
1944+
// SWMMAC_*FP8BF8
1945+
// SWMMAC_*BF8BF8
1946+
return Latency == 8 && SIInstrInfo::isSWMMAC(MI);
1947+
1948+
case 3: // Sparse WMMA Instructions:
1949+
// SWMMAC_IU8
1950+
// SWMMAC_IU4
1951+
return Latency == 16 && SIInstrInfo::isSWMMAC(MI);
1952+
default:
1953+
break;
1954+
} // end switch.
1955+
1956+
return false;
1957+
}
1958+
1959+
bool GCNHazardRecognizer::fixWMMACoexecutionHazards(MachineInstr *MI) {
1960+
if (!AMDGPU::isGFX1250(ST))
1961+
return false;
1962+
1963+
const SIInstrInfo *TII = ST.getInstrInfo();
1964+
if (!TII->isXDLWMMA(*MI) && !isCoexecutableVALUInst(*MI))
1965+
return false;
1966+
1967+
const SIRegisterInfo *TRI = ST.getRegisterInfo();
1968+
1969+
// WaitStates here is the number of V_NOPs or unrelated VALU instructions must
1970+
// be in between the first WMMA and the second instruction to cover the hazard
1971+
// (WMMAWaitStates if the second is also a WMMA, VALUWaitStates if the second
1972+
// is a VALU). Refer to SPG 4.6.12.1. "Requirements for WMMA data hazards" for
1973+
// numbers, which depends on the category of the first WMMA.
1974+
const int WMMAWaitStates[] = {5, 9, 3, 5};
1975+
const int VALUWaitStates[] = {4, 8, 2, 4};
1976+
unsigned Category = 0;
1977+
1978+
auto IsWMMAHazardFn = [MI, TII, TRI, &Category, this](const MachineInstr &I) {
1979+
if (!TII->isXDLWMMA(I))
1980+
return false;
1981+
1982+
unsigned Latency = TSchedModel.computeInstrLatency(&I);
1983+
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
1984+
return false;
1985+
1986+
Register D0 = TII->getNamedOperand(I, AMDGPU::OpName::vdst)->getReg();
1987+
Register A1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src0)->getReg();
1988+
Register B1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src1)->getReg();
1989+
1990+
// WMMA0 wrires (D0), WMMA1 reads (A1/B1/Idx1).
1991+
if (TRI->regsOverlap(D0, A1) || TRI->regsOverlap(D0, B1))
1992+
return true;
1993+
1994+
if (SIInstrInfo::isSWMMAC(*MI)) {
1995+
Register Idx1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2)->getReg();
1996+
if (TRI->regsOverlap(D0, Idx1))
1997+
return true;
1998+
}
1999+
2000+
return false;
2001+
};
2002+
2003+
auto IsVALUHazardFn = [MI, TII, TRI, &Category, this](const MachineInstr &I) {
2004+
if (!TII->isXDLWMMA(I))
2005+
return false;
2006+
2007+
unsigned Latency = TSchedModel.computeInstrLatency(&I);
2008+
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
2009+
return false;
2010+
2011+
// WMMA writes, VALU reads.
2012+
Register D0 = TII->getNamedOperand(I, AMDGPU::OpName::vdst)->getReg();
2013+
for (const MachineOperand &ValuUse : MI->explicit_uses()) {
2014+
if (ValuUse.isReg() && TRI->regsOverlap(D0, ValuUse.getReg()))
2015+
return true;
2016+
}
2017+
2018+
auto *ValuDst = TII->getNamedOperand(*MI, AMDGPU::OpName::vdst);
2019+
if (!ValuDst || !ValuDst->isReg())
2020+
return false;
2021+
Register D1 = ValuDst->getReg();
2022+
2023+
// WMMA writes, VALU writes.
2024+
if (TRI->regsOverlap(D0, D1))
2025+
return true;
2026+
2027+
// WMMA reads, VALU writes.
2028+
Register A0 = TII->getNamedOperand(I, AMDGPU::OpName::src0)->getReg();
2029+
Register B0 = TII->getNamedOperand(I, AMDGPU::OpName::src1)->getReg();
2030+
if (TRI->regsOverlap(A0, D1) || TRI->regsOverlap(B0, D1))
2031+
return true;
2032+
2033+
if (SIInstrInfo::isSWMMAC(I)) {
2034+
Register Idx0 = TII->getNamedOperand(I, AMDGPU::OpName::src2)->getReg();
2035+
if (TRI->regsOverlap(D1, Idx0))
2036+
return true;
2037+
}
2038+
2039+
return false;
2040+
};
2041+
2042+
int Limit = 0;
2043+
auto IsExpiredFn = [&Limit](const MachineInstr &, int WaitStates) {
2044+
return WaitStates >= Limit;
2045+
};
2046+
2047+
auto GetWaitStatesFn = [](const MachineInstr &I) {
2048+
return SIInstrInfo::isVALU(I) ? 1 : 0;
2049+
};
2050+
2051+
int WaitStatesNeeded = -1;
2052+
if (TII->isXDLWMMA(*MI)) {
2053+
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2054+
Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2055+
DenseSet<const MachineBasicBlock *> Visited;
2056+
// '::getWaitStatesSince' returns the number of VALUs in between if hazard
2057+
// exists, and INT_MAX if there is no hazard. As a result, a negative
2058+
// WaitStatesNeeded here means no hazard, and we will continue to search
2059+
// for other categories.
2060+
WaitStatesNeeded =
2061+
Limit - ::getWaitStatesSince(IsWMMAHazardFn, MI->getParent(),
2062+
std::next(MI->getReverseIterator()), 0,
2063+
IsExpiredFn, Visited, GetWaitStatesFn);
2064+
}
2065+
} else { // Must be a co-executable VALU.
2066+
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2067+
Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2068+
DenseSet<const MachineBasicBlock *> Visited;
2069+
// '::getWaitStatesSince' returns the number of VALUs in between if hazard
2070+
// exists, and INT_MAX if there is no hazard. As a result, a negative
2071+
// WaitStatesNeeded here means no hazard, and we will continue to search
2072+
// for other categories.
2073+
WaitStatesNeeded =
2074+
Limit - ::getWaitStatesSince(IsVALUHazardFn, MI->getParent(),
2075+
std::next(MI->getReverseIterator()), 0,
2076+
IsExpiredFn, Visited, GetWaitStatesFn);
2077+
}
2078+
}
2079+
2080+
// WaitStatesNeeded now is the number of V_NOPs we need to insert, negative
2081+
// means not needed.
2082+
for (int i = 0; i < WaitStatesNeeded; i++)
2083+
BuildMI(*MI->getParent(), MI, MI->getDebugLoc(),
2084+
TII->get(AMDGPU::V_NOP_e32));
2085+
2086+
return true;
2087+
}
2088+
19122089
bool GCNHazardRecognizer::fixShift64HighRegBug(MachineInstr *MI) {
19132090
if (!ST.hasShift64HighRegBug())
19142091
return false;

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class GCNHazardRecognizer final : public ScheduleHazardRecognizer {
106106
bool fixVALUTransUseHazard(MachineInstr *MI);
107107
bool fixVALUTransCoexecutionHazards(MachineInstr *MI);
108108
bool fixWMMAHazards(MachineInstr *MI);
109+
bool fixWMMACoexecutionHazards(MachineInstr *MI);
109110
bool fixShift64HighRegBug(MachineInstr *MI);
110111
bool fixVALUMaskWriteHazard(MachineInstr *MI);
111112
bool fixRequiredExportPriority(MachineInstr *MI);

0 commit comments

Comments
 (0)