|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import logging
|
8 |
| -from typing import Optional |
| 8 | +from typing import Optional, Tuple |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 |
|
|
18 | 18 | from torchao.prototype.moe_training.utils import (
|
19 | 19 | _is_column_major,
|
20 | 20 | )
|
| 21 | +from torchao.prototype.mx_formats.mx_tensor import to_mx |
21 | 22 |
|
22 | 23 | logger: logging.Logger = logging.getLogger(__name__)
|
23 | 24 |
|
@@ -268,6 +269,79 @@ def backward(ctx, grad_output: torch.Tensor):
|
268 | 269 | return grad_A, grad_B.transpose(-2, -1), None, None, None, None
|
269 | 270 |
|
270 | 271 |
|
| 272 | +class _MXFP8GroupedMM(torch.autograd.Function): |
| 273 | + """Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization.""" |
| 274 | + |
| 275 | + @staticmethod |
| 276 | + def forward( |
| 277 | + ctx, |
| 278 | + A: torch.Tensor, |
| 279 | + B_t: torch.Tensor, |
| 280 | + offs: Optional[torch.Tensor] = None, |
| 281 | + block_size: int = 32, |
| 282 | + out_dtype: Optional[torch.dtype] = torch.bfloat16, |
| 283 | + ) -> torch.Tensor: |
| 284 | + # torchao _scaled_grouped_mm only supports A=2D and B=3D. |
| 285 | + assert A.ndim == 2, "A must be 2D" |
| 286 | + assert B_t.ndim == 3, "B must be 3D" |
| 287 | + assert block_size == 32, "Only block_size=32 is supported" |
| 288 | + |
| 289 | + # Cast to mxpf8 across dim -1. |
| 290 | + # A_mx shape: (M, K) |
| 291 | + # A_scale shape: (M, K//block_size) |
| 292 | + A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) |
| 293 | + |
| 294 | + # Cast B_t per-expert to mxfp8 across dim1. |
| 295 | + # B_t_mx shape: (E, K, N) |
| 296 | + # B_t_scale shape: (E, K//block_size, N) |
| 297 | + B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(B_t, block_size=block_size) |
| 298 | + |
| 299 | + # Store what we need for backward. |
| 300 | + ctx.save_for_backward(A, B_t, offs) |
| 301 | + ctx.out_dtype = out_dtype |
| 302 | + |
| 303 | + # Perform scaled grouped GEMM and return result. |
| 304 | + # output = input @ weight.T |
| 305 | + # output shape: (M, N) |
| 306 | + out = emulated_mxfp8_scaled_grouped_mm( |
| 307 | + A_mx, |
| 308 | + A_scale, |
| 309 | + B_t_mx, |
| 310 | + B_t_scale, |
| 311 | + offs=offs, |
| 312 | + block_size=block_size, |
| 313 | + out_dtype=out_dtype, |
| 314 | + ) |
| 315 | + return out |
| 316 | + |
| 317 | + @staticmethod |
| 318 | + def backward(ctx, grad_output: torch.Tensor): |
| 319 | + raise NotImplementedError |
| 320 | + |
| 321 | + |
| 322 | +def _to_mxfp8_3d_expert_weights_dim1( |
| 323 | + w_t: torch.Tensor, # (num_experts, K, N) |
| 324 | + block_size: int = 32, |
| 325 | + elem_dtype: torch.dtype = torch.float8_e4m3fn, |
| 326 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 327 | + """Convert a 3D tensor of shape (experts, K, N) to MXFP8 format along dim1. |
| 328 | + Args: |
| 329 | + x (torch.Tensor): Input tensor to be converted. |
| 330 | + block_size (int): Block size for MXFP8 quantization. |
| 331 | + elem_dtype (torch.dtype): Element dtype for MXFP8 quantization. |
| 332 | + Returns: |
| 333 | + Tuple[torch.Tensor, torch.Tensor]: Converted tensor and scale tensor. |
| 334 | + - scale shape: (expets, K // block_size, N) |
| 335 | + - output shape: (experts, K, N) |
| 336 | + """ |
| 337 | + # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. |
| 338 | + w_scale, w_mx = to_mx( |
| 339 | + w_t.transpose(-2, -1).contiguous(), elem_dtype=elem_dtype, block_size=block_size |
| 340 | + ) |
| 341 | + w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1) |
| 342 | + return w_t_scale, w_t_mx |
| 343 | + |
| 344 | + |
271 | 345 | def emulated_mxfp8_scaled_grouped_mm(
|
272 | 346 | A_mx: torch.Tensor,
|
273 | 347 | A_scale: torch.Tensor,
|
|
0 commit comments