Skip to content

AMDGPU: Support v_wmma_f32_16x16x128_f8f6f4 on gfx1250 #149684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

changpeng
Copy link
Contributor

No description provided.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AMDGPU clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. mc Machine (object) code llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Jul 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 19, 2025

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-llvm-ir

Author: Changpeng Fang (changpeng)

Changes

Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff

31 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5)
  • (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12)
  • (modified) clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15)
  • (modified) llvm/lib/IR/Verifier.cpp (+48)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1)
  • (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33)
  • (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2)
  • (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+105)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+67)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76)
  • (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39)
  • (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158)
  • (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165)
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
       ArgsForMatchingMatrixTypes = {4, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+      ArgsForMatchingMatrixTypes = {5, 1, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+      break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
       ArgsForMatchingMatrixTypes = {3, 0, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
 }
 
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT:  entry:
+// CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT:    [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT:    store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
 // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
 // CHECK-GFX1250-NEXT:  entry:
 // CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
 }
 
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
 void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
 {
   *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
      IntrWillReturn, IntrNoCallback, IntrNoFree]
 >;
 
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+  Intrinsic<
+    [llvm_anyfloat_ty], // %D
+    [
+      llvm_i32_ty,      // matrix_a_fmt
+      llvm_anyint_ty,   // %A
+      llvm_i32_ty,      // matrix_b_fmt
+      llvm_anyint_ty,   // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>, // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
 def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_i32_16x16x64_iu8      : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4  : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
 def int_amdgcn_wmma_f32_32x16x128_f4       : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 }
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
     break;
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(1);
+    Value *Src1 = Call.getArgOperand(3);
+
+    unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+    unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+    Check(FmtA <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(0));
+    Check(FmtB <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(2));
+
+    // AMDGPU::MatrixFMT values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 16u;
+      case 2:
+      case 3:
+        return 12u;
+      case 4:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 16 || NumElts == 12 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+          "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+          "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     NewII->takeName(&II);
     return IC.replaceInstUsesWith(II, NewII);
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = II.getArgOperand(1);
+    Value *Src1 = II.getArgOperand(3);
+    unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+    uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+    unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 13> Args(II.args());
+    Args[1] = Src0;
+    Args[3] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+        Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
     case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
     case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
     case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
   ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
   return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+                                               StringRef Name,
+                                               AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(Operands, Name,
+                                    {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"},
+                                    Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+                           AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+                           AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           DefaultVal);
   }
 
+  int MatrixAFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+  if (MatrixAFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAFMT, 0);
+  }
+
+  int MatrixBFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+  if (MatrixBFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBFMT, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
   if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
     convertMAIInst(MI);
 
+  if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+    convertWMMAInst(MI);
+
   int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
                                               AMDGPU::OpName::vdst_in);
   if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
     return MO.setReg(
         MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
   case 8:
+    if (MCRegister NewReg = MRI.getSubReg(
+            MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 19, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Changpeng Fang (changpeng)

Changes

Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff

31 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5)
  • (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12)
  • (modified) clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15)
  • (modified) llvm/lib/IR/Verifier.cpp (+48)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1)
  • (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33)
  • (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2)
  • (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+105)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+67)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76)
  • (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39)
  • (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158)
  • (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165)
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
       ArgsForMatchingMatrixTypes = {4, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+      ArgsForMatchingMatrixTypes = {5, 1, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+      break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
       ArgsForMatchingMatrixTypes = {3, 0, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
 }
 
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT:  entry:
+// CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT:    [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT:    store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
 // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
 // CHECK-GFX1250-NEXT:  entry:
 // CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
 }
 
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
 void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
 {
   *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
      IntrWillReturn, IntrNoCallback, IntrNoFree]
 >;
 
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+  Intrinsic<
+    [llvm_anyfloat_ty], // %D
+    [
+      llvm_i32_ty,      // matrix_a_fmt
+      llvm_anyint_ty,   // %A
+      llvm_i32_ty,      // matrix_b_fmt
+      llvm_anyint_ty,   // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>, // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
 def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_i32_16x16x64_iu8      : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4  : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
 def int_amdgcn_wmma_f32_32x16x128_f4       : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 }
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
     break;
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(1);
+    Value *Src1 = Call.getArgOperand(3);
+
+    unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+    unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+    Check(FmtA <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(0));
+    Check(FmtB <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(2));
+
+    // AMDGPU::MatrixFMT values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 16u;
+      case 2:
+      case 3:
+        return 12u;
+      case 4:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 16 || NumElts == 12 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+          "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+          "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     NewII->takeName(&II);
     return IC.replaceInstUsesWith(II, NewII);
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = II.getArgOperand(1);
+    Value *Src1 = II.getArgOperand(3);
+    unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+    uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+    unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 13> Args(II.args());
+    Args[1] = Src0;
+    Args[3] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+        Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
     case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
     case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
     case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
   ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
   return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+                                               StringRef Name,
+                                               AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(Operands, Name,
+                                    {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"},
+                                    Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+                           AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+                           AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           DefaultVal);
   }
 
+  int MatrixAFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+  if (MatrixAFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAFMT, 0);
+  }
+
+  int MatrixBFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+  if (MatrixBFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBFMT, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
   if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
     convertMAIInst(MI);
 
+  if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+    convertWMMAInst(MI);
+
   int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
                                               AMDGPU::OpName::vdst_in);
   if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
     return MO.setReg(
         MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
   case 8:
+    if (MCRegister NewReg = MRI.getSubReg(
+            MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 19, 2025

@llvm/pr-subscribers-clang

Author: Changpeng Fang (changpeng)

Changes

Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff

31 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5)
  • (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12)
  • (modified) clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15)
  • (modified) llvm/lib/IR/Verifier.cpp (+48)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1)
  • (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33)
  • (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2)
  • (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+105)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+67)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76)
  • (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39)
  • (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158)
  • (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165)
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
       ArgsForMatchingMatrixTypes = {4, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+      ArgsForMatchingMatrixTypes = {5, 1, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+      break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
       ArgsForMatchingMatrixTypes = {3, 0, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
 }
 
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT:  entry:
+// CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT:    [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT:    store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
 // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
 // CHECK-GFX1250-NEXT:  entry:
 // CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
 }
 
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
 void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
 {
   *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
      IntrWillReturn, IntrNoCallback, IntrNoFree]
 >;
 
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+  Intrinsic<
+    [llvm_anyfloat_ty], // %D
+    [
+      llvm_i32_ty,      // matrix_a_fmt
+      llvm_anyint_ty,   // %A
+      llvm_i32_ty,      // matrix_b_fmt
+      llvm_anyint_ty,   // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>, // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
 def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_i32_16x16x64_iu8      : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4  : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
 def int_amdgcn_wmma_f32_32x16x128_f4       : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 }
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
     break;
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(1);
+    Value *Src1 = Call.getArgOperand(3);
+
+    unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+    unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+    Check(FmtA <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(0));
+    Check(FmtB <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(2));
+
+    // AMDGPU::MatrixFMT values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 16u;
+      case 2:
+      case 3:
+        return 12u;
+      case 4:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 16 || NumElts == 12 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+          "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+          "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     NewII->takeName(&II);
     return IC.replaceInstUsesWith(II, NewII);
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = II.getArgOperand(1);
+    Value *Src1 = II.getArgOperand(3);
+    unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+    uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+    unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 13> Args(II.args());
+    Args[1] = Src0;
+    Args[3] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+        Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
     case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
     case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
     case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
   ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
   return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+                                               StringRef Name,
+                                               AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(Operands, Name,
+                                    {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"},
+                                    Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+                           AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+                           AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           DefaultVal);
   }
 
+  int MatrixAFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+  if (MatrixAFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAFMT, 0);
+  }
+
+  int MatrixBFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+  if (MatrixBFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBFMT, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
   if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
     convertMAIInst(MI);
 
+  if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+    convertWMMAInst(MI);
+
   int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
                                               AMDGPU::OpName::vdst_in);
   if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
     return MO.setReg(
         MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
   case 8:
+    if (MCRegister NewReg = MRI.getSubReg(
+            MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 19, 2025

@llvm/pr-subscribers-mc

Author: Changpeng Fang (changpeng)

Changes

Patch is 132.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149684.diff

31 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp (+5)
  • (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl (+12)
  • (modified) clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl (+7)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+15)
  • (modified) llvm/lib/IR/Verifier.cpp (+48)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+41)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+79)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+46-1)
  • (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+42)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+10)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1)
  • (modified) llvm/lib/Target/AMDGPU/SISchedule.td (+15)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp (+23)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+8)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+89-33)
  • (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+2)
  • (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+9)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.gfx1250.w32.ll (+552)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imm.gfx1250.w32.ll (+105)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma.imod.gfx1250.w32.ll (+67)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32.s (+65)
  • (modified) llvm/test/MC/AMDGPU/gfx1250_asm_wmma_w32_err.s (+76)
  • (modified) llvm/test/MC/Disassembler/AMDGPU/gfx1250_dasm_wmma_w32.txt (+39)
  • (added) llvm/test/Transforms/InstCombine/AMDGPU/wmma-f8f6f4.ll (+158)
  • (added) llvm/test/Verifier/AMDGPU/wmma-f8f6f4.ll (+165)
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index d4fef5d46af73..878543566f0e3 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
 TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
