Skip to content

integrate mxfp8 dim1 cast kernel choice enum into MXLinear #2554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tqdm import tqdm

from torchao.prototype.mx_formats import MXLinearConfig
from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
Expand Down Expand Up @@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 32
config.use_fp8_dim1_cast_triton_kernel = True
config.mxfp8_cast_kernel_choice = MXFP8CastKernelChoice.CUDA
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
Expand Down
27 changes: 16 additions & 11 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import (
MXFP8CastKernelChoice,
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
Expand Down Expand Up @@ -81,16 +82,17 @@ def run_around_tests():
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_eager_vs_hp(
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
):
@pytest.mark.parametrize(
"mxfp8_cast_kernel_choice",
[None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA],
)
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
"""
Smoke test for training linear module with mx weight, compares the following:
* baseline: float32
* experiment: emulated MX
"""
if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice is not None:
if elem_dtype != (
torch.float8_e4m3fn,
torch.float8_e4m3fn,
Expand All @@ -109,11 +111,11 @@ def test_linear_eager_vs_hp(
)
m_mx = copy.deepcopy(m)
config = MXLinearConfig(
block_size=4,
block_size=32, # Only 32 is supported for now
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
)
quantize_(m_mx, config)

Expand Down Expand Up @@ -227,8 +229,11 @@ def test_activation_checkpointing():
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
@pytest.mark.parametrize(
"mxfp8_cast_kernel_choice",
[None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA],
)
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
Expand All @@ -246,7 +251,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for non-emulated recipes with bias=True")

if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice is not None:
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
pytest.skip("unsupported configuration")
if not is_sm_at_least_89():
Expand All @@ -267,7 +272,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
)
config = MXLinearConfig.from_recipe_name(recipe_name)
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice

quantize_(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
Expand Down
13 changes: 9 additions & 4 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum):
CUBLAS = "cublas"


class MXFP8CastKernelChoice(Enum):
TRITON = "triton"
CUDA = "cuda"
TORCH = "torch"


# Pre-made recipes for common configurations
class MXLinearRecipeName(Enum):
MXFP8_EMULATED = "mxfp8_emulated"
Expand Down Expand Up @@ -85,10 +91,10 @@ class MXLinearConfig(AOBaseConfig):
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for cast to mxfp8 across dim1
# define which kernel to use for dim1 cast
# TODO(1945): remove this config option once torch.compile gives us
# a fast kernel
use_fp8_dim1_cast_triton_kernel: bool = False
mxfp8_cast_kernel_choice: MXFP8CastKernelChoice = MXFP8CastKernelChoice.TRITON

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False
Expand Down Expand Up @@ -146,8 +152,7 @@ def short_str(self) -> str:
if self.elem_dtype_grad_output_override is not None:
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
s += f", kernel={self.gemm_kernel_choice.value}"
if self.use_fp8_dim1_cast_triton_kernel:
s += ", use_fp8_dim1_cast_triton_kernel=True"
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
return s
Expand Down
36 changes: 29 additions & 7 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._tensor import DTensor

from torchao.prototype.mx_formats.config import (
MXFP8CastKernelChoice,
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
Expand Down Expand Up @@ -134,15 +135,15 @@ def forward(
grad_elem_dtype: Any,
block_size: int,
gemm_kernel_choice: MXGemmKernelChoice,
use_fp8_dim1_cast_triton_kernel: bool,
mxfp8_cast_kernel_choice: MXFP8CastKernelChoice,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size
ctx.gemm_kernel_choice = gemm_kernel_choice
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice

# input @ weight_t = output
input_orig_shape = input_hp.shape
Expand All @@ -167,7 +168,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size
gemm_kernel_choice = ctx.gemm_kernel_choice
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
Expand All @@ -183,10 +184,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
)
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
weight_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
)
else:
weight_hp_t_c = weight_hp.t().contiguous()
weight_mx_dim1 = MXTensor.to_mx(
Expand All @@ -201,14 +206,22 @@ def backward(ctx, grad_output_hp: torch.Tensor):
)

# input_t @ grad_output = grad_weight
if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
grad_output_hp_r,
block_size,
grad_elem_dtype,
grad_output_hp_r.dtype,
gemm_kernel_choice,
)
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
grad_output_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
grad_output_hp_r,
block_size,
grad_elem_dtype,
grad_output_hp_r.dtype,
gemm_kernel_choice,
)
else:
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(),
Expand All @@ -217,7 +230,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
input_hp_r,
block_size,
Expand All @@ -226,6 +239,15 @@ def backward(ctx, grad_output_hp: torch.Tensor):
gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
input_t_mx_dim0_tmp = _cuda_to_mxfp8_dim1_wrapper(
input_hp_r,
block_size,
in_elem_dtype,
input_hp_r.dtype,
gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
else:
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(),
Expand Down Expand Up @@ -280,7 +302,7 @@ def forward(self, x):
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
config.gemm_kernel_choice,
config.use_fp8_dim1_cast_triton_kernel,
config.mxfp8_cast_kernel_choice,
)
if self.bias is not None:
y = y + self.bias
Expand Down
Loading