Skip to content

Commit 7837095

Browse files
mxfp8 grouped mm backward pass
1 parent fd92301 commit 7837095

File tree

2 files changed

+202
-3
lines changed

2 files changed

+202
-3
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from torchao.prototype.moe_training.utils import (
1919
_is_column_major,
20+
_to_mxfp8_per_group_rowwise,
21+
_to_mxfp8_per_group_colwise,
2022
)
2123
from torchao.prototype.mx_formats.mx_tensor import to_mx
2224

@@ -298,6 +300,7 @@ def forward(
298300

299301
# Store what we need for backward.
300302
ctx.save_for_backward(A, B_t, offs)
303+
ctx.block_size = block_size
301304
ctx.out_dtype = out_dtype
302305

303306
# Perform scaled grouped GEMM and return result.
@@ -315,8 +318,52 @@ def forward(
315318
return out
316319

317320
@staticmethod
318-
def backward(ctx, grad_output: torch.Tensor):
319-
raise NotImplementedError
321+
def backward(ctx, grad_out: torch.Tensor):
322+
A, B_t, offs = ctx.saved_tensors
323+
block_size = ctx.block_size
324+
out_dtype = ctx.out_dtype
325+
326+
# Compute grad_A.
327+
# grad_A = grad_output @ B
328+
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
329+
grad_out_scale, grad_out_mx = to_mx(
330+
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
331+
)
332+
333+
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(
334+
B_t.transpose(-2, -1).contiguous(),
335+
block_size=block_size,
336+
elem_dtype=torch.float8_e4m3fn,
337+
)
338+
339+
grad_A = emulated_mxfp8_scaled_grouped_mm(
340+
grad_out_mx,
341+
grad_out_scale,
342+
B_t_mx,
343+
B_t_scale,
344+
offs=offs,
345+
out_dtype=out_dtype,
346+
)
347+
348+
# Compute grad_B = grad_output_t @ A
349+
grad_out_t_scale, grad_out_t_mx = _to_mxfp8_per_group_rowwise(
350+
grad_out,
351+
offs=offs,
352+
block_size=block_size,
353+
)
354+
A_scale, A_mx = _to_mxfp8_per_group_colwise(
355+
A,
356+
offs=offs,
357+
block_size=block_size,
358+
)
359+
grad_B = emulated_mxfp8_scaled_grouped_mm(
360+
grad_out_t_mx,
361+
grad_out_t_scale,
362+
A_mx,
363+
A_scale,
364+
offs=offs,
365+
)
366+
return grad_A, grad_B, None, None, None
320367

321368

322369
def _to_mxfp8_3d_expert_weights_dim1(
@@ -350,6 +397,26 @@ def emulated_mxfp8_scaled_grouped_mm(
350397
offs: Optional[torch.Tensor] = None,
351398
out_dtype: Optional[torch.dtype] = torch.bfloat16,
352399
block_size: int = 32,
400+
) -> torch.Tensor:
401+
if A_mx.ndim == 2 and B_t_mx.ndim == 3:
402+
return _emulated_mxfp8_scaled_grouped_mm_2d_3d(
403+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
404+
)
405+
elif A_mx.ndim == 2 and B_t_mx.ndim == 2:
406+
return _emulated_mxfp8_scaled_grouped_mm_2d_2d(
407+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
408+
)
409+
else:
410+
raise NotImplemented
411+
412+
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
413+
A_mx: torch.Tensor,
414+
A_scale: torch.Tensor,
415+
B_t_mx: torch.Tensor,
416+
B_t_scale: torch.Tensor,
417+
offs: Optional[torch.Tensor] = None,
418+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
419+
block_size: int = 32,
353420
) -> torch.Tensor:
354421
# Dequantize input
355422
# A_mx shape: (M, K)
@@ -395,3 +462,49 @@ def emulated_mxfp8_scaled_grouped_mm(
395462
# Perform bf16 grouped GEMM.
396463
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
397464
return out
465+
466+
467+
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
468+
A_mx: torch.Tensor,
469+
A_scale: torch.Tensor,
470+
B_t_mx: torch.Tensor,
471+
B_t_scale: torch.Tensor,
472+
offs: torch.Tensor,
473+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
474+
block_size: int = 32,
475+
) -> torch.Tensor:
476+
A = torch.empty(A_mx.shape, dtype=torch.bfloat16, device=A_mx.device, requires_grad=A_mx.requires_grad)
477+
B_t = torch.empty(B_t_mx.shape, dtype=torch.bfloat16, device=B_t_mx.device, requires_grad=B_t_mx.requires_grad)
478+
479+
# Dequantize input per each scaling group
480+
scales_start_idx = 0
481+
group_start_idx = 0
482+
for group_end_idx in offs.tolist():
483+
# -- Dequantize A tensor
484+
# A_group shape: (M, group_size)
485+
# A_scale shape: (M, group_size//block_size)
486+
A_group = A_mx[:, group_start_idx:group_end_idx]
487+
A_group_shape = A_group.shape
488+
489+
# Get scales for this group.
490+
# scales shape: (M, group_size//block_size)
491+
group_size = group_end_idx - group_start_idx + 1
492+
num_scale_cols = group_size // block_size
493+
scales = A_scale[:, scales_start_idx : scales_start_idx + num_scale_cols]
494+
495+
# Reshape to be able to do per-scaling group multiplication
496+
# A_group shape: (M, group_size//block_size, block_size)
497+
# A_scale shape: (M, group_size//block_size, 1)
498+
A_group = A_group.reshape(*A_group.shape[:-1], A_group.shape[-1] // block_size, block_size)
499+
scales = scales.unsqueeze(-1)
500+
501+
# Rescale and cast to bfloat16
502+
A = A_group.to(torch.bfloat16) * scales.to(torch.bfloat16)
503+
504+
# Reshape back to original shape
505+
# A shape: (M, group_size)
506+
A = A.reshape(A_group_shape)
507+
A[:, group_start_idx:group_end_idx] = A_group
508+
509+
# -- Dequantize B_t tensor
510+

torchao/prototype/moe_training/utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from torchao.float8.config import ScalingGranularity
77
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
8+
from torchao.prototype.mx_formats.mx_tensor import to_mx
89

9-
10+
# --- float8 rowwise scaling ---
1011
def _to_2d_jagged_float8_tensor_colwise(
1112
A_col_major: torch.Tensor,
1213
offs: torch.Tensor,
@@ -142,6 +143,91 @@ def _to_2d_jagged_float8_tensor_rowwise(
142143

143144
return x_fp8, x_scales
144145

146+
# --- mxfp8 scaling ---
147+
def _to_mxfp8_per_group_rowwise(
148+
x: torch.Tensor,
149+
offs: torch.Tensor,
150+
block_size: int = 32,
151+
) -> Tuple[torch.Tensor, torch.Tensor]:
152+
"""
153+
This is a reference implementation used for testing correctness, it is not performant.
154+
155+
This function converts the 2D input tensor a mxpf8 tensor along dim 0 with per-token-group scaling,
156+
where groups are determined based on the offsets.
157+
158+
Args:
159+
A (torch.Tensor): The input tensor to be converted to a jagged mxfp8 tensor.
160+
161+
Returns:
162+
A tuple containing the jagged mxpf8 tensor and the scales used for the conversion.
163+
"""
164+
assert x.ndim == 2, "input tensor must be 2D"
165+
166+
x_mx = torch.empty_like(x, dtype=torch.float8_e4m3fn)
167+
x_scales = None
168+
169+
start_idx = 0
170+
for end_idx in offs.tolist():
171+
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
172+
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
173+
174+
# Perform mxfp8 conversion on logically distinct subtensor.
175+
scales, mx_subtensor = to_mx(subtensor, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
176+
177+
# Store this portion of the resulting mxfp8 tensor and scales.
178+
x_mx[:, start_idx:end_idx] = mx_subtensor
179+
if x_scales is None:
180+
x_scales = scales
181+
else:
182+
x_scales = torch.cat((x_scales, scales))
183+
184+
# Update start index for next group.
185+
start_idx = end_idx
186+
187+
return x_mx, x_scales
188+
189+
def _to_mxfp8_per_group_colwise(
190+
A_col_major: torch.Tensor, # (K, N)
191+
offs: torch.Tensor,
192+
block_size: int = 32,
193+
) -> Tuple[torch.Tensor, torch.Tensor]:
194+
"""
195+
This is a reference implementation used for testing correctness, it is not performant.
196+
197+
This function converts the 2D input tensor a mxpf8 tensor along dim 1 with per-token-group scaling,
198+
where groups are determined based on the offsets.
199+
200+
Args:
201+
A (torch.Tensor): The input tensor to be converted to a mxfp8 tensor.
202+
203+
Returns:
204+
A tuple containing the mxpf8 tensor and the scales used for the conversion.
205+
"""
206+
assert A_col_major.ndim == 2, "A must be 2D"
207+
208+
A_fp8_col_major = torch.empty_like(A_col_major, dtype=torch.float8_e4m3fn)
209+
A_scales = None
210+
211+
start_idx = 0
212+
for end_idx in offs.tolist():
213+
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
214+
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, N)
215+
216+
# Convert to mxfp8 along dim1, by transposing, converting, and transposing back.
217+
scales, mx_subtensor = to_mx(subtensor.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, block_size=block_size)
218+
scales, mx_subtensor = scales.transpose(-2, -1), mx_subtensor.transpose(-2, -1)
219+
220+
# Store this portion of the resulting mxfp8 tensor and scales.
221+
A_fp8_col_major[start_idx:end_idx, :] = mx_subtensor
222+
if A_scales is None:
223+
A_scales = mx_subtensor
224+
else:
225+
A_scales = torch.cat((A_scales, scales))
226+
227+
# Update start index for next group.
228+
start_idx = end_idx
229+
230+
return A_fp8_col_major, A_scales
145231

146232
def _is_column_major(x: torch.Tensor) -> bool:
147233
"""

0 commit comments

Comments
 (0)