@@ -520,8 +520,8 @@ static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
520
520
const MachineInstr *MI, IsExpiredFn IsExpired) {
521
521
DenseSet<const MachineBasicBlock *> Visited;
522
522
return getWaitStatesSince (IsHazard, MI->getParent (),
523
- std::next (MI->getReverseIterator ()),
524
- 0 , IsExpired, Visited );
523
+ std::next (MI->getReverseIterator ()), 0 , IsExpired,
524
+ Visited, SIInstrInfo::getNumWaitStates );
525
525
}
526
526
527
527
int GCNHazardRecognizer::getWaitStatesSince (IsHazardFn IsHazard, int Limit) {
@@ -1190,7 +1190,8 @@ void GCNHazardRecognizer::fixHazards(MachineInstr *MI) {
1190
1190
fixVALUPartialForwardingHazard (MI);
1191
1191
fixVALUTransUseHazard (MI);
1192
1192
fixVALUTransCoexecutionHazards (MI);
1193
- fixWMMAHazards (MI);
1193
+ fixWMMAHazards (MI); // fall-through if co-execution is enabled.
1194
+ fixWMMACoexecutionHazards (MI);
1194
1195
fixShift64HighRegBug (MI);
1195
1196
fixVALUMaskWriteHazard (MI);
1196
1197
fixRequiredExportPriority (MI);
@@ -1909,6 +1910,182 @@ bool GCNHazardRecognizer::fixWMMAHazards(MachineInstr *MI) {
1909
1910
return true ;
1910
1911
}
1911
1912
1913
+ static bool isCoexecutableVALUInst (const MachineInstr &MI) {
1914
+ return SIInstrInfo::isVALU (MI) && !SIInstrInfo::isTRANS (MI) &&
1915
+ !SIInstrInfo::isWMMA (MI) && !SIInstrInfo::isSWMMAC (MI); // What else?
1916
+ }
1917
+
1918
+ static bool IsWMMAHazardInstInCategory (const MachineInstr &MI,
1919
+ const SIInstrInfo *TII, unsigned Latency,
1920
+ unsigned Category) {
1921
+ assert (TII->isXDLWMMA (MI) && (Latency == 8 || Latency == 16 ) &&
1922
+ " Handle me if the xdl wmma instruction latency changes" );
1923
+
1924
+ switch (Category) {
1925
+ case 0 : // Dense WMMA Instructions:
1926
+ // WMMA_*F16, WMMA_*BF16
1927
+ // WMMA_*FP8FP8
1928
+ // WMMA_*FP8BF8
1929
+ // WMMA_*BF8FP8
1930
+ // WMMA_*BF8BF8
1931
+ // WMMA_*F8F6F4 if SRCA & SRCB != F8
1932
+ return Latency == 8 && SIInstrInfo::isWMMA (MI);
1933
+
1934
+ case 1 : // Dense WMMA Instructions:
1935
+ // WMMA_IU8
1936
+ // WMMA_IU4
1937
+ // WMMA_*F8F6F4 if SRCA OR SRCB == F8
1938
+ return Latency == 16 && SIInstrInfo::isWMMA (MI);
1939
+
1940
+ case 2 : // Dense SWMMAC Instructions
1941
+ // SWMMAC_*F16, SWMMAC_*BF16,
1942
+ // SWMMAC_*FP8FP8
1943
+ // SWMMAC_*BF8FP8
1944
+ // SWMMAC_*FP8BF8
1945
+ // SWMMAC_*BF8BF8
1946
+ return Latency == 8 && SIInstrInfo::isSWMMAC (MI);
1947
+
1948
+ case 3 : // Sparse WMMA Instructions:
1949
+ // SWMMAC_IU8
1950
+ // SWMMAC_IU4
1951
+ return Latency == 16 && SIInstrInfo::isSWMMAC (MI);
1952
+ default :
1953
+ break ;
1954
+ } // end switch.
1955
+
1956
+ return false ;
1957
+ }
1958
+
1959
+ bool GCNHazardRecognizer::fixWMMACoexecutionHazards (MachineInstr *MI) {
1960
+ if (!AMDGPU::isGFX1250 (ST))
1961
+ return false ;
1962
+
1963
+ const SIInstrInfo *TII = ST.getInstrInfo ();
1964
+ if (!TII->isXDLWMMA (*MI) && !isCoexecutableVALUInst (*MI))
1965
+ return false ;
1966
+
1967
+ const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1968
+
1969
+ // WaitStates here is the number of V_NOPs or unrelated VALU instructions must
1970
+ // be in between the first WMMA and the second instruction to cover the hazard
1971
+ // (WMMAWaitStates if the second is also a WMMA, VALUWaitStates if the second
1972
+ // is a VALU). Refer to SPG 4.6.12.1. "Requirements for WMMA data hazards" for
1973
+ // numbers, which depends on the category of the first WMMA.
1974
+ const int WMMAWaitStates[] = {5 , 9 , 3 , 5 };
1975
+ const int VALUWaitStates[] = {4 , 8 , 2 , 4 };
1976
+ unsigned Category = 0 ;
1977
+
1978
+ auto IsWMMAHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
1979
+ if (!TII->isXDLWMMA (I))
1980
+ return false ;
1981
+
1982
+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
1983
+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
1984
+ return false ;
1985
+
1986
+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
1987
+ Register A1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src0)->getReg ();
1988
+ Register B1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src1)->getReg ();
1989
+
1990
+ // WMMA0 wrires (D0), WMMA1 reads (A1/B1/Idx1).
1991
+ if (TRI->regsOverlap (D0, A1) || TRI->regsOverlap (D0, B1))
1992
+ return true ;
1993
+
1994
+ if (SIInstrInfo::isSWMMAC (*MI)) {
1995
+ Register Idx1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src2)->getReg ();
1996
+ if (TRI->regsOverlap (D0, Idx1))
1997
+ return true ;
1998
+ }
1999
+
2000
+ return false ;
2001
+ };
2002
+
2003
+ auto IsVALUHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
2004
+ if (!TII->isXDLWMMA (I))
2005
+ return false ;
2006
+
2007
+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
2008
+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
2009
+ return false ;
2010
+
2011
+ // WMMA writes, VALU reads.
2012
+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
2013
+ for (const MachineOperand &ValuUse : MI->explicit_uses ()) {
2014
+ if (ValuUse.isReg () && TRI->regsOverlap (D0, ValuUse.getReg ()))
2015
+ return true ;
2016
+ }
2017
+
2018
+ auto *ValuDst = TII->getNamedOperand (*MI, AMDGPU::OpName::vdst);
2019
+ if (!ValuDst || !ValuDst->isReg ())
2020
+ return false ;
2021
+ Register D1 = ValuDst->getReg ();
2022
+
2023
+ // WMMA writes, VALU writes.
2024
+ if (TRI->regsOverlap (D0, D1))
2025
+ return true ;
2026
+
2027
+ // WMMA reads, VALU writes.
2028
+ Register A0 = TII->getNamedOperand (I, AMDGPU::OpName::src0)->getReg ();
2029
+ Register B0 = TII->getNamedOperand (I, AMDGPU::OpName::src1)->getReg ();
2030
+ if (TRI->regsOverlap (A0, D1) || TRI->regsOverlap (B0, D1))
2031
+ return true ;
2032
+
2033
+ if (SIInstrInfo::isSWMMAC (I)) {
2034
+ Register Idx0 = TII->getNamedOperand (I, AMDGPU::OpName::src2)->getReg ();
2035
+ if (TRI->regsOverlap (D1, Idx0))
2036
+ return true ;
2037
+ }
2038
+
2039
+ return false ;
2040
+ };
2041
+
2042
+ int Limit = 0 ;
2043
+ auto IsExpiredFn = [&Limit](const MachineInstr &, int WaitStates) {
2044
+ return WaitStates >= Limit;
2045
+ };
2046
+
2047
+ auto GetWaitStatesFn = [](const MachineInstr &I) {
2048
+ return SIInstrInfo::isVALU (I) ? 1 : 0 ;
2049
+ };
2050
+
2051
+ int WaitStatesNeeded = -1 ;
2052
+ if (TII->isXDLWMMA (*MI)) {
2053
+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2054
+ Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2055
+ DenseSet<const MachineBasicBlock *> Visited;
2056
+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2057
+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2058
+ // WaitStatesNeeded here means no hazard, and we will continue to search
2059
+ // for other categories.
2060
+ WaitStatesNeeded =
2061
+ Limit - ::getWaitStatesSince (IsWMMAHazardFn, MI->getParent (),
2062
+ std::next (MI->getReverseIterator ()), 0 ,
2063
+ IsExpiredFn, Visited, GetWaitStatesFn);
2064
+ }
2065
+ } else { // Must be a co-executable VALU.
2066
+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2067
+ Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2068
+ DenseSet<const MachineBasicBlock *> Visited;
2069
+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2070
+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2071
+ // WaitStatesNeeded here means no hazard, and we will continue to search
2072
+ // for other categories.
2073
+ WaitStatesNeeded =
2074
+ Limit - ::getWaitStatesSince (IsVALUHazardFn, MI->getParent (),
2075
+ std::next (MI->getReverseIterator ()), 0 ,
2076
+ IsExpiredFn, Visited, GetWaitStatesFn);
2077
+ }
2078
+ }
2079
+
2080
+ // WaitStatesNeeded now is the number of V_NOPs we need to insert, negative
2081
+ // means not needed.
2082
+ for (int i = 0 ; i < WaitStatesNeeded; i++)
2083
+ BuildMI (*MI->getParent (), MI, MI->getDebugLoc (),
2084
+ TII->get (AMDGPU::V_NOP_e32));
2085
+
2086
+ return true ;
2087
+ }
2088
+
1912
2089
bool GCNHazardRecognizer::fixShift64HighRegBug (MachineInstr *MI) {
1913
2090
if (!ST.hasShift64HighRegBug ())
1914
2091
return false ;
0 commit comments