Skip to content

Commit 9f8a1da

Browse files
[moe training] integrate rowwise expert quant kernel
stack-info: PR: #2698, branch: danielvegamyhre/stack/32
1 parent 5118bcf commit 9f8a1da

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)
7+
from torchao.prototype.moe_training.kernels.float8_rowwise import (
8+
triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs,
9+
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.prototype.moe_training.kernels import (
1515
triton_fp8_col_major_jagged_colwise_scales,
1616
triton_fp8_row_major_jagged_rowwise_scales,
17+
triton_fp8_rowwise_3d_transpose_rhs,
1718
)
1819
from torchao.prototype.moe_training.utils import (
1920
_is_column_major,
@@ -44,7 +45,7 @@ def _scaled_grouped_mm(
4445
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4546
"""
4647
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
47-
logger.info("Using scaled_grouped_mm")
48+
# logger.info("Using scaled_grouped_mm")
4849
return _Float8GroupedMM.apply(
4950
A,
5051
B_t,
@@ -127,20 +128,11 @@ def forward(
127128
# Precompute non-transposed B column-major for backward, to save memory by storing the
128129
# low precision B tensor instead of the high precision B tensor.
129130
# In the backward this is needed for grad_A: grad_output @ B.
130-
B = B_t.contiguous().transpose(-2, -1)
131-
132-
# - B shape: (E, N, K)
133-
# - B scales must be computed rowwise keeping the outer/final dim, so:
134-
# - B_scale shape: (E, 1, K)
135-
B_scales = tensor_to_scale(
136-
B,
137-
torch.float8_e4m3fn,
138-
scaling_granularity=ScalingGranularity.AXISWISE,
139-
axiswise_dim=-2,
131+
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
132+
B_t,
133+
output_dtype=torch.float8_e4m3fn,
140134
round_scales_to_power_of_2=True,
141135
)
142-
B_scaled = B.to(torch.float32) * B_scales
143-
B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn)
144136

145137
# Store what we need for backward.
146138
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)

torchao/prototype/moe_training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
152152
) -> Tuple[torch.Tensor, torch.Tensor]:
153153
"""
154154
This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns
155-
on a per-expert basis.
155+
on a per-expert basis. Output will be in column-major memory layout.
156156
157157
Args:
158158
x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N).

0 commit comments

Comments
 (0)