Skip to content

Commit 5a4f0fa

Browse files
make fp8 blockwise linear differentiable; use new kernels
stack-info: PR: #2602, branch: danielvegamyhre/stack/16
1 parent fa64d54 commit 5a4f0fa

File tree

6 files changed

+237
-36
lines changed

6 files changed

+237
-36
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from torchao.float8.float8_utils import compute_error
11+
from torchao.prototype.blockwise_fp8.blockwise_linear import Float8BlockwiseLinear
12+
13+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
14+
15+
16+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
17+
@pytest.mark.parametrize("in_features", [1024])
18+
@pytest.mark.parametrize("out_features", [1024])
19+
@pytest.mark.parametrize("batch_size", [1])
20+
@pytest.mark.parametrize("block_size", [128])
21+
def test_blockwise_quant_linear_fwd_bwd(
22+
in_features,
23+
out_features,
24+
batch_size,
25+
block_size,
26+
):
27+
if in_features % block_size != 0 or out_features % block_size != 0:
28+
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")
29+
30+
torch.random.manual_seed(0)
31+
layer_test = Float8BlockwiseLinear(
32+
in_features=in_features,
33+
out_features=out_features,
34+
block_size=block_size,
35+
).cuda()
36+
37+
torch.random.manual_seed(0)
38+
layer_ref = torch.nn.Linear(
39+
in_features=in_features,
40+
out_features=out_features,
41+
).cuda()
42+
43+
# Create input tensor
44+
x_test = torch.randn(batch_size, in_features).cuda()
45+
x_ref = x_test.clone().detach().requires_grad_(True)
46+
47+
# Forward pass
48+
y_test = layer_test(x_test)
49+
y_ref = layer_ref(x_ref)
50+
51+
# Compare outputs
52+
sqnr = compute_error(y_ref, y_test)
53+
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"
54+
55+
# Backward pass
56+
y_test.sum().backward()
57+
y_ref.sum().backward()
58+
59+
# Compare input grads
60+
sqnr = compute_error(x_ref.grad, x_test.grad)
61+
assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"
62+
63+
# Compare weight grads
64+
sqnr = compute_error(layer_ref.weight, layer_test.weight)
65+
assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"

