Skip to content

Commit c8e0897

Browse files
mxfp8 grouped mm backward pass
stack-info: PR: #2632, branch: danielvegamyhre/stack/24
1 parent 5376f65 commit c8e0897

File tree

2 files changed

+202
-3
lines changed

2 files changed

+202
-3
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from torchao.prototype.moe_training.utils import (
1919
_is_column_major,
20+
_to_mxfp8_per_group_rowwise,
21+
_to_mxfp8_per_group_colwise,
2022
)
2123
from torchao.prototype.mx_formats.mx_tensor import to_mx
2224

@@ -300,6 +302,7 @@ def forward(
300302

301303
# Store what we need for backward.
302304
ctx.save_for_backward(A, B_t, offs)
305+
ctx.block_size = block_size
303306
ctx.out_dtype = out_dtype
304307

305308
# Perform scaled grouped GEMM and return result.
@@ -317,8 +320,52 @@ def forward(
317320
return out
318321

319322
@staticmethod
320-
def backward(ctx, grad_output: torch.Tensor):
321-
raise NotImplementedError
323+
def backward(ctx, grad_out: torch.Tensor):
324+
A, B_t, offs = ctx.saved_tensors
325+
block_size = ctx.block_size
326+
out_dtype = ctx.out_dtype
327+
328+
# Compute grad_A.
329+
# grad_A = grad_output @ B
330+
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
331+
grad_out_scale, grad_out_mx = to_mx(
332+
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
333+
)
334+
335+
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(
336+
B_t.transpose(-2, -1).contiguous(),
337+
block_size=block_size,
338+
elem_dtype=torch.float8_e4m3fn,
339+
)
340+
341+
grad_A = emulated_mxfp8_scaled_grouped_mm(
342+
grad_out_mx,
343+
grad_out_scale,
344+
B_t_mx,
345+
B_t_scale,
346+
offs=offs,
347+
out_dtype=out_dtype,
348+
)
349+
350+
# Compute grad_B = grad_output_t @ A
351+
grad_out_t_scale, grad_out_t_mx = _to_mxfp8_per_group_rowwise(
352+
grad_out,
353+
offs=offs,
354+
block_size=block_size,
355+
)
356+
A_scale, A_mx = _to_mxfp8_per_group_colwise(
357+
A,
358+
offs=offs,
359+
block_size=block_size,
360+
)
361+
grad_B = emulated_mxfp8_scaled_grouped_mm(
362+
grad_out_t_mx,
363+
grad_out_t_scale,
364+
A_mx,
365+
A_scale,
366+
offs=offs,
367+
)
368+
return grad_A, grad_B, None, None, None
322369

323370

324371
def _to_mxfp8_3d_expert_weights_dim1(
@@ -352,6 +399,26 @@ def emulated_mxfp8_scaled_grouped_mm(
352399
offs: Optional[torch.Tensor] = None,
353400
out_dtype: Optional[torch.dtype] = torch.bfloat16,
354401
block_size: int = 32,
402+
) -> torch.Tensor:
403+
if A_mx.ndim == 2 and B_t_mx.ndim == 3:
404+
return _emulated_mxfp8_scaled_grouped_mm_2d_3d(
405+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
406+
)
407+
elif A_mx.ndim == 2 and B_t_mx.ndim == 2:
408+
return _emulated_mxfp8_scaled_grouped_mm_2d_2d(
409+
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
410+
)
411+
else:
412+
raise NotImplemented
413+
414+
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
415+
A_mx: torch.Tensor,
416+
A_scale: torch.Tensor,
417+
B_t_mx: torch.Tensor,
418+
B_t_scale: torch.Tensor,
419+
offs: Optional[torch.Tensor] = None,
420+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
421+
block_size: int = 32,
355422
) -> torch.Tensor:
356423
# Dequantize input
357424
# A_mx shape: (M, K)
@@ -397,3 +464,49 @@ def emulated_mxfp8_scaled_grouped_mm(
397464
# Perform bf16 grouped GEMM.
398465
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
399466
return out
467+
468+
469+
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
470+
A_mx: torch.Tensor,
471+
A_scale: torch.Tensor,
472+
B_t_mx: torch.Tensor,
473+
B_t_scale: torch.Tensor,
474+
offs: torch.Tensor,
475+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
476+
block_size: int = 32,
477+
) -> torch.Tensor:
478+
A = torch.empty(A_mx.shape, dtype=torch.bfloat16, device=A_mx.device, requires_grad=A_mx.requires_grad)
479+
B_t = torch.empty(B_t_mx.shape, dtype=torch.bfloat16, device=B_t_mx.device, requires_grad=B_t_mx.requires_grad)
480+
481+
# Dequantize input per each scaling group
482+
scales_start_idx = 0
483+
group_start_idx = 0
484+
for group_end_idx in offs.tolist():
485+
# -- Dequantize A tensor
486+
# A_group shape: (M, group_size)
487+
# A_scale shape: (M, group_size//block_size)
488+
A_group = A_mx[:, group_start_idx:group_end_idx]
489+
A_group_shape = A_group.shape
490+
491+
# Get scales for this group.
492+
# scales shape: (M, group_size//block_size)
493+
group_size = group_end_idx - group_start_idx + 1
494+
num_scale_cols = group_size // block_size
495+
scales = A_scale[:, scales_start_idx : scales_start_idx + num_scale_cols]
496+
497+
# Reshape to be able to do per-scaling group multiplication
498+
# A_group shape: (M, group_size//block_size, block_size)
499+
# A_scale shape: (M, group_size//block_size, 1)
500+
A_group = A_group.reshape(*A_group.shape[:-1], A_group.shape[-1] // block_size, block_size)
501+
scales = scales.unsqueeze(-1)
502+
503+
# Rescale and cast to bfloat16
504+
A = A_group.to(torch.bfloat16) * scales.to(torch.bfloat16)
505+
506+
# Reshape back to original shape
507+
# A shape: (M, group_size)
508+
A = A.reshape(A_group_shape)
509+
A[:, group_start_idx:group_end_idx] = A_group
510+
511+
# -- Dequantize B_t tensor
512+

torchao/prototype/moe_training/utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from torchao.float8.config import ScalingGranularity
77
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
8+
from torchao.prototype.mx_formats.mx_tensor import to_mx
89

9-
10+
# --- float8 rowwise scaling ---
1011
def _to_2d_jagged_float8_tensor_colwise(
1112
A_col_major: torch.Tensor,
1213
offs: torch.Tensor,
@@ -142,6 +143,91 @@ def _to_2d_jagged_float8_tensor_rowwise(
142143

143144
return x_fp8, x_scales
144145

146+
# --- mxfp8 scaling ---
147+
def _to_mxfp8_per_group_rowwise(
148+
x: torch.Tensor,
149+
offs: torch.Tensor,
150+
block_size: int = 32,
151+
) -> Tuple[torch.Tensor, torch.Tensor]:
152+
"""
153+
This is a reference implementation used for testing correctness, it is not performant.
154+
155+
This function converts the 2D input tensor a mxpf8 tensor along dim 0 with per-token-group scaling,
156+
where groups are determined based on the offsets.
157+
158+
Args:
159+
A (torch.Tensor): The input tensor to be converted to a jagged mxfp8 tensor.
160+
161+
Returns:
162+
A tuple containing the jagged mxpf8 tensor and the scales used for the conversion.
163+
"""
164+
assert x.ndim == 2, "input tensor must be 2D"
165+
166+
x_mx = torch.empty_like(x, dtype=torch.float8_e4m3fn)
167+
x_scales = None
168+
169+
start_idx = 0
170+
for end_idx in offs.tolist():
171+
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
172+
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
173+
174+
# Perform mxfp8 conversion on logically distinct subtensor.
175+
scales, mx_subtensor = to_mx(subtensor, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
176+
177+
# Store this portion of the resulting mxfp8 tensor and scales.
178+
x_mx[:, start_idx:end_idx] = mx_subtensor
179+
if x_scales is None:
180+
x_scales = scales
181+
else:
182+
x_scales = torch.cat((x_scales, scales))
183+
184+
# Update start index for next group.
185+
start_idx = end_idx
186+
187+
return x_mx, x_scales
188+
189+
def _to_mxfp8_per_group_colwise(
190+
A_col_major: torch.Tensor, # (K, N)
191+
offs: torch.Tensor,
192+
block_size: int = 32,
193+
) -> Tuple[torch.Tensor, torch.Tensor]:
194+
"""
195+
This is a reference implementation used for testing correctness, it is not performant.
196+
197+
This function converts the 2D input tensor a mxpf8 tensor along dim 1 with per-token-group scaling,
198+
where groups are determined based on the offsets.
199+
200+
Args:
201+
A (torch.Tensor): The input tensor to be converted to a mxfp8 tensor.
202+
203+
Returns:
204+
A tuple containing the mxpf8 tensor and the scales used for the conversion.
205+
"""
206+
assert A_col_major.ndim == 2, "A must be 2D"
207+
208+
A_fp8_col_major = torch.empty_like(A_col_major, dtype=torch.float8_e4m3fn)
209+
A_scales = None
210+
211+
start_idx = 0
212+
for end_idx in offs.tolist():
213+
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
214+
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, N)
215+
216+
# Convert to mxfp8 along dim1, by transposing, converting, and transposing back.
217+
scales, mx_subtensor = to_mx(subtensor.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, block_size=block_size)
218+
scales, mx_subtensor = scales.transpose(-2, -1), mx_subtensor.transpose(-2, -1)
219+
220+
# Store this portion of the resulting mxfp8 tensor and scales.
221+
A_fp8_col_major[start_idx:end_idx, :] = mx_subtensor
222+
if A_scales is None:
223+
A_scales = mx_subtensor
224+
else:
225+
A_scales = torch.cat((A_scales, scales))
226+
227+
# Update start index for next group.
228+
start_idx = end_idx
229+
230+
return A_fp8_col_major, A_scales
145231

146232
def _is_column_major(x: torch.Tensor) -> bool:
147233
"""

0 commit comments

Comments
 (0)