17
17
)
18
18
from torchao .prototype .moe_training .utils import (
19
19
_is_column_major ,
20
+ _to_mxfp8_per_group_rowwise ,
21
+ _to_mxfp8_per_group_colwise ,
20
22
)
21
23
from torchao .prototype .mx_formats .mx_tensor import to_mx
22
24
@@ -300,6 +302,7 @@ def forward(
300
302
301
303
# Store what we need for backward.
302
304
ctx .save_for_backward (A , B_t , offs )
305
+ ctx .block_size = block_size
303
306
ctx .out_dtype = out_dtype
304
307
305
308
# Perform scaled grouped GEMM and return result.
@@ -317,8 +320,52 @@ def forward(
317
320
return out
318
321
319
322
@staticmethod
320
- def backward (ctx , grad_output : torch .Tensor ):
321
- raise NotImplementedError
323
+ def backward (ctx , grad_out : torch .Tensor ):
324
+ A , B_t , offs = ctx .saved_tensors
325
+ block_size = ctx .block_size
326
+ out_dtype = ctx .out_dtype
327
+
328
+ # Compute grad_A.
329
+ # grad_A = grad_output @ B
330
+ # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
331
+ grad_out_scale , grad_out_mx = to_mx (
332
+ grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
333
+ )
334
+
335
+ B_t_scale , B_t_mx = _to_mxfp8_3d_expert_weights_dim1 (
336
+ B_t .transpose (- 2 , - 1 ).contiguous (),
337
+ block_size = block_size ,
338
+ elem_dtype = torch .float8_e4m3fn ,
339
+ )
340
+
341
+ grad_A = emulated_mxfp8_scaled_grouped_mm (
342
+ grad_out_mx ,
343
+ grad_out_scale ,
344
+ B_t_mx ,
345
+ B_t_scale ,
346
+ offs = offs ,
347
+ out_dtype = out_dtype ,
348
+ )
349
+
350
+ # Compute grad_B = grad_output_t @ A
351
+ grad_out_t_scale , grad_out_t_mx = _to_mxfp8_per_group_rowwise (
352
+ grad_out ,
353
+ offs = offs ,
354
+ block_size = block_size ,
355
+ )
356
+ A_scale , A_mx = _to_mxfp8_per_group_colwise (
357
+ A ,
358
+ offs = offs ,
359
+ block_size = block_size ,
360
+ )
361
+ grad_B = emulated_mxfp8_scaled_grouped_mm (
362
+ grad_out_t_mx ,
363
+ grad_out_t_scale ,
364
+ A_mx ,
365
+ A_scale ,
366
+ offs = offs ,
367
+ )
368
+ return grad_A , grad_B , None , None , None
322
369
323
370
324
371
def _to_mxfp8_3d_expert_weights_dim1 (
@@ -352,6 +399,26 @@ def emulated_mxfp8_scaled_grouped_mm(
352
399
offs : Optional [torch .Tensor ] = None ,
353
400
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
354
401
block_size : int = 32 ,
402
+ ) -> torch .Tensor :
403
+ if A_mx .ndim == 2 and B_t_mx .ndim == 3 :
404
+ return _emulated_mxfp8_scaled_grouped_mm_2d_3d (
405
+ A_mx , A_scale , B_t_mx , B_t_scale , offs , out_dtype , block_size
406
+ )
407
+ elif A_mx .ndim == 2 and B_t_mx .ndim == 2 :
408
+ return _emulated_mxfp8_scaled_grouped_mm_2d_2d (
409
+ A_mx , A_scale , B_t_mx , B_t_scale , offs , out_dtype , block_size
410
+ )
411
+ else :
412
+ raise NotImplemented
413
+
414
+ def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
415
+ A_mx : torch .Tensor ,
416
+ A_scale : torch .Tensor ,
417
+ B_t_mx : torch .Tensor ,
418
+ B_t_scale : torch .Tensor ,
419
+ offs : Optional [torch .Tensor ] = None ,
420
+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
421
+ block_size : int = 32 ,
355
422
) -> torch .Tensor :
356
423
# Dequantize input
357
424
# A_mx shape: (M, K)
@@ -397,3 +464,49 @@ def emulated_mxfp8_scaled_grouped_mm(
397
464
# Perform bf16 grouped GEMM.
398
465
out = torch ._grouped_mm (A , B_t , offs = offs , out_dtype = out_dtype )
399
466
return out
467
+
468
+
469
+ def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
470
+ A_mx : torch .Tensor ,
471
+ A_scale : torch .Tensor ,
472
+ B_t_mx : torch .Tensor ,
473
+ B_t_scale : torch .Tensor ,
474
+ offs : torch .Tensor ,
475
+ out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
476
+ block_size : int = 32 ,
477
+ ) -> torch .Tensor :
478
+ A = torch .empty (A_mx .shape , dtype = torch .bfloat16 , device = A_mx .device , requires_grad = A_mx .requires_grad )
479
+ B_t = torch .empty (B_t_mx .shape , dtype = torch .bfloat16 , device = B_t_mx .device , requires_grad = B_t_mx .requires_grad )
480
+
481
+ # Dequantize input per each scaling group
482
+ scales_start_idx = 0
483
+ group_start_idx = 0
484
+ for group_end_idx in offs .tolist ():
485
+ # -- Dequantize A tensor
486
+ # A_group shape: (M, group_size)
487
+ # A_scale shape: (M, group_size//block_size)
488
+ A_group = A_mx [:, group_start_idx :group_end_idx ]
489
+ A_group_shape = A_group .shape
490
+
491
+ # Get scales for this group.
492
+ # scales shape: (M, group_size//block_size)
493
+ group_size = group_end_idx - group_start_idx + 1
494
+ num_scale_cols = group_size // block_size
495
+ scales = A_scale [:, scales_start_idx : scales_start_idx + num_scale_cols ]
496
+
497
+ # Reshape to be able to do per-scaling group multiplication
498
+ # A_group shape: (M, group_size//block_size, block_size)
499
+ # A_scale shape: (M, group_size//block_size, 1)
500
+ A_group = A_group .reshape (* A_group .shape [:- 1 ], A_group .shape [- 1 ] // block_size , block_size )
501
+ scales = scales .unsqueeze (- 1 )
502
+
503
+ # Rescale and cast to bfloat16
504
+ A = A_group .to (torch .bfloat16 ) * scales .to (torch .bfloat16 )
505
+
506
+ # Reshape back to original shape
507
+ # A shape: (M, group_size)
508
+ A = A .reshape (A_group_shape )
509
+ A [:, group_start_idx :group_end_idx ] = A_group
510
+
511
+ # -- Dequantize B_t tensor
512
+
0 commit comments