Skip to content

Commit b2e5d4d

Browse files
Add Triton kernels for fp8 blockwise quantization and GEMMs
stack-info: PR: #2617, branch: danielvegamyhre/stack/17
1 parent 0e00df3 commit b2e5d4d

File tree

7 files changed

+1108
-0
lines changed

7 files changed

+1108
-0
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
11+
12+
from packaging import version
13+
from torchao.float8.float8_utils import compute_error
14+
from torchao.prototype.blockwise_fp8_training.kernels import (
15+
blockwise_fp8_gemm_1x128_128x1,
16+
blockwise_fp8_gemm_1x128_128x128,
17+
fp8_blockwise_act_quant_lhs,
18+
fp8_blockwise_act_quant_rhs,
19+
fp8_blockwise_act_quant_transposed_lhs,
20+
fp8_blockwise_weight_quant_rhs,
21+
fp8_blockwise_weight_quant_transposed_rhs,
22+
torch_blockwise_scale_act_quant_lhs,
23+
torch_blockwise_scale_act_quant_rhs,
24+
torch_blockwise_scale_weight_quant,
25+
)
26+
from torchao.testing.utils import skip_if_rocm
27+
28+
BLOCKWISE_SIZE_MNK = [
29+
(128, 128, 128),
30+
(2, 512, 128),
31+
(2, 5120, 1280),
32+
(3, 2048, 2048),
33+
(4, 3584, 640),
34+
(13, 8704, 8576),
35+
(26, 18944, 1664),
36+
(67, 6656, 1408),
37+
]
38+
39+
40+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
41+
@pytest.mark.skipif(
42+
version.parse(triton.__version__) < version.parse("3.3.0"),
43+
reason="Triton version < 3.3.0, test skipped",
44+
)
45+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
46+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
47+
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
48+
# Simulate output = input @ weight.T
49+
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
50+
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
51+
C = A @ B.T
52+
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype)
53+
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
54+
C_q = blockwise_fp8_gemm_1x128_128x128(A_q, 1.0 / A_s, B_t_q, 1.0 / B_t_s)
55+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
56+
57+
sqnr = compute_error(C, C_q)
58+
min_sqnr = 28.0
59+
print(f"blockwise_fp8_gemm_1x128_128x128 ({M},{N},{K}) SQNR: {sqnr}")
60+
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"
61+
62+
63+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
64+
@pytest.mark.skipif(
65+
version.parse(triton.__version__) < version.parse("3.3.0"),
66+
reason="Triton version < 3.3.0, test skipped",
67+
)
68+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
69+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
70+
def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype):
71+
# Simulate grad_weight = grad_output_t @ input
72+
A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda")
73+
B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda")
74+
C = A.T @ B
75+
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
76+
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype)
77+
C_q = blockwise_fp8_gemm_1x128_128x1(A_t_q, 1.0 / A_t_s, B_q, 1.0 / B_s)
78+
79+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
80+
assert C.dtype == torch.bfloat16
81+
assert C_q.dtype == torch.bfloat16
82+
83+
sqnr = compute_error(C, C_q)
84+
min_sqnr = 28.0
85+
print(f"blockwise_fp8_gemm_1x128_128x1 ({M},{N},{K}) SQNR: {sqnr}")
86+
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"
87+
88+
89+
@skip_if_rocm("ROCm not supported")
90+
@pytest.mark.parametrize("block_size", [128, 256])
91+
def test_triton_quantize_fp8_act_quant_lhs(block_size):
92+
device = "cuda"
93+
M, K = 4096, 1024
94+
x = torch.randn(M, K, device=device)
95+
96+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
97+
# quantized tensor will have NaNs due to division by 0
98+
x[0, :block_size] = 0.0
99+
100+
# Get the quantized tensor and scales using triton implementation
101+
triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs(
102+
x,
103+
block_size=block_size,
104+
)
105+
106+
# Get the quantized tensor and scales using reference implementation
107+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(x, tile_size=block_size)
108+
109+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
110+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
111+
112+
# Convert both to float32 for comparison
113+
triton_fp32 = triton_fp8.to(torch.float32)
114+
ref_fp32 = ref_fp8.to(torch.float32)
115+
116+
# Check that the quantized tensors are close
117+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), (
118+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
119+
)
120+
121+
# Compare scales
122+
assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), (
123+
f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}"
124+
)
125+
126+
127+
@skip_if_rocm("ROCm not supported")
128+
@pytest.mark.parametrize("block_size", [128, 256])
129+
def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
130+
device = "cuda"
131+
M, K = 4096, 1024
132+
x = torch.randn(M, K, device=device)
133+
134+
# Set one block to 0s, so if nan guards/EPS are not applied, the
135+
# quantized tensor will have NaNs due to division by 0
136+
x[:block_size, :block_size] = 0.0
137+
138+
# Get the quantized tensor and scales using triton implementation
139+
triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs(
140+
x,
141+
block_size=block_size,
142+
)
143+
144+
# Get the quantized tensor and scales using reference implementation
145+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_rhs(x, block_size=block_size)
146+
147+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
148+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
149+
150+
# Convert both to float32 for comparison
151+
triton_fp32 = triton_fp8.to(torch.float32)
152+
ref_fp32 = ref_fp8.to(torch.float32)
153+
154+
# Check that the quantized tensors are close
155+
assert torch.allclose(triton_fp32, ref_fp32, atol=1e-3), (
156+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
157+
)
158+
159+
# Compare scales
160+
assert torch.allclose(triton_scale, ref_scale, atol=1e-3), (
161+
f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}"
162+
)
163+
164+
165+
@skip_if_rocm("ROCm not supported")
166+
@pytest.mark.parametrize("block_size", [128, 256])
167+
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
168+
def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
169+
device = "cuda"
170+
x = torch.randn(M, K, device=device)
171+
172+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
173+
# quantized tensor will have NaNs due to division by 0
174+
x[0, :block_size] = 0.0
175+
176+
# Get the quantized tensor and scales using triton implementation
177+
triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs(
178+
x,
179+
block_size=block_size,
180+
)
181+
182+
# Get the quantized tensor and scales using reference implementation
183+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(
184+
x.t().contiguous(), tile_size=block_size
185+
)
186+
187+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
188+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
189+
190+
# Convert both to float32 for comparison
191+
triton_fp32 = triton_fp8.to(torch.float32)
192+
ref_fp32 = ref_fp8.to(torch.float32)
193+
194+
# Check that the quantized tensors are close
195+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), (
196+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
197+
)
198+
199+
# Compare scales
200+
assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), (
201+
f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}"
202+
)
203+
204+
205+
@skip_if_rocm("ROCm not supported")
206+
@pytest.mark.parametrize("block_size", [128, 256])
207+
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
208+
def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
209+
device = "cuda"
210+
x = torch.randn(M, K, device=device)
211+
212+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
213+
# quantized tensor will have NaNs due to division by 0
214+
x[:block_size, :block_size] = 0.0
215+
216+
# Get the quantized tensor and scales using triton implementation
217+
triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs(
218+
x,
219+
block_size=block_size,
220+
)
221+
# Get the quantized tensor and scales using reference implementation
222+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=block_size)
223+
224+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
225+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
226+
227+
# Convert both to float32 for comparison
228+
triton_fp32 = triton_fp8.to(torch.float32)
229+
ref_fp32 = ref_fp8.to(torch.float32)
230+
231+
# Check that the quantized tensors are close
232+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), (
233+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
234+
)
235+
236+
# Compare scales
237+
assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), (
238+
f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}"
239+
)
240+
241+
242+
@skip_if_rocm("ROCm not supported")
243+
@pytest.mark.parametrize("block_size", [128, 256])
244+
def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int):
245+
device = "cuda"
246+
M = 512
247+
K = 2048
248+
x = torch.randn(M, K, device=device)
249+
250+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
251+
# quantized tensor will have NaNs due to division by 0
252+
x[:block_size, :block_size] = 0.0
253+
254+
# Get the quantized tensor and scales using triton implementation
255+
triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs(
256+
x,
257+
block_size=block_size,
258+
)
259+
# Get the quantized tensor and scales using reference implementation
260+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(
261+
x.t().contiguous(), tile_size=block_size
262+
)
263+
264+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
265+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
266+
267+
# Convert both to float32 for comparison
268+
triton_fp32 = triton_fp8.to(torch.float32)
269+
ref_fp32 = ref_fp8.to(torch.float32)
270+
271+
# Check that the quantized tensors are close
272+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), (
273+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
274+
)
275+
276+
# Compare scales
277+
assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), (
278+
f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}"
279+
)

torchao/prototype/blockwise_fp8_training/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)