16
16
import triton
17
17
import triton .language as tl
18
18
19
- from torchao .prototype .moe_training .utils import _is_column_major
20
-
21
19
EPS = 1e-12
22
20
23
21
FP8_DTYPE_MAP = {
33
31
torch .float64 : tl .float64 ,
34
32
}
35
33
36
- block_sizes = [128 , 256 ]
34
+ block_sizes = [1 , 16 , 32 , 64 ]
35
+ block_sizes_iter = [32 , 64 , 128 , 256 ]
36
+ num_warps = [1 , 4 ]
37
+ num_stages = [2 , 3 ]
37
38
kernel_configs_2D = [
38
39
triton .Config (
39
- {"BLOCK_SIZE_ROWS" : block_size_rows , "BLOCK_SIZE_COLS" : block_size_cols }
40
+ {"BLOCK_SIZE" : block_size , "BLOCK_SIZE_ITER" : block_size_iter },
41
+ num_warps = warps ,
42
+ num_stages = stages ,
40
43
)
41
- for block_size_rows in block_sizes
42
- for block_size_cols in block_sizes
44
+ for block_size in block_sizes
45
+ for block_size_iter in block_sizes_iter
46
+ for warps in num_warps
47
+ for stages in num_stages
43
48
]
44
49
45
50
from torch .library import triton_op , wrap_triton
@@ -68,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales(
68
73
- jagged rowwise scales (i.e., rowwise scales for each group)
69
74
"""
70
75
assert hp_tensor .ndim == 2 , "input tensor must be 2D"
71
- assert hp_tensor .is_contiguous (), "input tensor must be contiguous"
72
76
73
77
num_elements = hp_tensor .numel ()
74
78
tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
@@ -81,16 +85,14 @@ def triton_fp8_row_major_jagged_rowwise_scales(
81
85
n_groups = offsets .numel ()
82
86
83
87
# allocate on-device buffers for output and scales
84
- output_buffer = torch .empty_like (
85
- hp_tensor , dtype = output_dtype , device = hp_tensor .device
86
- )
88
+ output_buffer = torch .empty ((m , k ), dtype = output_dtype , device = hp_tensor .device )
87
89
scales_buffer = torch .empty (
88
90
(m * n_groups ), dtype = torch .float32 , device = hp_tensor .device
89
91
)
90
92
91
93
# parallelize across rows and groups (offsets)
92
94
grid = lambda meta : (
93
- triton .cdiv (m , meta ["BLOCK_SIZE_ROWS " ]),
95
+ triton .cdiv (m , meta ["BLOCK_SIZE " ]),
94
96
offsets .numel (),
95
97
)
96
98
wrap_triton (_triton_fp8_row_major_jagged_rowwise_scales )[grid ](
@@ -115,7 +117,13 @@ def triton_fp8_row_major_jagged_rowwise_scales(
115
117
return output_buffer , scales_buffer
116
118
117
119
118
- @triton .autotune (configs = kernel_configs_2D , key = ["num_elements" ])
120
+ # This kernel is used on grad_output.t() which has shape (K, M),
121
+ # before the calculation `grad_B = grad_output_t @ input`.
122
+ # However, in this code, we use the conventional dim names (M, K)
123
+ # so the kernel is easily interpretable in a standalone fasion.
124
+ # The tokens per expert will vary per iteration, so don't want
125
+ # to recompile on `token` dim (K, in this case) changes.
126
+ @triton .autotune (configs = kernel_configs_2D , key = ["M" ])
119
127
@triton .jit
120
128
def _triton_fp8_row_major_jagged_rowwise_scales (
121
129
input_ptr ,
@@ -134,8 +142,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
134
142
input_dtype : tl .constexpr ,
135
143
output_dtype : tl .constexpr ,
136
144
round_scales_to_power_of_2 : tl .constexpr ,
137
- BLOCK_SIZE_ROWS : tl .constexpr ,
138
- BLOCK_SIZE_COLS : tl .constexpr ,
145
+ BLOCK_SIZE : tl .constexpr ,
146
+ BLOCK_SIZE_ITER : tl .constexpr ,
139
147
EPS : tl .constexpr ,
140
148
):
141
149
# parallel across rows and groups (offsets)
@@ -147,12 +155,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
147
155
offsets_ptr + offset_idx - 1 , mask = offset_idx > 0 , other = 0
148
156
)
149
157
group_col_end_idx = tl .load (offsets_ptr + offset_idx )
150
- block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl .arange (0 , BLOCK_SIZE_ROWS )
158
+ block_row_offs = block_row_id * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
151
159
152
160
# compute rowwise amaxes for this group
153
- amax_buffer = tl .zeros ((BLOCK_SIZE_ROWS ,), dtype = input_dtype )
154
- for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_COLS ):
155
- block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_COLS )
161
+ amax_buffer = tl .zeros ((BLOCK_SIZE ,), dtype = input_dtype )
162
+ for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_ITER ):
163
+ block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
156
164
block_offs = (
157
165
block_row_offs [:, None ] * stride_input_row
158
166
+ block_col_offs [None , :] * stride_input_col
@@ -180,12 +188,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
180
188
# store rowwise scales for each group in contiguous memory:
181
189
# [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
182
190
scales_offs = block_row_offs + (M * offset_idx )
183
- scales_mask = tl .arange (0 , BLOCK_SIZE_ROWS ) < M
191
+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < M
184
192
tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
185
193
186
194
# perform float8 conversion for this group
187
- for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_COLS ):
188
- block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_COLS )
195
+ for col_start_idx in range (group_col_start_idx , group_col_end_idx , BLOCK_SIZE_ITER ):
196
+ block_col_offs = col_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
189
197
block_offs = (
190
198
block_row_offs [:, None ] * stride_input_row
191
199
+ block_col_offs [None , :] * stride_input_col
@@ -230,7 +238,6 @@ def triton_fp8_col_major_jagged_colwise_scales(
230
238
- jagged column-wise scales (i.e., column-wise scales for each group)
231
239
"""
232
240
assert hp_tensor .ndim == 2 , "input tensor must be 2D"
233
- assert _is_column_major (hp_tensor ), "input tensor must be column-major"
234
241
235
242
num_elements = hp_tensor .numel ()
236
243
tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
@@ -242,17 +249,18 @@ def triton_fp8_col_major_jagged_colwise_scales(
242
249
k , n = hp_tensor .shape
243
250
n_groups = offsets .numel ()
244
251
245
- # allocate on-device buffers for output and scales
252
+ # Output buffer in column major
246
253
output_buffer = torch .empty_like (
247
254
hp_tensor , dtype = output_dtype , device = hp_tensor .device
248
- )
255
+ ).as_strided (hp_tensor .size (), (1 , k ))
256
+
249
257
scales_buffer = torch .empty (
250
258
(n * n_groups ), dtype = torch .float32 , device = hp_tensor .device
251
259
)
252
260
253
261
# parallelize across columns and groups (offsets)
254
262
grid = lambda meta : (
255
- triton .cdiv (n , meta ["BLOCK_SIZE_COLS " ]),
263
+ triton .cdiv (n , meta ["BLOCK_SIZE " ]),
256
264
offsets .numel (),
257
265
)
258
266
wrap_triton (_triton_fp8_col_major_jagged_colwise_scales )[grid ](
@@ -277,7 +285,11 @@ def triton_fp8_col_major_jagged_colwise_scales(
277
285
return output_buffer , scales_buffer
278
286
279
287
280
- @triton .autotune (configs = kernel_configs_2D , key = ["num_elements" ])
288
+ # This kernel is used on `input` which has shape (M, K),
289
+ # before the calculation `grad_B = grad_output_t @ input`.
290
+ # The tokens per expert will vary per iteration, so don't want
291
+ # to recompile on `token` dim (M) changes.
292
+ @triton .autotune (configs = kernel_configs_2D , key = ["K" ])
281
293
@triton .jit
282
294
def _triton_fp8_col_major_jagged_colwise_scales (
283
295
input_ptr ,
@@ -296,8 +308,8 @@ def _triton_fp8_col_major_jagged_colwise_scales(
296
308
input_dtype : tl .constexpr ,
297
309
output_dtype : tl .constexpr ,
298
310
round_scales_to_power_of_2 : tl .constexpr ,
299
- BLOCK_SIZE_ROWS : tl .constexpr ,
300
- BLOCK_SIZE_COLS : tl .constexpr ,
311
+ BLOCK_SIZE : tl .constexpr ,
312
+ BLOCK_SIZE_ITER : tl .constexpr ,
301
313
EPS : tl .constexpr ,
302
314
):
303
315
# parallel across columns and groups (offsets)
@@ -309,12 +321,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
309
321
offsets_ptr + offset_idx - 1 , mask = offset_idx > 0 , other = 0
310
322
)
311
323
group_row_end_idx = tl .load (offsets_ptr + offset_idx )
312
- block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl .arange (0 , BLOCK_SIZE_COLS )
324
+ block_col_offs = block_col_id * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
313
325
314
326
# compute colwise amaxes for this group
315
- amax_buffer = tl .zeros ((BLOCK_SIZE_COLS ,), dtype = input_dtype )
316
- for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ROWS ):
317
- block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ROWS )
327
+ amax_buffer = tl .zeros ((BLOCK_SIZE ,), dtype = input_dtype )
328
+ for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ITER ):
329
+ block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
318
330
block_offs = (
319
331
block_row_offs [:, None ] * stride_input_row
320
332
+ block_col_offs [None , :] * stride_input_col
@@ -343,12 +355,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
343
355
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
344
356
# note: input tensor is in col-major memory layout.
345
357
scales_offs = block_col_offs + (N * offset_idx )
346
- scales_mask = tl .arange (0 , BLOCK_SIZE_COLS ) < N
358
+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < N
347
359
tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
348
360
349
361
# perform float8 conversion for this group
350
- for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ROWS ):
351
- block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ROWS )
362
+ for row_start_idx in range (group_row_start_idx , group_row_end_idx , BLOCK_SIZE_ITER ):
363
+ block_row_offs = row_start_idx + tl .arange (0 , BLOCK_SIZE_ITER )
352
364
block_offs = (
353
365
block_row_offs [:, None ] * stride_input_row
354
366
+ block_col_offs [None , :] * stride_input_col
0 commit comments