diff --git a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
index ee736a2816218..7dccf82b1a7a3 100644
--- a/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
   case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
   case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
       ArgsForMatchingMatrixTypes = {4, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
+      ArgsForMatchingMatrixTypes = {5, 1, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
+      break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
       ArgsForMatchingMatrixTypes = {3, 0, 1};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
index e4ef3defdb341..86c27d48ab0d4 100644
--- a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
 }
 
+// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
+// CHECK-GFX1250-NEXT:  entry:
+// CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
+// CHECK-GFX1250-NEXT:    [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
+// CHECK-GFX1250-NEXT:    store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
+// CHECK-GFX1250-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
+}
+
 // CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
 // CHECK-GFX1250-NEXT:  entry:
 // CHECK-GFX1250-NEXT:    [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)
diff --git a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
index 55d705e6ad238..8aa7c34672783 100644
--- a/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
+++ b/clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
   *out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
 }
 
+void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+  *out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
+}
+
 void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
 {
   *out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ecda6c4efefe3..8bfa34584c3a4 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
      IntrWillReturn, IntrNoCallback, IntrNoFree]
 >;
 
+class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
+  Intrinsic<
+    [llvm_anyfloat_ty], // %D
+    [
+      llvm_i32_ty,      // matrix_a_fmt
+      llvm_anyint_ty,   // %A
+      llvm_i32_ty,      // matrix_b_fmt
+      llvm_anyint_ty,   // %B
+      llvm_i16_ty,      // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+      LLVMMatchType<0>, // %C
+    ],
+    [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
 def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_i32_16x16x64_iu8      : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_16x16x128_f8f6f4  : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
 def int_amdgcn_wmma_f32_32x16x128_f4       : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 }
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8c8ed3c5e47ba..40d20ee92e4af 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
     break;
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = Call.getArgOperand(1);
+    Value *Src1 = Call.getArgOperand(3);
+
+    unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
+    unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
+    Check(FmtA <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(0));
+    Check(FmtB <= 4, "invalid value for matrix format", Call,
+          Call.getArgOperand(2));
+
+    // AMDGPU::MatrixFMT values
+    auto getFormatNumRegs = [](unsigned FormatVal) {
+      switch (FormatVal) {
+      case 0:
+      case 1:
+        return 16u;
+      case 2:
+      case 3:
+        return 12u;
+      case 4:
+        return 8u;
+      default:
+        llvm_unreachable("invalid format value");
+      }
+    };
+
+    auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
+      if (!Ty || !Ty->getElementType()->isIntegerTy(32))
+        return false;
+      unsigned NumElts = Ty->getNumElements();
+      return NumElts == 16 || NumElts == 12 || NumElts == 8;
+    };
+
+    auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
+    Check(isValidSrcASrcBVector(Src0Ty),
+          "operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
+    Check(isValidSrcASrcBVector(Src1Ty),
+          "operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
+
+    // Permit excess registers for the format.
+    Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
+          "invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
+    Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
+          "invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
+    break;
+  }
   case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
   case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
     Value *V = Call.getArgOperand(0);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index e2c2e8912c715..f2207ff4cb1c4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     NewII->takeName(&II);
     return IC.replaceInstUsesWith(II, NewII);
   }
