|
14 | 14 | from torchao.prototype.moe_training.kernels import (
|
15 | 15 | triton_fp8_col_major_jagged_colwise_scales,
|
16 | 16 | triton_fp8_row_major_jagged_rowwise_scales,
|
| 17 | + triton_fp8_rowwise_3d_transpose_rhs, |
17 | 18 | )
|
18 | 19 | from torchao.prototype.moe_training.utils import (
|
19 | 20 | _is_column_major,
|
@@ -44,7 +45,7 @@ def _scaled_grouped_mm(
|
44 | 45 | out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
|
45 | 46 | """
|
46 | 47 | # TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
|
47 |
| - logger.info("Using scaled_grouped_mm") |
| 48 | + # logger.info("Using scaled_grouped_mm") |
48 | 49 | return _Float8GroupedMM.apply(
|
49 | 50 | A,
|
50 | 51 | B_t,
|
@@ -127,20 +128,11 @@ def forward(
|
127 | 128 | # Precompute non-transposed B column-major for backward, to save memory by storing the
|
128 | 129 | # low precision B tensor instead of the high precision B tensor.
|
129 | 130 | # In the backward this is needed for grad_A: grad_output @ B.
|
130 |
| - B = B_t.contiguous().transpose(-2, -1) |
131 |
| - |
132 |
| - # - B shape: (E, N, K) |
133 |
| - # - B scales must be computed rowwise keeping the outer/final dim, so: |
134 |
| - # - B_scale shape: (E, 1, K) |
135 |
| - B_scales = tensor_to_scale( |
136 |
| - B, |
137 |
| - torch.float8_e4m3fn, |
138 |
| - scaling_granularity=ScalingGranularity.AXISWISE, |
139 |
| - axiswise_dim=-2, |
| 131 | + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( |
| 132 | + B_t, |
| 133 | + output_dtype=torch.float8_e4m3fn, |
140 | 134 | round_scales_to_power_of_2=True,
|
141 | 135 | )
|
142 |
| - B_scaled = B.to(torch.float32) * B_scales |
143 |
| - B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) |
144 | 136 |
|
145 | 137 | # Store what we need for backward.
|
146 | 138 | ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
|
|
0 commit comments