17
17
)
18
18
from torchao .prototype .moe_training .utils import (
19
19
_is_column_major ,
20
+ _to_mxfp8_per_group_rowwise ,
21
+ _to_mxfp8_per_group_colwise ,
20
22
)
21
23
from torchao .prototype .mx_formats .mx_tensor import to_mx
22
24
@@ -298,6 +300,7 @@ def forward(
298
300
299
301
# Store what we need for backward.
300
302
ctx .save_for_backward (A , B_t , offs )
303
+ ctx .block_size = block_size
301
304
ctx .out_dtype = out_dtype
302
305
303
306
# Perform scaled grouped GEMM and return result.
@@ -315,8 +318,52 @@ def forward(
315
318
return out
316
319
317
320
@staticmethod
318
- def backward (ctx , grad_output : torch .Tensor ):
319
- raise NotImplementedError
321
+ def backward (ctx , grad_out : torch .Tensor ):
322
+ A , B_t , offs = ctx .saved_tensors
323
+ block_size = ctx .block_size
324
+ out_dtype = ctx .out_dtype
325
+
326
+ # Compute grad_A.
327
+ # grad_A = grad_output @ B
328
+ # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
329
+ grad_out_scale , grad_out_mx = to_mx (
330
+ grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
331
+ )
332
+
333
+ B_t_scale , B_t_mx = _to_mxfp8_3d_expert_weights_dim1 (
334
+ B_t .transpose (- 2 , - 1 ).contiguous (),
335
+ block_size = block_size ,
336
+ elem_dtype = torch .float8_e4m3fn ,
337
+ )
338
+
339
+ grad_A = emulated_mxfp8_scaled_grouped_mm (
340
+ grad_out_mx ,
341
+ grad_out_scale ,
342
+ B_t_mx ,
343
+ B_t_scale ,
344
+ offs = offs ,
345
+ out_dtype = out_dtype ,
346
+ )
347
+
348
+ # Compute grad_B = grad_output_t @ A
349
+ grad_out_t_scale , grad_out_t_mx = _to_mxfp8_per_group_rowwise (
350
+ grad_out ,
351
+ offs = offs ,
352
+ block_size = block_size ,
353
+ )
354
+ A_scale , A_mx = _to_mxfp8_per_group_colwise (
355
+ A ,
356
+ offs = offs ,
357
+ block_size = block_size ,
358
+ )
359
+ grad_B = emulated_mxfp8_scaled_grouped_mm (
360
+ grad_out_t_mx ,
361
+ grad_out_t_scale ,
362
+ A_mx ,
363
+ A_scale ,
364
+ offs = offs ,
365
+ )
366
+ return grad_A , grad_B , None , None , None
320
367
321
368
322
369
def _to_mxfp8_3d_expert_weights_dim1 (
@@ -350,6 +397,26 @@ def emulated_mxfp8_scaled_grouped_mm(
350
397
offs : Optional [torch .Tensor ] = None ,
351
398
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
352
399
block_size : int = 32 ,
400
+ ) -> torch .Tensor :
401
+ if A_mx .ndim == 2 and B_t_mx .ndim == 3 :
402
+ return _emulated_mxfp8_scaled_grouped_mm_2d_3d (
403
+ A_mx , A_scale , B_t_mx , B_t_scale , offs , out_dtype , block_size
404
+ )
405
+ elif A_mx .ndim == 2 and B_t_mx .ndim == 2 :
406
+ return _emulated_mxfp8_scaled_grouped_mm_2d_2d (
407
+ A_mx , A_scale , B_t_mx , B_t_scale , offs , out_dtype , block_size
408
+ )
409
+ else :
410
+ raise NotImplemented
411
+
412
+ def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
413
+ A_mx : torch .Tensor ,
414
+ A_scale : torch .Tensor ,
415
+ B_t_mx : torch .Tensor ,
416
+ B_t_scale : torch .Tensor ,
417
+ offs : Optional [torch .Tensor ] = None ,
418
+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
419
+ block_size : int = 32 ,
353
420
) -> torch .Tensor :
354
421
# Dequantize input
355
422
# A_mx shape: (M, K)
@@ -395,3 +462,49 @@ def emulated_mxfp8_scaled_grouped_mm(
395
462
# Perform bf16 grouped GEMM.
396
463
out = torch ._grouped_mm (A , B_t , offs = offs , out_dtype = out_dtype )
397
464
return out
465
+
466
+
467
+ def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
468
+ A_mx : torch .Tensor ,
469
+ A_scale : torch .Tensor ,
470
+ B_t_mx : torch .Tensor ,
471
+ B_t_scale : torch .Tensor ,
472
+ offs : torch .Tensor ,
473
+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
474
+ block_size : int = 32 ,
475
+ ) -> torch .Tensor :
476
+ A = torch .empty (A_mx .shape , dtype = torch .bfloat16 , device = A_mx .device , requires_grad = A_mx .requires_grad )
477
+ B_t = torch .empty (B_t_mx .shape , dtype = torch .bfloat16 , device = B_t_mx .device , requires_grad = B_t_mx .requires_grad )
478
+
479
+ # Dequantize input per each scaling group
480
+ scales_start_idx = 0
481
+ group_start_idx = 0
482
+ for group_end_idx in offs .tolist ():
483
+ # -- Dequantize A tensor
484
+ # A_group shape: (M, group_size)
485
+ # A_scale shape: (M, group_size//block_size)
486
+ A_group = A_mx [:, group_start_idx :group_end_idx ]
487
+ A_group_shape = A_group .shape
488
+
489
+ # Get scales for this group.
490
+ # scales shape: (M, group_size//block_size)
491
+ group_size = group_end_idx - group_start_idx + 1
492
+ num_scale_cols = group_size // block_size
493
+ scales = A_scale [:, scales_start_idx : scales_start_idx + num_scale_cols ]
494
+
495
+ # Reshape to be able to do per-scaling group multiplication
496
+ # A_group shape: (M, group_size//block_size, block_size)
497
+ # A_scale shape: (M, group_size//block_size, 1)
498
+ A_group = A_group .reshape (* A_group .shape [:- 1 ], A_group .shape [- 1 ] // block_size , block_size )
499
+ scales = scales .unsqueeze (- 1 )
500
+
501
+ # Rescale and cast to bfloat16
502
+ A = A_group .to (torch .bfloat16 ) * scales .to (torch .bfloat16 )
503
+
504
+ # Reshape back to original shape
505
+ # A shape: (M, group_size)
506
+ A = A .reshape (A_group_shape )
507
+ A [:, group_start_idx :group_end_idx ] = A_group
508
+
509
+ # -- Dequantize B_t tensor
510
+
0 commit comments