Skip to content

Commit 9052a85

Browse files
authored
[mlir][AMDGPU] Infer canonical layouts for fat_raw_buffer_cast resetOffset (#149867)
When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization. Now, the return type inference will put in an identity layout after offset stripping for 1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and 2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.
1 parent 860ff87 commit 9052a85

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1717
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/Diagnostics.h"
@@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
8990
auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
9091
if (!stridedLayout)
9192
return failure();
92-
mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
93+
MemRefLayoutAttrInterface newLayout =
94+
StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
95+
// Special case: if resetting the offset causes the strided layout to become
96+
// the identity layout, then reset to the identity layout.
97+
// TODO: this'll get a lot simpler when we have the contiguous layout.
98+
SmallVector<int64_t> stridesIfIdentity;
99+
if (source.hasStaticShape()) {
100+
stridesIfIdentity = computeSuffixProduct(source.getShape());
101+
} else if (source.getRank() <= 1) {
102+
stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
103+
}
104+
if (stridesIfIdentity == stridedLayout.getStrides()) {
105+
newLayout = AffineMapAttr::get(
106+
AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
107+
}
108+
mb.setLayout(newLayout);
93109
}
94110
return (MemRefType)(mb);
95111
}

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
6868
}
6969

7070
// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
71-
func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
71+
func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
7272
// CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
7373
// CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
7474
// CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
@@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
7777
// CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
7878
// CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
7979
// CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
80-
%ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
81-
return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
80+
%ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
81+
return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
8282
}
8383

8484
// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
360360
// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
361361
// CHECK-SAME: boundsCheck(false)
362362
// CHECK-SAME: resetOffset
363-
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
363+
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
364364
%ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
365-
: memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
366-
func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
365+
: memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
366+
func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
367+
}
368+
369+
// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
370+
// CHECK: amdgpu.fat_raw_buffer_cast
371+
func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
372+
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
373+
: memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
374+
func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
375+
}
376+
377+
// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
378+
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
379+
// CHECK: return %[[ret]]
380+
func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
381+
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
382+
: memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
383+
func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
384+
}
385+
386+
// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
387+
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
388+
// CHECK: return %[[ret]]
389+
func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
390+
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
391+
: memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
392+
func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
393+
}
394+
395+
// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
396+
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
397+
// CHECK: return %[[ret]]
398+
func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
399+
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
400+
: memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
401+
func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
402+
}
403+
404+
// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
405+
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
406+
// CHECK: return %[[ret]]
407+
func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
408+
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
409+
: memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
410+
func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
367411
}
368412

369413
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1

0 commit comments

Comments
 (0)