Skip to content

Commit 8070e9b

Browse files
authored
[MLIR][XeGPU] Refactor xegpu-wg-to-sg tests (#149204)
This PR refactors the xegpu-wg-to-sg.mlir tests to use larger shapes which resemble closer to workgroup level programming.
1 parent 56c93a4 commit 8070e9b

File tree

2 files changed

+178
-203
lines changed

2 files changed

+178
-203
lines changed

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 67 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,122 +2,117 @@
22

33
gpu.module @test_round_robin_assignment {
44
// CHECK-LABEL: create_nd_tdesc
5-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
6-
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
7-
// CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
8-
// CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
5+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
6+
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
7+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
8+
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
99
// CHECK-NOT: xegpu.create_nd_tdesc
10-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
11-
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
10+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
11+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
1212
gpu.return
1313
}
1414

1515
// CHECK-LABEL: load_nd_tdesc
16-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
17-
gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
18-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
19-
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
20-
// CHECK-COUNT-12: xegpu.load_nd %{{.*}}
21-
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
22-
// CHECK-SAME-COUNT-12: -> vector<2x2xf32>
16+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
17+
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
18+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
19+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
20+
// CHECK-COUNT-4: xegpu.load_nd %{{.*}}
21+
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
22+
// CHECK-SAME-COUNT-4: -> vector<16x16xf32>
2323
// CHECK-NOT: xegpu.load_nd
2424
%load = xegpu.load_nd %tdesc
25-
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
26-
-> vector<24x32xf32>
25+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
26+
-> vector<256x128xf32>
2727
gpu.return
2828
}
2929

3030
// CHECK-LABEL: store_nd
31-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
32-
gpu.func @store_nd(%src: memref<24x32xf32>) {
33-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
34-
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
35-
// CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
36-
// CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
31+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
32+
gpu.func @store_nd(%src: memref<256x128xf32>) {
33+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
34+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
35+
// CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
36+
// CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
3737
// CHECK-NOT : xegpu.store_nd
3838
%load = xegpu.load_nd %tdesc
39-
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
40-
-> vector<24x32xf32>
39+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
40+
-> vector<256x128xf32>
4141
xegpu.store_nd %load, %tdesc
42-
: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
42+
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
4343
gpu.return
4444
}
4545

4646
// CHECK-LABEL: update_nd
47-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
48-
gpu.func @update_nd(%src: memref<24x32xf32>){
49-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
50-
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
51-
// CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
52-
// CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
47+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
48+
gpu.func @update_nd(%src: memref<256x128xf32>){
49+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
50+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
51+
// CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
52+
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
5353
// CHECK-NOT: xegpu.update_nd_offset
5454
%update = xegpu.update_nd_offset %tdesc, [0, 16]
55-
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
55+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
5656
gpu.return
5757
}
5858

5959
// CHECK-LABEL: dpas
60-
// CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
61-
gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
62-
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
63-
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
60+
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
61+
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
62+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
63+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
6464
// CHECK-NOT: xegpu.create_nd_tdesc
65-
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
66-
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
67-
// CHECK-NOT: xegpu.create_nd_tdesc
68-
// CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
69-
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
65+
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
66+
// CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
7067
// CHECK-NOT: xegpu.create_nd_tdesc
7168
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
72-
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
73-
// CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
69+
// CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
70+
// CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
7471
// CHECK-NOT: xegpu.dpas
75-
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
76-
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
72+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
73+
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
7774
%load_a = xegpu.load_nd %tdesc_a
78-
: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
79-
-> vector<8x8xf32>
80-
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
81-
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
75+
: !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
76+
-> vector<256x128xf16>
77+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
78+
-> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
8279
%load_b = xegpu.load_nd %tdesc_b
83-
: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
84-
-> vector<8x8xf32>
85-
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
86-
-> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
80+
: !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
81+
-> vector<128x256xf16>
8782
%dpas = xegpu.dpas %load_a, %load_b
88-
{layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
89-
: vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
83+
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
84+
: vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
9085
gpu.return
9186
}
9287

9388
// CHECK-LABEL: prefetch_nd_tdesc
94-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
95-
gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
96-
// CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
97-
// CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
89+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
90+
gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
91+
// CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
92+
// CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
9893
// CHECK-NOT: xegpu.prefetch_nd
99-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
100-
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
94+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
95+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
10196
xegpu.prefetch_nd %tdesc
102-
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
97+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
10398
gpu.return
10499
}
105100

106101
// CHECK-LABEL: broadcast
107-
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
108-
gpu.func @broadcast(%src: memref<24x1xf32>) {
109-
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
110-
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
102+
// CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
103+
gpu.func @broadcast(%src: memref<128x1xf32>) {
104+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
105+
-> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
111106
%load = xegpu.load_nd %tdesc
112-
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
113-
-> vector<24x1xf32>
114-
// CHECK-COUNT-3: vector.broadcast {{.*}}
115-
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
116-
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
107+
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
108+
-> vector<128x1xf32>
109+
// CHECK-COUNT-2: vector.broadcast {{.*}}
110+
// CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
111+
// CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
117112
// CHECK-NOT: vector.broadcast
118113
%broadcast = vector.broadcast %load
119-
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
120-
: vector<24x1xf32> to vector<24x8xf32>
114+
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
115+
: vector<128x1xf32> to vector<128x64xf32>
121116
gpu.return
122117
}
123118

0 commit comments

Comments
 (0)