-
Notifications
You must be signed in to change notification settings - Fork 14.5k
AMDGPU: Support v_wmma_f32_16x16x128_f8f6f4 on gfx1250 #149684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-Authored-by: Stanislav Mekhanoshin <[email protected]>
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-ir Author: Changpeng Fang (changpeng) ChangesPatch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff 31 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
ArgsForMatchingMatrixTypes = {4, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
break;
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+ ArgsForMatchingMatrixTypes = {5, 1, 3};
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+ break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
ArgsForMatchingMatrixTypes = {3, 0, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
}
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT: entry:
+// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT: ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
// CHECK-GFX1250-NEXT: entry:
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
}
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
{
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
IntrWillReturn, IntrNoCallback, IntrNoFree]
>;
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+ Intrinsic<
+ [llvm_anyfloat_ty], // %D
+ [
+ llvm_i32_ty, // matrix_a_fmt
+ llvm_anyint_ty, // %A
+ llvm_i32_ty, // matrix_b_fmt
+ llvm_anyint_ty, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
break;
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = Call.getArgOperand(1);
+ Value *Src1 = Call.getArgOperand(3);
+
+ unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+ unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+ Check(FmtA <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(0));
+ Check(FmtB <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(2));
+
+ // AMDGPU::MatrixFMT values
+ auto getFormatNumRegs = [](unsigned FormatVal) {
+ switch (FormatVal) {
+ case 0:
+ case 1:
+ return 16u;
+ case 2:
+ case 3:
+ return 12u;
+ case 4:
+ return 8u;
+ default:
+ llvm_unreachable("invalid format value");
+ }
+ };
+
+ auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+ if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+ return false;
+ unsigned NumElts = Ty->getNumElements();
+ return NumElts == 16 || NumElts == 12 || NumElts == 8;
+ };
+
+ auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+ Check(isValidSrcASrcBVector(Src0Ty),
+ "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+ Check(isValidSrcASrcBVector(Src1Ty),
+ "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+ // Permit excess registers for the format.
+ Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+ "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+ Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+ "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+ break;
+ }
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
NewII->takeName(&II);
return IC.replaceInstUsesWith(II, NewII);
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = II.getArgOperand(1);
+ Value *Src1 = II.getArgOperand(3);
+ unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+ uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+ bool MadeChange = false;
+ unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+ unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+ // Depending on the used format, fewer registers are required so shrink the
+ // vector type.
+ if (Src0Ty->getNumElements() > Src0NumElts) {
+ Src0 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (Src1Ty->getNumElements() > Src1NumElts) {
+ Src1 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (!MadeChange)
+ return std::nullopt;
+
+ SmallVector<Value *, 13> Args(II.args());
+ Args[1] = Src0;
+ Args[3] = Src1;
+
+ CallInst *NewII = IC.Builder.CreateIntrinsic(
+ IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+ Args, &II);
+ NewII->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewII);
+ }
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyWaitVAVDst,
ImmTyWaitVMVSrc,
ImmTyBitOp3,
+ ImmTyMatrixAFMT,
+ ImmTyMatrixBFMT,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+ bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+ bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
case ImmTyBitOp3: OS << "BitOp3"; break;
+ case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+ case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
ParseStatus parseIndexKey8bit(OperandVector &Operands);
ParseStatus parseIndexKey16bit(OperandVector &Operands);
ParseStatus parseIndexKey32bit(OperandVector &Operands);
+ ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAFMT(OperandVector &Operands);
+ ParseStatus parseMatrixBFMT(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
const unsigned CPol);
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+ bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
unsigned getConstantBusLimit(unsigned Opcode) const;
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
return true;
}
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+ const OperandVector &Operands) {
+ unsigned Opc = Inst.getOpcode();
+ const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+ const MCInstrDesc &Desc = MII.get(Opc);
+
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+ int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+ if (FmtIdx == -1)
+ return true;
+ unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+ int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+ unsigned RegSize =
+ TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+ return true;
+
+ static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"};
+
+ Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+ "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+ return false;
+ };
+
+ return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+ validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
const SMLoc &IDLoc,
const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
if (!validateTFE(Inst, Operands)) {
return false;
}
+ if (!validateWMMA(Inst, Operands)) {
+ return false;
+ }
return true;
}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(Operands, Name,
+ {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+ AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+ AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
DefaultVal);
}
+ int MatrixAFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+ if (MatrixAFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAFMT, 0);
+ }
+
+ int MatrixBFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+ if (MatrixBFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBFMT, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
convertMAIInst(MI);
+ if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+ convertWMMAInst(MI);
+
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::vdst_in);
if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
return MO.setReg(
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
case 8:
+ if (MCRegister NewReg = MRI.getSubReg(
+ MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]
|
@llvm/pr-subscribers-llvm-transforms Author: Changpeng Fang (changpeng) ChangesPatch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff 31 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
ArgsForMatchingMatrixTypes = {4, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
break;
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+ ArgsForMatchingMatrixTypes = {5, 1, 3};
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+ break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
ArgsForMatchingMatrixTypes = {3, 0, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
}
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT: entry:
+// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT: ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
// CHECK-GFX1250-NEXT: entry:
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
}
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
{
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
IntrWillReturn, IntrNoCallback, IntrNoFree]
>;
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+ Intrinsic<
+ [llvm_anyfloat_ty], // %D
+ [
+ llvm_i32_ty, // matrix_a_fmt
+ llvm_anyint_ty, // %A
+ llvm_i32_ty, // matrix_b_fmt
+ llvm_anyint_ty, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
break;
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = Call.getArgOperand(1);
+ Value *Src1 = Call.getArgOperand(3);
+
+ unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+ unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+ Check(FmtA <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(0));
+ Check(FmtB <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(2));
+
+ // AMDGPU::MatrixFMT values
+ auto getFormatNumRegs = [](unsigned FormatVal) {
+ switch (FormatVal) {
+ case 0:
+ case 1:
+ return 16u;
+ case 2:
+ case 3:
+ return 12u;
+ case 4:
+ return 8u;
+ default:
+ llvm_unreachable("invalid format value");
+ }
+ };
+
+ auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+ if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+ return false;
+ unsigned NumElts = Ty->getNumElements();
+ return NumElts == 16 || NumElts == 12 || NumElts == 8;
+ };
+
+ auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+ Check(isValidSrcASrcBVector(Src0Ty),
+ "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+ Check(isValidSrcASrcBVector(Src1Ty),
+ "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+ // Permit excess registers for the format.
+ Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+ "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+ Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+ "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+ break;
+ }
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
NewII->takeName(&II);
return IC.replaceInstUsesWith(II, NewII);
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = II.getArgOperand(1);
+ Value *Src1 = II.getArgOperand(3);
+ unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+ uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+ bool MadeChange = false;
+ unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+ unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+ // Depending on the used format, fewer registers are required so shrink the
+ // vector type.
+ if (Src0Ty->getNumElements() > Src0NumElts) {
+ Src0 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (Src1Ty->getNumElements() > Src1NumElts) {
+ Src1 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (!MadeChange)
+ return std::nullopt;
+
+ SmallVector<Value *, 13> Args(II.args());
+ Args[1] = Src0;
+ Args[3] = Src1;
+
+ CallInst *NewII = IC.Builder.CreateIntrinsic(
+ IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+ Args, &II);
+ NewII->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewII);
+ }
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyWaitVAVDst,
ImmTyWaitVMVSrc,
ImmTyBitOp3,
+ ImmTyMatrixAFMT,
+ ImmTyMatrixBFMT,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+ bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+ bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
case ImmTyBitOp3: OS << "BitOp3"; break;
+ case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+ case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
ParseStatus parseIndexKey8bit(OperandVector &Operands);
ParseStatus parseIndexKey16bit(OperandVector &Operands);
ParseStatus parseIndexKey32bit(OperandVector &Operands);
+ ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAFMT(OperandVector &Operands);
+ ParseStatus parseMatrixBFMT(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
const unsigned CPol);
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+ bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
unsigned getConstantBusLimit(unsigned Opcode) const;
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
return true;
}
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+ const OperandVector &Operands) {
+ unsigned Opc = Inst.getOpcode();
+ const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+ const MCInstrDesc &Desc = MII.get(Opc);
+
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+ int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+ if (FmtIdx == -1)
+ return true;
+ unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+ int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+ unsigned RegSize =
+ TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+ return true;
+
+ static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"};
+
+ Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+ "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+ return false;
+ };
+
+ return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+ validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
const SMLoc &IDLoc,
const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
if (!validateTFE(Inst, Operands)) {
return false;
}
+ if (!validateWMMA(Inst, Operands)) {
+ return false;
+ }
return true;
}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(Operands, Name,
+ {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+ AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+ AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
DefaultVal);
}
+ int MatrixAFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+ if (MatrixAFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAFMT, 0);
+ }
+
+ int MatrixBFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+ if (MatrixBFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBFMT, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
convertMAIInst(MI);
+ if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+ convertWMMAInst(MI);
+
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::vdst_in);
if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
return MO.setReg(
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
case 8:
+ if (MCRegister NewReg = MRI.getSubReg(
+ MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]
|
@llvm/pr-subscribers-clang Author: Changpeng Fang (changpeng) ChangesPatch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff 31 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
ArgsForMatchingMatrixTypes = {4, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
break;
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+ ArgsForMatchingMatrixTypes = {5, 1, 3};
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+ break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
ArgsForMatchingMatrixTypes = {3, 0, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
}
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT: entry:
+// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT: ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
// CHECK-GFX1250-NEXT: entry:
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
}
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
{
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
IntrWillReturn, IntrNoCallback, IntrNoFree]
>;
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+ Intrinsic<
+ [llvm_anyfloat_ty], // %D
+ [
+ llvm_i32_ty, // matrix_a_fmt
+ llvm_anyint_ty, // %A
+ llvm_i32_ty, // matrix_b_fmt
+ llvm_anyint_ty, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
break;
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = Call.getArgOperand(1);
+ Value *Src1 = Call.getArgOperand(3);
+
+ unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+ unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+ Check(FmtA <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(0));
+ Check(FmtB <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(2));
+
+ // AMDGPU::MatrixFMT values
+ auto getFormatNumRegs = [](unsigned FormatVal) {
+ switch (FormatVal) {
+ case 0:
+ case 1:
+ return 16u;
+ case 2:
+ case 3:
+ return 12u;
+ case 4:
+ return 8u;
+ default:
+ llvm_unreachable("invalid format value");
+ }
+ };
+
+ auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+ if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+ return false;
+ unsigned NumElts = Ty->getNumElements();
+ return NumElts == 16 || NumElts == 12 || NumElts == 8;
+ };
+
+ auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+ Check(isValidSrcASrcBVector(Src0Ty),
+ "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+ Check(isValidSrcASrcBVector(Src1Ty),
+ "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+ // Permit excess registers for the format.
+ Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+ "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+ Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+ "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+ break;
+ }
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
NewII->takeName(&II);
return IC.replaceInstUsesWith(II, NewII);
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = II.getArgOperand(1);
+ Value *Src1 = II.getArgOperand(3);
+ unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+ uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+ bool MadeChange = false;
+ unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+ unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+ // Depending on the used format, fewer registers are required so shrink the
+ // vector type.
+ if (Src0Ty->getNumElements() > Src0NumElts) {
+ Src0 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (Src1Ty->getNumElements() > Src1NumElts) {
+ Src1 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (!MadeChange)
+ return std::nullopt;
+
+ SmallVector<Value *, 13> Args(II.args());
+ Args[1] = Src0;
+ Args[3] = Src1;
+
+ CallInst *NewII = IC.Builder.CreateIntrinsic(
+ IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+ Args, &II);
+ NewII->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewII);
+ }
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyWaitVAVDst,
ImmTyWaitVMVSrc,
ImmTyBitOp3,
+ ImmTyMatrixAFMT,
+ ImmTyMatrixBFMT,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+ bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+ bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
case ImmTyBitOp3: OS << "BitOp3"; break;
+ case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+ case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
ParseStatus parseIndexKey8bit(OperandVector &Operands);
ParseStatus parseIndexKey16bit(OperandVector &Operands);
ParseStatus parseIndexKey32bit(OperandVector &Operands);
+ ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAFMT(OperandVector &Operands);
+ ParseStatus parseMatrixBFMT(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
const unsigned CPol);
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+ bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
unsigned getConstantBusLimit(unsigned Opcode) const;
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
return true;
}
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+ const OperandVector &Operands) {
+ unsigned Opc = Inst.getOpcode();
+ const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+ const MCInstrDesc &Desc = MII.get(Opc);
+
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+ int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+ if (FmtIdx == -1)
+ return true;
+ unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+ int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+ unsigned RegSize =
+ TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+ return true;
+
+ static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"};
+
+ Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+ "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+ return false;
+ };
+
+ return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+ validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
const SMLoc &IDLoc,
const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
if (!validateTFE(Inst, Operands)) {
return false;
}
+ if (!validateWMMA(Inst, Operands)) {
+ return false;
+ }
return true;
}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(Operands, Name,
+ {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+ AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+ AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
DefaultVal);
}
+ int MatrixAFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+ if (MatrixAFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAFMT, 0);
+ }
+
+ int MatrixBFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+ if (MatrixBFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBFMT, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
convertMAIInst(MI);
+ if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+ convertWMMAInst(MI);
+
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::vdst_in);
if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
return MO.setReg(
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
case 8:
+ if (MCRegister NewReg = MRI.getSubReg(
+ MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]
|
@llvm/pr-subscribers-mc Author: Changpeng Fang (changpeng) ChangesPatch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff 31 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
ArgsForMatchingMatrixTypes = {4, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
break;
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+ ArgsForMatchingMatrixTypes = {5, 1, 3};
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+ break;
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
ArgsForMatchingMatrixTypes = {3, 0, 1};
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
}
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT: entry:
+// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT: ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
// CHECK-GFX1250-NEXT: entry:
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
}
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+ *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
{
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
IntrWillReturn, IntrNoCallback, IntrNoFree]
>;
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+ Intrinsic<
+ [llvm_anyfloat_ty], // %D
+ [
+ llvm_i32_ty, // matrix_a_fmt
+ llvm_anyint_ty, // %A
+ llvm_i32_ty, // matrix_b_fmt
+ llvm_anyint_ty, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
break;
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = Call.getArgOperand(1);
+ Value *Src1 = Call.getArgOperand(3);
+
+ unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+ unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+ Check(FmtA <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(0));
+ Check(FmtB <= 4, "invalid value for matrix format", Call,
+ Call.getArgOperand(2));
+
+ // AMDGPU::MatrixFMT values
+ auto getFormatNumRegs = [](unsigned FormatVal) {
+ switch (FormatVal) {
+ case 0:
+ case 1:
+ return 16u;
+ case 2:
+ case 3:
+ return 12u;
+ case 4:
+ return 8u;
+ default:
+ llvm_unreachable("invalid format value");
+ }
+ };
+
+ auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+ if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+ return false;
+ unsigned NumElts = Ty->getNumElements();
+ return NumElts == 16 || NumElts == 12 || NumElts == 8;
+ };
+
+ auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+ Check(isValidSrcASrcBVector(Src0Ty),
+ "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+ Check(isValidSrcASrcBVector(Src1Ty),
+ "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+ // Permit excess registers for the format.
+ Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+ "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+ Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+ "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+ break;
+ }
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
NewII->takeName(&II);
return IC.replaceInstUsesWith(II, NewII);
}
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+ Value *Src0 = II.getArgOperand(1);
+ Value *Src1 = II.getArgOperand(3);
+ unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+ uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+ auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+ auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+ bool MadeChange = false;
+ unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+ unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+ // Depending on the used format, fewer registers are required so shrink the
+ // vector type.
+ if (Src0Ty->getNumElements() > Src0NumElts) {
+ Src0 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (Src1Ty->getNumElements() > Src1NumElts) {
+ Src1 = IC.Builder.CreateExtractVector(
+ FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+ IC.Builder.getInt64(0));
+ MadeChange = true;
+ }
+
+ if (!MadeChange)
+ return std::nullopt;
+
+ SmallVector<Value *, 13> Args(II.args());
+ Args[1] = Src0;
+ Args[3] = Src1;
+
+ CallInst *NewII = IC.Builder.CreateIntrinsic(
+ IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+ Args, &II);
+ NewII->takeName(&II);
+ return IC.replaceInstUsesWith(II, NewII);
+ }
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+ case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
ImmTyWaitVAVDst,
ImmTyWaitVMVSrc,
ImmTyBitOp3,
+ ImmTyMatrixAFMT,
+ ImmTyMatrixBFMT,
ImmTyMatrixAReuse,
ImmTyMatrixBReuse,
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+ bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+ bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
case ImmTyBitOp3: OS << "BitOp3"; break;
+ case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+ case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
ParseStatus parseIndexKey8bit(OperandVector &Operands);
ParseStatus parseIndexKey16bit(OperandVector &Operands);
ParseStatus parseIndexKey32bit(OperandVector &Operands);
+ ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+ AMDGPUOperand::ImmTy Type);
+ ParseStatus parseMatrixAFMT(OperandVector &Operands);
+ ParseStatus parseMatrixBFMT(OperandVector &Operands);
ParseStatus parseDfmtNfmt(int64_t &Format);
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
const unsigned CPol);
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+ bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
unsigned getConstantBusLimit(unsigned Opcode) const;
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
return true;
}
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+ const OperandVector &Operands) {
+ unsigned Opc = Inst.getOpcode();
+ const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+ const MCInstrDesc &Desc = MII.get(Opc);
+
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+ int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+ if (FmtIdx == -1)
+ return true;
+ unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+ int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+ unsigned RegSize =
+ TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+ return true;
+
+ static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"};
+
+ Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+ "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+ return false;
+ };
+
+ return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+ validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
const SMLoc &IDLoc,
const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
if (!validateTFE(Inst, Operands)) {
return false;
}
+ if (!validateWMMA(Inst, Operands)) {
+ return false;
+ }
return true;
}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
}
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+ StringRef Name,
+ AMDGPUOperand::ImmTy Type) {
+ return parseStringOrIntWithPrefix(Operands, Name,
+ {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+ "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+ "MATRIX_FMT_FP4"},
+ Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+ AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+ return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+ AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
// values to live in a joint format operand in the MCInst encoding.
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
DefaultVal);
}
+ int MatrixAFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+ if (MatrixAFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixAFMT, 0);
+ }
+
+ int MatrixBFMTIdx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+ if (MatrixBFMTIdx != -1) {
+ addOptionalImmOperand(Inst, Operands, OptIdx,
+ AMDGPUOperand::ImmTyMatrixBFMT, 0);
+ }
+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
addOptionalImmOperand(Inst, Operands, OptIdx,
AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
convertMAIInst(MI);
+ if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+ convertWMMAInst(MI);
+
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::vdst_in);
if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
return MO.setReg(
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
case 8:
+ if (MCRegister NewReg = MRI.getSubReg(
+ MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]
|
No description provided.