+  case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
+    Value *Src0 = II.getArgOperand(1);
+    Value *Src1 = II.getArgOperand(3);
+    unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
+    uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
+    auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
+    auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
+
+    bool MadeChange = false;
+    unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
+    unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
+
+    // Depending on the used format, fewer registers are required so shrink the
+    // vector type.
+    if (Src0Ty->getNumElements() > Src0NumElts) {
+      Src0 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (Src1Ty->getNumElements() > Src1NumElts) {
+      Src1 = IC.Builder.CreateExtractVector(
+          FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
+          IC.Builder.getInt64(0));
+      MadeChange = true;
+    }
+
+    if (!MadeChange)
+      return std::nullopt;
+
+    SmallVector<Value *, 13> Args(II.args());
+    Args[1] = Src0;
+    Args[3] = Src1;
+
+    CallInst *NewII = IC.Builder.CreateIntrinsic(
+        IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
+        Args, &II);
+    NewII->takeName(&II);
+    return IC.replaceInstUsesWith(II, NewII);
+  }
   }
   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index bf2f37bddb9ed..8ef4745bd0b6b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
     case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
     case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
+    case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
     case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
     case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 43d4e8db791b0..b83d88b55e72a 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     ImmTyWaitVAVDst,
     ImmTyWaitVMVSrc,
     ImmTyBitOp3,
