From f4c270f6fd6a436e8fd6a78d42bd19a53e694edf Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 5 Aug 2025 15:22:42 -0700 Subject: [PATCH] [moe training] integrate rowwise expert quant kernel stack-info: PR: https://github.com/pytorch/ao/pull/2698, branch: danielvegamyhre/stack/32 --- .../prototype/moe_training/kernels/__init__.py | 3 +++ .../prototype/moe_training/scaled_grouped_mm.py | 16 ++++------------ torchao/prototype/moe_training/utils.py | 2 +- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/torchao/prototype/moe_training/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py index b5446849b6..8fb16579e5 100644 --- a/torchao/prototype/moe_training/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -1,3 +1,6 @@ +from torchao.prototype.moe_training.kernels.float8_rowwise import ( + triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs, +) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 5604d1ecad..7dc246e251 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -15,6 +15,7 @@ from torchao.prototype.moe_training.kernels import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.utils import ( _is_column_major, @@ -142,20 +143,11 @@ def forward( # Precompute non-transposed B column-major for backward, to save memory by storing the # low precision B tensor instead of the high precision B tensor. # In the backward this is needed for grad_A: grad_output @ B. - B = B_t.contiguous().transpose(-2, -1) - - # - B shape: (E, N, K) - # - B scales must be computed rowwise keeping the outer/final dim, so: - # - B_scale shape: (E, 1, K) - B_scales = tensor_to_scale( - B, - torch.float8_e4m3fn, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-2, + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( + B_t, + output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - B_scaled = B.to(torch.float32) * B_scales - B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) # Store what we need for backward. ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) diff --git a/torchao/prototype/moe_training/utils.py b/torchao/prototype/moe_training/utils.py index cbffcadbd2..dc13dfea33 100644 --- a/torchao/prototype/moe_training/utils.py +++ b/torchao/prototype/moe_training/utils.py @@ -152,7 +152,7 @@ def torch_to_3d_rowwise_float8_transpose_rhs( ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns - on a per-expert basis. + on a per-expert basis. Output will be in column-major memory layout. Args: x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N).