|
11 | 11 | import iris
|
12 | 12 |
|
13 | 13 |
|
14 |
| -@triton.jit |
15 |
| -def tile_id_to_index_range( |
16 |
| - tile_id, |
17 |
| - M, |
18 |
| - N, |
19 |
| - BLOCK_SIZE_M: tl.constexpr, |
20 |
| - BLOCK_SIZE_N: tl.constexpr, |
21 |
| - GROUP_SIZE_M: tl.constexpr, |
22 |
| -): |
23 |
| - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
24 |
| - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
25 |
| - num_pid_in_group = GROUP_SIZE_M * num_pid_n |
26 |
| - |
27 |
| - group_id = tile_id // num_pid_in_group |
28 |
| - first_pid_m = group_id * GROUP_SIZE_M |
29 |
| - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
30 |
| - |
31 |
| - tile_in_group = tile_id % num_pid_in_group |
32 |
| - pid_m = first_pid_m + (tile_in_group % group_size_m) |
33 |
| - pid_n = tile_in_group // group_size_m |
34 |
| - |
35 |
| - rm_start = pid_m * BLOCK_SIZE_M |
36 |
| - rn_start = pid_n * BLOCK_SIZE_N |
37 |
| - |
38 |
| - # clamp to the maximum valid index (M-1, N-1) |
39 |
| - max_m = M - 1 |
40 |
| - max_n = N - 1 |
41 |
| - |
42 |
| - # generate indices |
43 |
| - rm = rm_start + tl.arange(0, BLOCK_SIZE_M) |
44 |
| - rn = rn_start + tl.arange(0, BLOCK_SIZE_N) |
45 |
| - |
46 |
| - rm = tl.minimum(rm, max_m) |
47 |
| - rn = tl.minimum(rn, max_n) |
48 |
| - |
49 |
| - # rm_mod = rm % M |
50 |
| - # rm = tl.max_contiguous(tl.multiple_of(rm_mod, BLOCK_SIZE_M), BLOCK_SIZE_M) |
51 |
| - |
52 |
| - return rm, rn, rm_start, rn_start |
53 |
| - |
54 |
| - |
55 |
| -@triton.jit |
56 |
| -def offset_for_tile(local_tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M_local, N_local): |
57 |
| - rm, rn, rm_start, rn_start = tile_id_to_index_range( |
58 |
| - local_tile_id, M_local, N_local, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M |
59 |
| - ) |
60 |
| - c_mask = (rm[:, None] < M_local) & (rn[None, :] < N_local) |
61 |
| - return rm, rn, c_mask, rm_start, rn_start |
62 |
| - |
63 |
| - |
64 |
| -@triton.jit |
65 |
| -def extract_submask_and_offset( |
66 |
| - rm, |
67 |
| - rn, |
68 |
| - mask, |
69 |
| - rm_start, |
70 |
| - rn_start, |
71 |
| - start_row, |
72 |
| - start_col, |
73 |
| - SUB_BLOCK_SIZE_M: tl.constexpr, |
74 |
| - SUB_BLOCK_SIZE_N: tl.constexpr, |
75 |
| - BLOCK_SIZE_M: tl.constexpr, |
76 |
| - BLOCK_SIZE_N: tl.constexpr, |
77 |
| - stride_cm_local: tl.constexpr, |
78 |
| - stride_cn_local: tl.constexpr, |
79 |
| -): |
80 |
| - # Create indices for the sub-block |
81 |
| - sub_rm = tl.arange(0, SUB_BLOCK_SIZE_M) + start_row |
82 |
| - sub_rn = tl.arange(0, SUB_BLOCK_SIZE_N) + start_col |
83 |
| - |
84 |
| - # Create a 2D grid of indices for the sub-block |
85 |
| - sub_rm_2d = sub_rm[:, None] # Shape: (SUB_BLOCK_SIZE_M, 1) |
86 |
| - sub_rn_2d = sub_rn[None, :] # Shape: (1, SUB_BLOCK_SIZE_N) |
87 |
| - |
88 |
| - # Compute the sub-mask |
89 |
| - sub_mask = (sub_rm_2d < BLOCK_SIZE_M) & (sub_rn_2d < BLOCK_SIZE_N) |
90 |
| - |
91 |
| - # Compute the sub-offset relative to the start of the tile |
92 |
| - sub_offset = ((rm_start + sub_rm_2d) * stride_cm_local) + ((rn_start + sub_rn_2d) * stride_cn_local) |
93 |
| - |
94 |
| - return sub_mask, sub_offset |
95 |
| - |
96 |
| - |
97 | 14 | @triton.jit()
|
98 | 15 | def persistent_gemm_all_scatter(
|
99 | 16 | A,
|
@@ -166,8 +83,8 @@ def persistent_gemm_all_scatter(
|
166 | 83 | A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
|
167 | 84 | B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
|
168 | 85 |
|
169 |
| - tl.assume(pid_m > 0) |
170 |
| - tl.assume(pid_n > 0) |
| 86 | + tl.assume(pid_m >= 0) |
| 87 | + tl.assume(pid_n >= 0) |
171 | 88 |
|
172 | 89 | loop_k = tl.cdiv(K, BLOCK_SIZE_K)
|
173 | 90 | if not EVEN_K:
|
@@ -195,51 +112,39 @@ def persistent_gemm_all_scatter(
|
195 | 112 | # Accumulator registers with C results
|
196 | 113 | c = acc.to(C.type.element_ty)
|
197 | 114 |
|
198 |
| - rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) |
199 |
| - |
200 |
| - # Calculate the number of sub-tiles in each dimension |
201 |
| - num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) |
202 |
| - num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) |
203 |
| - total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n |
204 |
| - |
205 |
| - for sub_tile_idx in range(0, total_sub_tiles): |
206 |
| - # Calculate start_row and start_col for the current sub-tile |
207 |
| - start_row = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M |
208 |
| - start_col = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N |
209 |
| - |
210 |
| - # Translate to global |
211 |
| - sub_mask, global_offset = extract_submask_and_offset( |
212 |
| - rm, |
213 |
| - rn + cur_rank * N, |
214 |
| - mask, |
215 |
| - rm_start, |
216 |
| - rn_start + cur_rank * N, |
217 |
| - start_row, |
218 |
| - start_col, |
219 |
| - BLOCK_SIZE_M, |
220 |
| - BLOCK_SIZE_N, |
221 |
| - BLOCK_SIZE_M, |
222 |
| - BLOCK_SIZE_N, |
223 |
| - stride_cm_global, |
224 |
| - stride_cn_global, |
225 |
| - ) |
226 |
| - |
227 |
| - # Timestamp for GEMM before store |
228 |
| - if COLLECT_TIMESTAMPS: |
229 |
| - timestamp = read_realtime() |
230 |
| - tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) |
231 |
| - |
232 |
| - # Store data to the global result using puts |
233 |
| - for remote_rank in range(world_size): |
234 |
| - if remote_rank == cur_rank: |
235 |
| - # For the current rank, we can use store |
236 |
| - tl.store(c_global + global_offset, c, mask=sub_mask) |
237 |
| - else: |
238 |
| - iris.store( |
239 |
| - c_global + global_offset, |
240 |
| - c, |
241 |
| - cur_rank, |
242 |
| - remote_rank, |
243 |
| - heap_bases, |
244 |
| - mask=sub_mask, |
245 |
| - ) |
| 115 | + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 116 | + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 117 | + |
| 118 | + # Add compiler hints |
| 119 | + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) |
| 120 | + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) |
| 121 | + |
| 122 | + # Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N) |
| 123 | + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) |
| 124 | + |
| 125 | + # Calculate the "global" offset of C based on the rank. |
| 126 | + # Note how the N-dimension is being multiplied by current rank. |
| 127 | + # This is because each rank is computing a portion of the N-dimension |
| 128 | + # locally and then scattering it to all other ranks to complete |
| 129 | + # the global N-dimension. |
| 130 | + global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global |
| 131 | + |
| 132 | + # Timestamp for GEMM before store |
| 133 | + if COLLECT_TIMESTAMPS: |
| 134 | + timestamp = read_realtime() |
| 135 | + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) |
| 136 | + |
| 137 | + # Store data to the global result using puts |
| 138 | + for remote_rank in range(world_size): |
| 139 | + if remote_rank == cur_rank: |
| 140 | + # For the current rank, we can use store |
| 141 | + tl.store(c_global + global_offset, c, mask=sub_mask) |
| 142 | + else: |
| 143 | + iris.store( |
| 144 | + c_global + global_offset, |
| 145 | + c, |
| 146 | + cur_rank, |
| 147 | + remote_rank, |
| 148 | + heap_bases, |
| 149 | + mask=sub_mask, |
| 150 | + ) |
0 commit comments