Skip to content

Commit 14420ef

Browse files
Add Triton kernels for fp8 blockwise quantization and GEMMs
stack-info: PR: #2617, branch: danielvegamyhre/stack/17
1 parent 0e00df3 commit 14420ef

File tree

9 files changed

+1157
-3
lines changed

9 files changed

+1157
-3
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
11+
12+
from packaging import version
13+
from torchao.float8.float8_utils import compute_error
14+
from torchao.prototype.blockwise_fp8_training.kernels import (
15+
blockwise_fp8_gemm_1x128_128x1,
16+
blockwise_fp8_gemm_1x128_128x128,
17+
fp8_blockwise_act_quant_lhs,
18+
fp8_blockwise_act_quant_rhs,
19+
fp8_blockwise_act_quant_transposed_lhs,
20+
fp8_blockwise_weight_quant_rhs,
21+
fp8_blockwise_weight_quant_transposed_rhs,
22+
torch_blockwise_scale_act_quant_lhs,
23+
torch_blockwise_scale_act_quant_rhs,
24+
torch_blockwise_scale_weight_quant,
25+
)
26+
from torchao.testing.utils import skip_if_rocm
27+
from torchao.utils import is_sm_at_least_90
28+
29+
BLOCKWISE_SIZE_MNK = [
30+
(128, 128, 128),
31+
(2, 512, 128),
32+
(2, 5120, 1280),
33+
(3, 2048, 2048),
34+
(4, 3584, 640),
35+
(13, 8704, 8576),
36+
(26, 18944, 1664),
37+
(67, 6656, 1408),
38+
]
39+
40+
41+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
42+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
43+
@pytest.mark.skipif(
44+
version.parse(triton.__version__) < version.parse("3.3.0"),
45+
reason="Triton version < 3.3.0, test skipped",
46+
)
47+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
48+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
49+
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
50+
# Simulate output = input @ weight.T
51+
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
52+
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
53+
C = A @ B.T
54+
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype)
55+
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
56+
C_q = blockwise_fp8_gemm_1x128_128x128(A_q, 1.0 / A_s, B_t_q, 1.0 / B_t_s)
57+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
58+
59+
sqnr = compute_error(C, C_q)
60+
min_sqnr = 28.0
61+
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"
62+
63+
64+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
65+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
66+
@pytest.mark.skipif(
67+
version.parse(triton.__version__) < version.parse("3.3.0"),
68+
reason="Triton version < 3.3.0, test skipped",
69+
)
70+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
71+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
72+
def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype):
73+
# Simulate grad_weight = grad_output_t @ input
74+
A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda")
75+
B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda")
76+
C = A.T @ B
77+
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
78+
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype)
79+
C_q = blockwise_fp8_gemm_1x128_128x1(A_t_q, 1.0 / A_t_s, B_q, 1.0 / B_s)
80+
81+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
82+
assert C.dtype == torch.bfloat16
83+
assert C_q.dtype == torch.bfloat16
84+
85+
sqnr = compute_error(C, C_q)
86+
min_sqnr = 28.0
87+
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"
88+
89+
90+
@skip_if_rocm("ROCm not supported")
91+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
92+
@pytest.mark.parametrize("block_size", [128, 256])
93+
def test_triton_quantize_fp8_act_quant_lhs(block_size):
94+
device = "cuda"
95+
M, K = 4096, 1024
96+
x = torch.randn(M, K, device=device)
97+
98+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
99+
# quantized tensor will have NaNs due to division by 0
100+
x[0, :block_size] = 0.0
101+
102+
# Get the quantized tensor and scales using triton implementation
103+
triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs(
104+
x,
105+
block_size=block_size,
106+
)
107+
108+
# Get the quantized tensor and scales using reference implementation
109+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(x, tile_size=block_size)
110+
111+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
112+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
113+
114+
# Convert both to float32 for comparison
115+
triton_fp32 = triton_fp8.to(torch.float32)
116+
ref_fp32 = ref_fp8.to(torch.float32)
117+
118+
# Check that the quantized tensors are close
119+
torch.testing.assert_close(
120+
triton_fp32,
121+
ref_fp32,
122+
atol=0,
123+
rtol=0,
124+
msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}",
125+
)
126+
127+
# Compare scales
128+
torch.testing.assert_close(
129+
triton_scale,
130+
ref_scale,
131+
atol=0,
132+
rtol=0,
133+
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
134+
)
135+
136+
137+
@skip_if_rocm("ROCm not supported")
138+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
139+
@pytest.mark.parametrize("block_size", [128, 256])
140+
def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
141+
device = "cuda"
142+
M, K = 4096, 1024
143+
x = torch.randn(M, K, device=device)
144+
145+
# Set one block to 0s, so if nan guards/EPS are not applied, the
146+
# quantized tensor will have NaNs due to division by 0
147+
x[:block_size, :block_size] = 0.0
148+
149+
# Get the quantized tensor and scales using triton implementation
150+
triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs(
151+
x,
152+
block_size=block_size,
153+
)
154+
155+
# Get the quantized tensor and scales using reference implementation
156+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_rhs(x, block_size=block_size)
157+
158+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
159+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
160+
161+
# Convert both to float32 for comparison
162+
triton_fp32 = triton_fp8.to(torch.float32)
163+
ref_fp32 = ref_fp8.to(torch.float32)
164+
165+
# Check that the quantized tensors are close
166+
torch.testing.assert_close(
167+
triton_fp32,
168+
ref_fp32,
169+
atol=0,
170+
rtol=0,
171+
msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}",
172+
)
173+
174+
# Compare scales
175+
torch.testing.assert_close(
176+
triton_scale,
177+
ref_scale,
178+
atol=0,
179+
rtol=0,
180+
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
181+
)
182+
183+
184+
@skip_if_rocm("ROCm not supported")
185+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
186+
@pytest.mark.parametrize("block_size", [128, 256])
187+
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
188+
def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
189+
device = "cuda"
190+
x = torch.randn(M, K, device=device)
191+
192+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
193+
# quantized tensor will have NaNs due to division by 0
194+
x[0, :block_size] = 0.0
195+
196+
# Get the quantized tensor and scales using triton implementation
197+
triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs(
198+
x,
199+
block_size=block_size,
200+
)
201+
202+
# Get the quantized tensor and scales using reference implementation
203+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(
204+
x.t().contiguous(), tile_size=block_size
205+
)
206+
207+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
208+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
209+
210+
# Convert both to float32 for comparison
211+
triton_fp32 = triton_fp8.to(torch.float32)
212+
ref_fp32 = ref_fp8.to(torch.float32)
213+
214+
# Check that the quantized tensors are close
215+
torch.testing.assert_close(
216+
triton_fp32,
217+
ref_fp32,
218+
atol=0,
219+
rtol=0,
220+
msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}",
221+
)
222+
223+
# Compare scales
224+
torch.testing.assert_close(
225+
triton_scale,
226+
ref_scale,
227+
atol=0,
228+
rtol=0,
229+
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
230+
)
231+
232+
233+
@skip_if_rocm("ROCm not supported")
234+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
235+
@pytest.mark.parametrize("block_size", [128, 256])
236+
@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)])
237+
def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
238+
device = "cuda"
239+
x = torch.randn(M, K, device=device)
240+
241+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
242+
# quantized tensor will have NaNs due to division by 0
243+
x[:block_size, :block_size] = 0.0
244+
245+
# Get the quantized tensor and scales using triton implementation
246+
triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs(
247+
x,
248+
block_size=block_size,
249+
)
250+
# Get the quantized tensor and scales using reference implementation
251+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=block_size)
252+
253+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
254+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
255+
256+
# Convert both to float32 for comparison
257+
triton_fp32 = triton_fp8.to(torch.float32)
258+
ref_fp32 = ref_fp8.to(torch.float32)
259+
260+
# Check that the quantized tensors are close
261+
torch.testing.assert_close(
262+
triton_fp32,
263+
ref_fp32,
264+
atol=0,
265+
rtol=0,
266+
msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}",
267+
)
268+
269+
# Compare scales
270+
torch.testing.assert_close(
271+
triton_scale,
272+
ref_scale,
273+
atol=0,
274+
rtol=0,
275+
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
276+
)
277+
278+
279+
@skip_if_rocm("ROCm not supported")
280+
@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0")
281+
@pytest.mark.parametrize("block_size", [128, 256])
282+
def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int):
283+
device = "cuda"
284+
M = 512
285+
K = 2048
286+
x = torch.randn(M, K, device=device)
287+
288+
# Set one scaling block to 0s, so if nan guards/EPS are not applied, the
289+
# quantized tensor will have NaNs due to division by 0
290+
x[:block_size, :block_size] = 0.0
291+
292+
# Get the quantized tensor and scales using triton implementation
293+
triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs(
294+
x,
295+
block_size=block_size,
296+
)
297+
# Get the quantized tensor and scales using reference implementation
298+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(
299+
x.t().contiguous(), tile_size=block_size
300+
)
301+
302+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
303+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
304+
305+
# Convert both to float32 for comparison
306+
triton_fp32 = triton_fp8.to(torch.float32)
307+
ref_fp32 = ref_fp8.to(torch.float32)
308+
309+
# Check that the quantized tensors are close
310+
torch.testing.assert_close(
311+
triton_fp32,
312+
ref_fp32,
313+
atol=0,
314+
rtol=0,
315+
msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}",
316+
)
317+
318+
# Compare scales
319+
torch.testing.assert_close(
320+
triton_scale,
321+
ref_scale,
322+
atol=0,
323+
rtol=0,
324+
msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}",
325+
)

test/prototype/test_blockwise_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
triton = pytest.importorskip("triton", reason="Triton required to run this test")
1313

14-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
14+
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
1515
blockwise_fp8_gemm,
1616
fp8_blockwise_act_quant,
1717
fp8_blockwise_weight_dequant,
File renamed without changes.

torchao/prototype/blockwise_fp8/blockwise_linear.py renamed to torchao/prototype/blockwise_fp8_inference/blockwise_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
10+
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
1111
blockwise_fp8_gemm,
1212
fp8_blockwise_act_quant,
1313
)

torchao/prototype/blockwise_fp8_training/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)