diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 5a53b15a9c679..b237f7b5749e7 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -719,6 +719,29 @@ def AMDGPU_SchedBarrierOp : }]; } +def AMDGPU_MemoryCounterWaitOp : + AMDGPU_Op<"memory_counter_wait">, + Arguments<(ins + OptionalAttr:$load, + OptionalAttr:$store, + OptionalAttr:$ds, + OptionalAttr:$exp + )> + { + let summary = "Wait for specified hardware counters"; + let description = [{ + Wait for the specified counters to be less-than or equal-to the provided + values before continuing. + + Counters can lower to different instructions on different architectires, + including clamping to the some HW supported max value or combining multiple + counters into one. + }]; + let assemblyFormat = [{ + oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict + }]; +} + def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB", "The possible permutations of the lanes storing B available in an MFMA", [ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ef35ee208f002..309476ca7136a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -419,6 +419,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { } }; +// TODO: AMDGPU backend already have all this bitpacking logic, we should move +// it to some common place. +/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: +/// Vmcnt = Waitcnt[3:0] (pre-gfx9) +/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) +/// Vmcnt = Waitcnt[15:10] (gfx11) +/// Expcnt = Waitcnt[6:4] (pre-gfx11) +/// Expcnt = Waitcnt[2:0] (gfx11) +/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) +/// Lgkmcnt = Waitcnt[13:8] (gfx10) +/// Lgkmcnt = Waitcnt[9:4] (gfx11) +static FailureOr encodeWaitcnt(Chipset chipset, unsigned vmcnt, + unsigned expcnt, unsigned lgkmcnt) { + if (chipset.majorVersion < 9) { + vmcnt = std::min(15u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + return vmcnt | (expcnt << 4) | (lgkmcnt << 8); + } + if (chipset.majorVersion == 9) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 10) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 11) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + return (vmcnt << 10) | expcnt | (lgkmcnt << 4); + } + return failure(); +} + +struct MemoryCounterWaitOpLowering + : public ConvertOpToLLVMPattern { + MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset.majorVersion >= 12) { + Location loc = op.getLoc(); + if (std::optional ds = adaptor.getDs()) + rewriter.create(loc, *ds); + + if (std::optional load = adaptor.getLoad()) + rewriter.create(loc, *load); + + if (std::optional store = adaptor.getStore()) + rewriter.create(loc, *store); + + if (std::optional exp = adaptor.getExp()) + rewriter.create(loc, *exp); + + rewriter.eraseOp(op); + return success(); + } + + auto getVal = [](Attribute attr) -> unsigned { + if (attr) + return cast(attr).getInt(); + + // This value will be clamped to the maximum value for the chipset. + return 1024; + }; + unsigned ds = getVal(adaptor.getDsAttr()); + unsigned exp = getVal(adaptor.getExpAttr()); + + unsigned vmcnt = 1024; + Attribute load = adaptor.getLoadAttr(); + Attribute store = adaptor.getStoreAttr(); + if (load && store) { + vmcnt = getVal(load) + getVal(store); + } else if (load) { + vmcnt = getVal(load); + } else if (store) { + vmcnt = getVal(store); + } + + FailureOr waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); + if (failed(waitcnt)) + return op.emitOpError("unsupported chipset"); + + rewriter.replaceOpWithNewOp(op, *waitcnt); + return success(); + } +}; + struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} @@ -1825,9 +1931,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering, - AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, - MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, - ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering>(converter, chipset); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir new file mode 100644 index 0000000000000..1016ee859e462 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12 + +// CHECK-LABEL: func @memory_counter_wait +func.func @memory_counter_wait() { + // GFX9: rocdl.s.waitcnt 53119 + // GFX10: rocdl.s.waitcnt 65407 + // GFX11: rocdl.s.waitcnt 65527 + // GFX12-NOT: rocdl.s.wait.loadcnt + // GFX12-NOT: rocdl.s.wait.storecnt + // GFX12-NOT: rocdl.s.wait.expcnt + // GFX12-NOT: rocdl.s.wait.dscnt + amdgpu.memory_counter_wait + + // GFX9: rocdl.s.waitcnt 3952 + // GFX10: rocdl.s.waitcnt 16240 + // GFX11: rocdl.s.waitcnt 1015 + // GFX12: rocdl.s.wait.loadcnt 0 + amdgpu.memory_counter_wait load(0) + + // GFX9: rocdl.s.waitcnt 3952 + // GFX10: rocdl.s.waitcnt 16240 + // GFX11: rocdl.s.waitcnt 1015 + // GFX12: rocdl.s.wait.storecnt 0 + amdgpu.memory_counter_wait store(0) + + // GFX9: rocdl.s.waitcnt 53007 + // GFX10: rocdl.s.waitcnt 65295 + // GFX11: rocdl.s.waitcnt 65520 + // GFX12: rocdl.s.wait.expcnt 0 + amdgpu.memory_counter_wait exp(0) + + // GFX9: rocdl.s.waitcnt 49279 + // GFX10: rocdl.s.waitcnt 49279 + // GFX11: rocdl.s.waitcnt 64519 + // GFX12: rocdl.s.wait.dscnt 0 + amdgpu.memory_counter_wait ds(0) + + return +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index fe2b32be04de4..fe78b5365745a 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -548,3 +548,20 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, % amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space> func.return } + +// CHECK-LABEL: func @memory_counter_wait +func.func @memory_counter_wait() { + // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) + // CHECK: amdgpu.memory_counter_wait load(1) + // CHECK: amdgpu.memory_counter_wait store(2) + // CHECK: amdgpu.memory_counter_wait ds(3) + // CHECK: amdgpu.memory_counter_wait exp(4) + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4) + amdgpu.memory_counter_wait load(1) + amdgpu.memory_counter_wait store(2) + amdgpu.memory_counter_wait ds(3) + amdgpu.memory_counter_wait exp(4) + func.return +}