torchao/prototype/blockwise_fp8/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .blockwise_linear import BlockwiseQuantLinear
1+
from .blockwise_linear import Float8BlockwiseLinear
22
from .kernels import (
33
blockwise_fp8_gemm,
44
fp8_blockwise_act_quant,
@@ -8,7 +8,7 @@
88

99
__all__ = [
1010
"blockwise_fp8_gemm",
11-
"BlockwiseQuantLinear",
11+
"Float8BlockwiseLinear",
1212
"fp8_blockwise_act_quant",
1313
"fp8_blockwise_weight_quant",
1414
"fp8_blockwise_weight_dequant",

torchao/prototype/blockwise_fp8/blockwise_linear.py

Lines changed: 135 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,107 @@
77
import torch
88
from torch import nn
99

10+
from torchao.core.config import AOBaseConfig
11+
from torchao.prototype.blockwise_fp8.deep_gemm_utils import (
12+
scaled_mm_deep_gemm_128_1_128_1,
13+
scaled_mm_deep_gemm_128_1_128_128,
14+
)
1015
from torchao.prototype.blockwise_fp8.kernels import (
11-
blockwise_fp8_gemm,
1216
fp8_blockwise_act_quant,
17+
triton_quantize_fp8_block,
18+
)
19+
from torchao.quantization.transform_module import (
20+
register_quantize_module_handler,
1321
)
1422

1523

16-
class BlockwiseQuantLinear(nn.Module):
24+
class fp8_blockwise_mm(torch.autograd.Function):
25+
@staticmethod
26+
def forward(ctx, x, weight, block_size):
27+
assert block_size == 128, "Only support block_size=128"
28+
29+
# Temporarily reshape x to 2D tensor
30+
x_orig_shape = x.shape
31+
x = x.reshape(-1, x_orig_shape[-1])
32+
33+
# Triton kernel from DeepGEMM currently has the fastest activation quantization (1 x block_size)
34+
x_fp8, x_scale = fp8_blockwise_act_quant(x, block_size)
35+
36+
# fbgemm currently has the fastest weight quantization (block_size x block_size)
37+
weight_t_fp8, weight_t_scale = triton_quantize_fp8_block(
38+
weight,
39+
block_m=block_size,
40+
block_k=block_size,
41+
k_major=True, # For [M,K] -> [K,M] in column-major
42+
)
43+
44+
# DeepGEMM for blockwise GEMM where activation has (1 x block_size) scaling granularity
45+
# and weight has (block_size x block_size) scaling granularity.
46+
out = scaled_mm_deep_gemm_128_1_128_128(
47+
x_fp8,
48+
x_scale,
49+
weight_t_fp8,
50+
weight_t_scale,
51+
)
52+
ctx.save_for_backward(x, weight)
53+
ctx.block_size = block_size
54+
return out
55+
56+
@staticmethod
57+
def backward(ctx, grad_output):
58+
x, weight = ctx.saved_tensors
59+
block_size = ctx.block_size
60+
61+
# left operand must be row-major
62+
grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant(
63+
grad_output,
64+
block_size,
65+
)
66+
67+
# right operand must be column-major
68+
weight_t_fp8, weight_t_scale = triton_quantize_fp8_block(
69+
weight,
70+
block_m=block_size,
71+
block_k=block_size,
72+
k_major=False, # For [M,K] -> [K,M] in row-major
73+
)
74+
weight_t_fp8 = weight_t_fp8.t().contiguous().t() # To col-major
75+
76+
# DeepGEMM for blockwise GEMM where left operand has (1 x block_size) scaling granularity
77+
# and right operand has (block_size x block_size) scaling granularity.
78+
# grad_x = grad_output @ weight.T
79+
grad_x = scaled_mm_deep_gemm_128_1_128_128(
80+
grad_output_fp8,
81+
weight_t_fp8,
82+
1.0 / grad_output_scale,
83+
1.0 / weight_t_scale,
84+
)
85+
86+
# left operand must be row-major
87+
grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant(
88+
grad_output.t().contiguous(),
89+
block_size,
90+
)
91+
92+
# right operand must be column-major
93+
x_fp8, x_scale = fp8_blockwise_act_quant(
94+
x,
95+
block_size,
96+
)
97+
x_fp8 = x_fp8.t().contiguous().t() # To col-major
98+
99+
# DeepGEMM for blockwise GEMM where both operands have (1 x block_size) scaling granularity.
100+
# grad_weight = grad_output.T @ x
101+
grad_weight = scaled_mm_deep_gemm_128_1_128_1(
102+
grad_output_t_fp8,
103+
x_fp8,
104+
1.0 / grad_output_t_scale,
105+
1.0 / x_scale,
106+
)
107+
return grad_x, grad_weight, None, None
108+
109+
110+
class Float8BlockwiseLinear(nn.Linear):
17111
"""
18112
Custom linear layer with support for quantized weights and optional bias.
19113
@@ -25,53 +119,60 @@ class BlockwiseQuantLinear(nn.Module):
25119
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
26120
"""
27121

28-
dtype = torch.bfloat16
122+
supported_dtypes = [
123+
torch.bfloat16,
124+
]
29125

30126
def __init__(
31127
self,
32-
in_features: int,
33-
out_features: int,
34-
bias: bool = False,
128+
*args,
35129
block_size: int = 128,
36-
dtype: torch.dtype = torch.float8_e4m3fn,
130+
dtype=torch.bfloat16,
131+
**kwargs,
37132
):
38-
super().__init__()
39-
supported_dtypes = [
40-
torch.float8_e4m3fn,
41-
torch.float8_e5m2,
42-
]
43-
assert dtype in supported_dtypes, (
44-
f"Unsupported dtype: {dtype}. Supported dtypes: {supported_dtypes}"
45-
)
46-
scale_in_features = (in_features + block_size - 1) // block_size
47-
scale_out_features = (out_features + block_size - 1) // block_size
48-
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
49-
self.weight.scale = self.scale = nn.Parameter(
50-
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
133+
super().__init__(*args, **kwargs)
134+
135+
assert dtype in self.supported_dtypes, (
136+
f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}"
51137
)
52138
self.block_size = block_size
53-
self.dtype
54-
55-
if bias:
56-
self.bias = nn.Parameter(torch.empty(out_features))
57-
else:
58-
self.register_parameter("bias", None)
139+
self.dtype = dtype
59140

60141
def forward(self, x: torch.Tensor) -> torch.Tensor:
61142
"""
62143
Forward pass for the custom linear layer.
63144
64145
Args:
65-
x (torch.Tensor): Input tensor.
146+
x (torch.Tensor): input tensor.
66147
67148
Returns:
68149
torch.Tensor: Transformed tensor after linear computation.
69150
"""
70-
x, scale = fp8_blockwise_act_quant(x, self.block_size, self.dtype)
71-
y = blockwise_fp8_gemm(
72-
x, scale, self.weight, self.weight.scale, self.block_size
73-
)
151+
return fp8_blockwise_mm.apply(x, self.weight, self.block_size)
152+
153+
@classmethod
154+
def from_float(
155+
cls,
156+
mod,
157+
):
158+
assert mod.bias is None, "unsupported"
159+
assert mod.in_features % 128 == 0, "unsupported"
160+
assert mod.out_features % 128 == 0, "unsupported"
161+
with torch.device("meta"):
162+
new_mod = cls(
163+
mod.in_features,
164+
mod.out_features,
165+
bias=False,
166+
)
167+
new_mod.weight = mod.weight
168+
new_mod.bias = mod.bias
169+
return new_mod
170+
171+
172+
class Float8BlockwiseLinearConfig(AOBaseConfig):
173+
pass
174+
74175

75-
if self.bias is not None:
76-
y += self.bias
77-
return y
176+
@register_quantize_module_handler(Float8BlockwiseLinearConfig)
177+
def _deep_gemm_float8_inference_linear_transform(module, config):
178+
return Float8BlockwiseLinear.from_float(module)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import sys
2+
3+
import torch
4+
5+
try:
6+
import deep_gemm
7+
except ImportError:
8+
print("Please install deepgemm to use this feature")
9+
sys.exit(0)
10+
11+
12+
def scaled_mm_deep_gemm_128_1_128_128(a, b, a_scale, b_scale):
13+
M, K = a.shape
14+
N, K = b.shape
15+
out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
16+
deep_gemm.gemm_fp8_fp8_bf16_nt((a, a_scale), (b, b_scale), out=out)
17+
return out
18+
19+
20+
def scaled_mm_deep_gemm_128_1_128_1(a, b, a_scale, b_scale):
21+
M, K = a.shape
22+
N, K = b.shape
23+
# Note: the results from `wgrad_gemm_fp8_fp8_fp32_nt` are **accumulated**
24+
# into this tensor. For now, we initialize with `zeros` to get correct
25+
# numerics in toy examples. For a real use case, this will need to pass
26+
# in the gradient tensor directly.
27+
out = torch.zeros((M, N), dtype=torch.float, device=a.device)
28+
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((a, a_scale), (b, b_scale), out=out)
29+
return out

torchao/prototype/blockwise_fp8/kernels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
import triton.language as tl
1313
from triton import Config
1414

15+
# try:
16+
# from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block
17+
# except ImportError:
18+
# print("Please install fbgemm-gpu to use this feature")
19+
# sys.exit(1)
20+
1521
# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
1622

1723
fp8_gemm_configs = [

0 commit comments

Comments
 (0)