-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][amdgpu] Add rocdl.s.waitcnt
wrapper
#149670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering. Only gfx9 bitpacking support added as part of this commit.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Ivan Butygin (Hardcode84) ChangesThe main motivations is to pass vmcnt/expcnt/lgkmcnt values directly (similar to the asm format) and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering. Only gfx9 support added as part of this commit. Full diff: https://github.com/llvm/llvm-project/pull/149670.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 80959ffbaf426..cecb936e18ae3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -717,6 +717,26 @@ def AMDGPU_SchedBarrierOp :
}];
}
+def AMDGPU_WaitcntOp :
+ AMDGPU_Op<"waitcnt">,
+ Arguments<(ins
+ OptionalAttr<I32Attr>:$vmcnt,
+ OptionalAttr<I32Attr>:$expcnt,
+ OptionalAttr<I32Attr>:$lgkmcnt
+ )>
+ {
+ let summary = "Wrapper on ROCDL SWaitcntOp";
+ let description = [{
+ Covenience wrapper on `rocdl.s.waitcnt`. Hides the architecture specific
+ bitpacking from user. Missing values will be assumed maximum values supported
+ by the architecture. Large values will also be clamped to the maximum
+ supported values.
+ }];
+ let assemblyFormat = [{
+ (`vmcnt` `(` $vmcnt^ `)` )? (`expcnt` `(` $expcnt^ `)` )? (`lgkmcnt` `(` $lgkmcnt^ `)`)? attr-dict
+ }];
+}
+
def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
"The possible permutations of the lanes storing B available in an MFMA",
[
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee208f002..af588d5b70a45 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -419,6 +419,52 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
}
};
+// TODO: AMDGPU backend already have all this bitpacking logic, we should move
+// it to some common place.
+static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
+ unsigned expcnt, unsigned lgkmcnt) {
+ if (chipset.majorVersion == 9) {
+ vmcnt = std::min(63u, vmcnt);
+ expcnt = std::min(7u, expcnt);
+ lgkmcnt = std::min(15u, lgkmcnt);
+ unsigned lowBits = vmcnt & 0xF;
+ unsigned highBits = (vmcnt >> 4) << 14;
+ unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
+ return lowBits | highBits | otherCnts;
+ }
+ return failure();
+}
+
+struct WaitcntOpLowering : public ConvertOpToLLVMPattern<WaitcntOp> {
+ WaitcntOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<WaitcntOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(WaitcntOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto getVal = [](Attribute attr) -> unsigned {
+ if (attr)
+ return cast<IntegerAttr>(attr).getInt();
+
+ // This value will be clamped to the maximum value for the chipset.
+ return 1024 * 1024;
+ };
+ unsigned vmcnt = getVal(adaptor.getVmcntAttr());
+ unsigned expcnt = getVal(adaptor.getExpcntAttr());
+ unsigned lgkmcnt = getVal(adaptor.getLgkmcntAttr());
+
+ FailureOr<unsigned> waitcnt =
+ encodeWaitcnt(chipset, vmcnt, expcnt, lgkmcnt);
+ if (failed(waitcnt))
+ return op.emitOpError("unsupported chipset");
+
+ rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
+ return success();
+ }
+};
+
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1871,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicUminOp>,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
- AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
- MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
+ AMDGPUDPPLowering, WaitcntOpLowering, LDSBarrierOpLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir
new file mode 100644
index 0000000000000..9c785670198ae
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
+// TODO: Add more chipsets support
+
+
+// CHECK-LABEL: func @waitcnt
+func.func @waitcnt() {
+ // GFX9: rocdl.s.waitcnt 53119
+ amdgpu.waitcnt
+
+ // GFX9: rocdl.s.waitcnt 3952
+ amdgpu.waitcnt vmcnt(0)
+
+ // GFX9: rocdl.s.waitcnt 53007
+ amdgpu.waitcnt expcnt(0)
+
+ // GFX9: rocdl.s.waitcnt 49279
+ amdgpu.waitcnt lgkmcnt(0)
+
+ return
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..b126b23cb8156 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -504,3 +504,16 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
func.return
}
+
+// CHECK-LABEL: func @waitcnt
+func.func @waitcnt() {
+ // CHECK: amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
+ // CHECK: amdgpu.waitcnt vmcnt(1)
+ // CHECK: amdgpu.waitcnt expcnt(2)
+ // CHECK: amdgpu.waitcnt lgkmcnt(3)
+ amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3)
+ amdgpu.waitcnt vmcnt(1)
+ amdgpu.waitcnt expcnt(2)
+ amdgpu.waitcnt lgkmcnt(3)
+ func.return
+}
|
// it to some common place. | ||
static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt, | ||
unsigned expcnt, unsigned lgkmcnt) { | ||
if (chipset.majorVersion == 9) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Thoughts on adding some doc based on
llvm-project/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
Lines 1128 to 1135 in 54492c2
/// \p Vmcnt = \p Waitcnt[3:0] (pre-gfx9) | |
/// \p Vmcnt = \p Waitcnt[15:14,3:0] (gfx9,10) | |
/// \p Vmcnt = \p Waitcnt[15:10] (gfx11) | |
/// \p Expcnt = \p Waitcnt[6:4] (pre-gfx11) | |
/// \p Expcnt = \p Waitcnt[2:0] (gfx11) | |
/// \p Lgkmcnt = \p Waitcnt[11:8] (pre-gfx10) | |
/// \p Lgkmcnt = \p Waitcnt[13:8] (gfx10) | |
/// \p Lgkmcnt = \p Waitcnt[9:4] (gfx11) |
/// \p Vmcnt = \p Waitcnt[15:14,3:0]
/// \p Expcnt = \p Waitcnt[6:4]
/// \p Lgkmcnt = \p Waitcnt[11:8]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done (and also added all other chipsets)
Signed-off-by: Ivan Butygin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but let's wait for an approval from @krzysz00
The main motivations is to pass vmcnt/expcnt/lgkmcnt values directly (similar to the asm format) and delegate architecture-dependent bitpacking to the amdgpu->rocdl lowering.