Skip to content

Add Float8BlockwiseLinear for training #2618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions test/prototype/blockwise_fp8_training/test_blockwise_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy

import pytest
import torch

from torchao.utils import is_sm_at_least_90

triton = pytest.importorskip("triton", reason="Triton required to run this test")
if not is_sm_at_least_90():
pytest.skip("This test requires SM90 or higher", allow_module_level=True)


from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear

torch.random.manual_seed(0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm90?

@pytest.mark.parametrize("in_features", [4096])
@pytest.mark.parametrize("out_features", [128256])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("block_size", [128])
def test_blockwise_quant_linear_fwd_bwd(
in_features,
out_features,
batch_size,
block_size,
):
if in_features % block_size != 0 or out_features % block_size != 0:
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")

layer_ref = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
bias=False,
).cuda()

layer_test = Float8BlockwiseLinear.from_float(copy.deepcopy(layer_ref))

# Create input tensor
x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True)
x_ref = x_test.clone().detach().requires_grad_(True)

# Forward pass
y_test = layer_test(x_test)
y_ref = layer_ref(x_ref)

# Compare outputs
sqnr = compute_error(y_ref, y_test)
assert not y_test.isnan().any(), "Output must not contain NaNs"
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"
assert not sqnr.isinf().any(), "SQNR must not be inf"

# Backward pass
y_test.sum().backward()
y_ref.sum().backward()

# Compare input grads
sqnr = compute_error(x_ref.grad, x_test.grad)
assert not x_test.grad.isnan().any(), "Input grad must not contain NaNs"
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"

# Compare weight grads
sqnr = compute_error(layer_ref.weight, layer_test.weight)
assert not layer_test.weight.grad.isnan().any(), "Weight grad must not contain NaNs"
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
185 changes: 185 additions & 0 deletions torchao/prototype/blockwise_fp8_training/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn

from torchao.core.config import AOBaseConfig
from torchao.prototype.blockwise_fp8_training.kernels import (
blockwise_fp8_gemm_1x128_128x1,
blockwise_fp8_gemm_1x128_128x128,
fp8_blockwise_act_quant_lhs,
fp8_blockwise_act_quant_rhs,
fp8_blockwise_act_quant_transposed_lhs,
fp8_blockwise_weight_quant_rhs,
fp8_blockwise_weight_quant_transposed_rhs,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import is_sm_at_least_90


class fp8_blockwise_mm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, block_size):
assert block_size == 128, "Only support block_size=128"

# Temporarily reshape x to 2D tensor
x_orig_shape = x.shape
x = x.reshape(-1, x_orig_shape[-1])

# Cast inputs to fp8 blockwise using (1, block_size) scaling granularity in row major format.
x_fp8, x_scale = fp8_blockwise_act_quant_lhs(x, block_size)

# Cast weight to fp8 blockwise using (block_size, block_size) scaling granularity, with transposed dims in column major format.
weight_t_fp8, weight_t_scale = fp8_blockwise_weight_quant_transposed_rhs(
weight,
block_size=block_size,
)

# out = input @ weight.T
out = blockwise_fp8_gemm_1x128_128x128(
x_fp8,
1.0 / x_scale,
weight_t_fp8,
1.0 / weight_t_scale,
)
out = out.reshape(*x_orig_shape[:-1], out.shape[-1])
ctx.save_for_backward(x, weight)
ctx.block_size = block_size
return out

@staticmethod
def backward(ctx, grad_output):
x, weight = ctx.saved_tensors
block_size = ctx.block_size

# Reshape input to 2D
x_orig_shape = x.shape
x = x.reshape(-1, x_orig_shape[-1])

# Reshape grad_output to 2D
grad_output_orig_shape = grad_output.shape
grad_output = grad_output.reshape(-1, grad_output_orig_shape[-1]).contiguous()
assert grad_output.shape[1] % 128 == 0, "unsupported"

# Cast grad_output to fp8 blockwise 1x128 since it is the grad of the output activation.
grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant_lhs(
grad_output,
block_size,
)

# Cast weight to fp8 blockwise to 128x128 in column major format.
weight_fp8, weight_scale = fp8_blockwise_weight_quant_rhs(
weight,
block_size=block_size,
)

# grad_x = grad_output @ weight
grad_x = blockwise_fp8_gemm_1x128_128x128(
grad_output_fp8,
1.0 / grad_output_scale,
weight_fp8,
1.0 / weight_scale,
)

# Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is
# the grad of the output activation.
# Write directly with transposed dims in row major format, as needed for dW calc.
grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant_transposed_lhs(
grad_output,
block_size,
)

# Cast x to fp8 blockwise with (block_size x 1) scaling groups, in column major format.
# RHS should have groupwise scales calculated colwise, so scaling groups do not cross the
# contracting (K) dim.
x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size)

# grad_weight = grad_output.T @ x
grad_weight = blockwise_fp8_gemm_1x128_128x1(
grad_output_t_fp8,
1.0 / grad_output_t_scale,
x_fp8,
1.0 / x_scale,
)

# Reshape grad_x to expected potentially 3D+ shape
grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1])
return grad_x, grad_weight, None, None


class Float8BlockwiseLinear(nn.Linear):
"""
Custom linear layer with support for quantized weights and optional bias.

Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
block_size (int): Block size for quantization. Defaults to 128.
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
"""

supported_dtypes = [
torch.bfloat16,
]

def __init__(
self,
*args,
block_size: int = 128,
dtype=torch.bfloat16,
**kwargs,
):
super().__init__(*args, **kwargs)

assert dtype in self.supported_dtypes, (
f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}"
)
assert is_sm_at_least_90(), "Only support SM90"
self.block_size = block_size
self.dtype = dtype

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.

Args:
x (torch.Tensor): input tensor.

Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return fp8_blockwise_mm.apply(x, self.weight, self.block_size)

@classmethod
def from_float(
cls,
mod,
):
assert mod.bias is None, "unsupported"
assert mod.in_features % 128 == 0, "unsupported"
assert mod.out_features % 128 == 0, "unsupported"
with torch.device("meta"):
new_mod = cls(
mod.in_features,
mod.out_features,
bias=False,
)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


class Float8BlockwiseLinearConfig(AOBaseConfig):
pass


@register_quantize_module_handler(Float8BlockwiseLinearConfig)
def _float8_blockwise_transform(module, config):
return Float8BlockwiseLinear.from_float(module)
Loading