+    ImmTyMatrixAFMT,
+    ImmTyMatrixBFMT,
     ImmTyMatrixAReuse,
     ImmTyMatrixBReuse,
     ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
   bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
   bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
   bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
+  bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
+  bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
   bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
   bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
   bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
     case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
     case ImmTyBitOp3: OS << "BitOp3"; break;
+    case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
+    case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
     case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
     case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
     case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
   ParseStatus parseIndexKey8bit(OperandVector &Operands);
   ParseStatus parseIndexKey16bit(OperandVector &Operands);
   ParseStatus parseIndexKey32bit(OperandVector &Operands);
+  ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
+                                AMDGPUOperand::ImmTy Type);
+  ParseStatus parseMatrixAFMT(OperandVector &Operands);
+  ParseStatus parseMatrixBFMT(OperandVector &Operands);
 
   ParseStatus parseDfmtNfmt(int64_t &Format);
   ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
                               const unsigned CPol);
   bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
   std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
+  bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
   unsigned getConstantBusLimit(unsigned Opcode) const;
   bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
   bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
   return true;
 }
 
+bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
+                                   const OperandVector &Operands) {
+  unsigned Opc = Inst.getOpcode();
+  const MCRegisterInfo *TRI = getContext().getRegisterInfo();
+  const MCInstrDesc &Desc = MII.get(Opc);
+
+  auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
+    int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
+    if (FmtIdx == -1)
+      return true;
+    unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
+    int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
+    unsigned RegSize =
+        TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
+
+    if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
+      return true;
+
+    static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"};
+
+    Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
+          "wrong register tuple size for " + Twine(FmtNames[Fmt]));
+    return false;
+  };
+
+  return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
+         validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
+}
+
 bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
                                           const SMLoc &IDLoc,
                                           const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
   if (!validateTFE(Inst, Operands)) {
     return false;
   }
+  if (!validateWMMA(Inst, Operands)) {
+    return false;
+  }
 
   return true;
 }
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
   return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
 }
 
+ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
+                                               StringRef Name,
+                                               AMDGPUOperand::ImmTy Type) {
+  return parseStringOrIntWithPrefix(Operands, Name,
+                                    {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
+                                     "MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
+                                     "MATRIX_FMT_FP4"},
+                                    Type);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_a_fmt",
+                           AMDGPUOperand::ImmTyMatrixAFMT);
+}
+
+ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
+  return tryParseMatrixFMT(Operands, "matrix_b_fmt",
+                           AMDGPUOperand::ImmTyMatrixBFMT);
+}
+
 // dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
 // values to live in a joint format operand in the MCInst encoding.
 ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
                           DefaultVal);
   }
 
+  int MatrixAFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
+  if (MatrixAFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixAFMT, 0);
+  }
+
+  int MatrixBFMTIdx =
+      AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
+  if (MatrixBFMTIdx != -1) {
+    addOptionalImmOperand(Inst, Operands, OptIdx,
+                          AMDGPUOperand::ImmTyMatrixBFMT, 0);
+  }
+
   if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
     addOptionalImmOperand(Inst, Operands, OptIdx,
                           AMDGPUOperand::ImmTyMatrixAReuse, 0);
diff --git a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
index 98f7e17e9528c..5c1989b345bdc 100644
--- a/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
+++ b/llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
@@ -877,6 +877,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
   if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
     convertMAIInst(MI);
 
+  if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsWMMA)
+    convertWMMAInst(MI);
+
   int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
                                               AMDGPU::OpName::vdst_in);
   if (VDstIn_Idx != -1) {
@@ -974,10 +977,23 @@ static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
     return MO.setReg(
         MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
   case 8:
+    if (MCRegister NewReg = MRI.getSubReg(
+            MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5_su...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants