Skip to content

Commit eca0126

Browse files
Add Float8BlockwiseLinear for training
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
1 parent b2e5d4d commit eca0126

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
11+
12+
from torchao.float8.float8_utils import compute_error
13+
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear
14+
15+
16+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
17+
@pytest.mark.parametrize("in_features", [4096])
18+
@pytest.mark.parametrize("out_features", [128256])
19+
@pytest.mark.parametrize("batch_size", [1, 8])
20+
@pytest.mark.parametrize("block_size", [128])
21+
def test_blockwise_quant_linear_fwd_bwd(
22+
in_features,
23+
out_features,
24+
batch_size,
25+
block_size,
26+
):
27+
if in_features % block_size != 0 or out_features % block_size != 0:
28+
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")
29+
30+
torch.random.manual_seed(0)
31+
layer_test = Float8BlockwiseLinear(
32+
in_features=in_features,
33+
out_features=out_features,
34+
block_size=block_size,
35+
).cuda()
36+
37+
torch.random.manual_seed(0)
38+
layer_ref = torch.nn.Linear(
39+
in_features=in_features,
40+
out_features=out_features,
41+
).cuda()
42+
43+
# Create input tensor
44+
x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True)
45+
x_ref = x_test.clone().detach().requires_grad_(True)
46+
47+
# Forward pass
48+
y_test = layer_test(x_test)
49+
y_ref = layer_ref(x_ref)
50+
51+
# Compare outputs
52+
sqnr = compute_error(y_ref, y_test)
53+
print(f"Output SQNR: {sqnr}")
54+
assert not y_test.isnan().any(), "Output must not contain NaNs"
55+
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"
56+
assert not sqnr.isinf().any(), "SQNR must not be inf"
57+
58+
# Backward pass
59+
y_test.sum().backward()
60+
y_ref.sum().backward()
61+
62+
# Compare input grads
63+
sqnr = compute_error(x_ref.grad, x_test.grad)
64+
print(f"Input grad SQNR: {sqnr}")
65+
assert not x_test.grad.isnan().any(), "Input grad must not contain NaNs"
66+
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
67+
68+
# Compare weight grads
69+
sqnr = compute_error(layer_ref.weight, layer_test.weight)
70+
print(f"Weight grad SQNR: {sqnr}")
71+
assert not layer_test.weight.grad.isnan().any(), "Weight grad must not contain NaNs"
72+
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import nn
9+
10+
from torchao.core.config import AOBaseConfig
11+
from torchao.prototype.blockwise_fp8_training.kernels import (
12+
blockwise_fp8_gemm_1x128_128x1,
13+
blockwise_fp8_gemm_1x128_128x128,
14+
fp8_blockwise_act_quant_lhs,
15+
fp8_blockwise_act_quant_rhs,
16+
fp8_blockwise_act_quant_transposed_lhs,
17+
fp8_blockwise_weight_quant_rhs,
18+
fp8_blockwise_weight_quant_transposed_rhs,
19+
)
20+
from torchao.quantization.transform_module import (
21+
register_quantize_module_handler,
22+
)
23+
from torchao.utils import is_sm_at_least_90
24+
25+
26+
class fp8_blockwise_mm(torch.autograd.Function):
27+
@staticmethod
28+
def forward(ctx, x, weight, block_size):
29+
assert block_size == 128, "Only support block_size=128"
30+
31+
# Temporarily reshape x to 2D tensor
32+
x_orig_shape = x.shape
33+
x = x.reshape(-1, x_orig_shape[-1])
34+
35+
# Cast inputs to fp8 blockwise using (1, block_size) scaling granularity in row major format.
36+
x_fp8, x_scale = fp8_blockwise_act_quant_lhs(x, block_size)
37+
38+
# Cast weight to fp8 blockwise using (block_size, block_size) scaling granularity, with transposed dims in column major format.
39+
weight_t_fp8, weight_t_scale = fp8_blockwise_weight_quant_transposed_rhs(
40+
weight,
41+
block_size=block_size,
42+
)
43+
44+
# out = input @ weight.T
45+
out = blockwise_fp8_gemm_1x128_128x128(
46+
x_fp8,
47+
1.0 / x_scale,
48+
weight_t_fp8,
49+
1.0 / weight_t_scale,
50+
)
51+
out = out.reshape(*x_orig_shape[:-1], out.shape[-1])
52+
ctx.save_for_backward(x, weight)
53+
ctx.block_size = block_size
54+
return out
55+
56+
@staticmethod
57+
def backward(ctx, grad_output):
58+
x, weight = ctx.saved_tensors
59+
block_size = ctx.block_size
60+
61+
# Reshape input to 2D
62+
x_orig_shape = x.shape
63+
x = x.reshape(-1, x_orig_shape[-1])
64+
65+
# Reshape grad_output to 2D
66+
grad_output_orig_shape = grad_output.shape
67+
grad_output = grad_output.reshape(-1, grad_output_orig_shape[-1]).contiguous()
68+
assert grad_output.shape[1] % 128 == 0, "unsupported"
69+
70+
# Cast grad_output to fp8 blockwise 1x128 since it is the grad of the output activation.
71+
grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant_lhs(
72+
grad_output,
73+
block_size,
74+
)
75+
76+
# Cast weight to fp8 blockwise to 128x128 in column major format.
77+
weight_fp8, weight_scale = fp8_blockwise_weight_quant_rhs(
78+
weight,
79+
block_size=block_size,
80+
)
81+
82+
# grad_x = grad_output @ weight
83+
grad_x = blockwise_fp8_gemm_1x128_128x128(
84+
grad_output_fp8,
85+
1.0 / grad_output_scale,
86+
weight_fp8,
87+
1.0 / weight_scale,
88+
)
89+
90+
# Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is
91+
# the grad of the output activation.
92+
# Write directly with transposed dims in row major format, as needed for dW calc.
93+
grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant_transposed_lhs(
94+
grad_output,
95+
block_size,
96+
)
97+
98+
# Cast x to fp8 blockwise with (block_size x 1) scaling groups, in column major format.
99+
# RHS should have groupwise scales calculated colwise, so scaling groups do not cross the
100+
# contracting (K) dim.
101+
x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size)
102+
103+
# grad_weight = grad_output.T @ x
104+
grad_weight = blockwise_fp8_gemm_1x128_128x1(
105+
grad_output_t_fp8,
106+
1.0 / grad_output_t_scale,
107+
x_fp8,
108+
1.0 / x_scale,
109+
)
110+
111+
# Reshape grad_x to expected potentially 3D+ shape
112+
grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1])
113+
return grad_x, grad_weight, None, None
114+
115+
116+
class Float8BlockwiseLinear(nn.Linear):
117+
"""
118+
Custom linear layer with support for quantized weights and optional bias.
119+
120+
Args:
121+
in_features (int): Number of input features.
122+
out_features (int): Number of output features.
123+
bias (bool): Whether to include a bias term. Defaults to False.
124+
block_size (int): Block size for quantization. Defaults to 128.
125+
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
126+
"""
127+
128+
supported_dtypes = [
129+
torch.bfloat16,
130+
]
131+
132+
def __init__(
133+
self,
134+
*args,
135+
block_size: int = 128,
136+
dtype=torch.bfloat16,
137+
**kwargs,
138+
):
139+
super().__init__(*args, **kwargs)
140+
141+
assert dtype in self.supported_dtypes, (
142+
f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}"
143+
)
144+
assert is_sm_at_least_90(), "Only support SM90"
145+
self.block_size = block_size
146+
self.dtype = dtype
147+
148+
def forward(self, x: torch.Tensor) -> torch.Tensor:
149+
"""
150+
Forward pass for the custom linear layer.
151+
152+
Args:
153+
x (torch.Tensor): input tensor.
154+
155+
Returns:
156+
torch.Tensor: Transformed tensor after linear computation.
157+
"""
158+
return fp8_blockwise_mm.apply(x, self.weight, self.block_size)
159+
160+
@classmethod
161+
def from_float(
162+
cls,
163+
mod,
164+
):
165+
assert mod.bias is None, "unsupported"
166+
assert mod.in_features % 128 == 0, "unsupported"
167+
assert mod.out_features % 128 == 0, "unsupported"
168+
with torch.device("meta"):
169+
new_mod = cls(
170+
mod.in_features,
171+
mod.out_features,
172+
bias=False,
173+
)
174+
new_mod.weight = mod.weight
175+
new_mod.bias = mod.bias
176+
return new_mod
177+
178+
179+
class Float8BlockwiseLinearConfig(AOBaseConfig):
180+
pass
181+
182+
183+
@register_quantize_module_handler(Float8BlockwiseLinearConfig)
184+
def _float8_blockwise_transform(module, config):
185+
return Float8BlockwiseLinear.from_float(module)

0 commit comments

Comments
 (0)