Skip to content

Commit 05e1a19

Browse files
Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
1 parent 0e00df3 commit 05e1a19

File tree

10 files changed

+1141
-402
lines changed

10 files changed

+1141
-402
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8.kernels import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
import argparse
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from utils import benchmark_microseconds
16+
17+
from torchao.prototype.blockwise_fp8.kernels import (
18+
fp8_blockwise_act_quant,
19+
fp8_blockwise_weight_quant,
20+
torch_blockwise_scale_act_quant,
21+
torch_blockwise_scale_weight_quant,
22+
triton_quantize_fp8_block,
23+
)
24+
25+
device = torch.device("cuda")
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
A_shape: tuple[int]
34+
block_m: int
35+
block_k: int
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
torch_us: float
41+
fbgemm_us: float
42+
deepgemm_us: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
A_shapes = [
53+
(1024, 1024),
54+
(2048, 2048),
55+
(4096, 4096),
56+
(8192, 8192),
57+
(16384, 16384),
58+
(32768, 32768),
59+
]
60+
block_m_opts = [1, 128]
61+
block_k_opts = [
62+
128,
63+
]
64+
configs = []
65+
for A_shape, block_m, block_k in itertools.product(
66+
A_shapes,
67+
block_m_opts,
68+
block_k_opts,
69+
):
70+
configs.append(
71+
ExperimentConfig(
72+
A_shape=A_shape,
73+
block_m=block_m,
74+
block_k=block_k,
75+
)
76+
)
77+
return configs
78+
79+
80+
def run_experiment(
81+
config: ExperimentConfig, args: argparse.Namespace
82+
) -> ExperimentResult:
83+
A = torch.randn(
84+
*config.A_shape,
85+
dtype=torch.bfloat16,
86+
device=device,
87+
)
88+
89+
# Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size)
90+
# and weight quantization (block_size x block_size)
91+
if config.block_m == 1:
92+
torch_func = torch.compile(torch_blockwise_scale_act_quant)
93+
deepgemm_func = fp8_blockwise_act_quant
94+
else:
95+
torch_func = torch.compile(torch_blockwise_scale_weight_quant)
96+
deepgemm_func = fp8_blockwise_weight_quant
97+
98+
# Validate output shapes and strides
99+
torch_out, torch_scale = torch_func(A, tile_size=config.block_k)
100+
deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k)
101+
fbgemm_out, fbgemm_scale = triton_quantize_fp8_block(
102+
A, block_m=config.block_m, block_k=config.block_k, k_major=True
103+
)
104+
assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape
105+
assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride()
106+
assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape
107+
assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride()
108+
109+
# Do benchmarking
110+
torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k)
111+
deepgemm_us = benchmark_microseconds(
112+
fp8_blockwise_act_quant, A, block_size=config.block_k
113+
)
114+
fbgemm_us = benchmark_microseconds(
115+
triton_quantize_fp8_block,
116+
A,
117+
block_m=config.block_m,
118+
block_k=config.block_k,
119+
k_major=True,
120+
)
121+
122+
return ExperimentResult(
123+
torch_us=round(torch_us, 3),
124+
fbgemm_us=round(fbgemm_us, 3),
125+
deepgemm_us=round(deepgemm_us, 3),
126+
)
127+
128+
129+
def print_results(experiments: List[Experiment]):
130+
headers = [
131+
"A_shape",
132+
"block_shape",
133+
"torch_us",
134+
"fbgemm_us",
135+
"deepgemm_us",
136+
]
137+
rows = []
138+
for experiment in experiments:
139+
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
140+
block_shape = f"({experiment.config.block_m},{experiment.config.block_k})"
141+
rows.append(
142+
[
143+
A_shape,
144+
block_shape,
145+
experiment.result.torch_us,
146+
experiment.result.fbgemm_us,
147+
experiment.result.deepgemm_us,
148+
]
149+
)
150+
print(tabulate(rows, headers=headers))
151+
152+
153+
def main(args: argparse.Namespace):
154+
torch.random.manual_seed(123)
155+
configs = get_configs()
156+
results = []
157+
for config in tqdm(configs):
158+
result = run_experiment(config, args)
159+
results.append(Experiment(config=config, result=result))
160+
161+
# Use Tabulate to print results
162+
print_results(results)
163+
164+
165+
if __name__ == "__main__":
166+
arg_parser = argparse.ArgumentParser()
167+
arg_parser.add_argument("--compile", action="store_true")
168+
args = arg_parser.parse_args()
169+
main(args)

benchmarks/float8/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch.utils.benchmark as benchmark
1313
from torch.profiler import ProfilerActivity, profile
14+
from triton.testing import do_bench
1415

1516

1617
def profiler_output_to_filtered_time_by_kernel_name(
@@ -428,3 +429,12 @@ def do_benchmarks(
428429
tops_sec = float(tops) / time_sec
429430
pct_top_peak = tops_sec / peak_tops
430431
return time_sec, tops_sec, pct_top_peak
432+
433+
434+
def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs):
435+
return (
436+
do_bench(
437+
lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median"
438+
)
439+
* 1e3
440+
)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
11+
12+
from packaging import version
13+
from torchao.float8.float8_utils import compute_error
14+
from torchao.prototype.blockwise_fp8.kernels import (
15+
blockwise_fp8_gemm_1x128_1x128,
16+
blockwise_fp8_gemm_1x128_128x128,
17+
fp8_blockwise_act_quant,
18+
fp8_blockwise_weight_dequant,
19+
fp8_blockwise_weight_quant,
20+
torch_blockwise_scale_act_quant,
21+
torch_blockwise_scale_weight_quant,
22+
triton_quantize_fp8_block,
23+
)
24+
from torchao.testing.utils import skip_if_rocm
25+
from torchao.utils import (
26+
is_sm_at_least_89,
27+
)
28+
29+
BLOCKWISE_SIZE_MNK = [
30+
(2, 512, 128),
31+
(3, 2048, 2048),
32+
(4, 3584, 640),
33+
(13, 8704, 8576),
34+
(26, 18944, 1664),
35+
(67, 6656, 1408),
36+
]
37+
38+
39+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
40+
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
41+
@pytest.mark.parametrize(
42+
"dtype",
43+
[torch.float8_e4m3fn, torch.float8_e5m2]
44+
if is_sm_at_least_89()
45+
else [torch.float8_e5m2],
46+
)
47+
def test_blockwise_quant_dequant(_, N, K, dtype):
48+
x = torch.randn(N, K).cuda()
49+
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
50+
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
51+
sqnr = compute_error(x, x_reconstructed)
52+
assert sqnr >= 25.0, f"SQNR {sqnr:.2f} must be >= 25.0"
53+
54+
55+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
56+
@pytest.mark.skipif(
57+
version.parse(triton.__version__) < version.parse("3.3.0"),
58+
reason="Triton version < 3.3.0, test skipped",
59+
)
60+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
61+
@pytest.mark.parametrize(
62+
"dtype",
63+
[torch.float8_e4m3fn, torch.float8_e5m2]
64+
if is_sm_at_least_89()
65+
else [torch.float8_e5m2],
66+
)
67+
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
68+
A = torch.randn(M, K).cuda()
69+
B = torch.randn(N, K).cuda()
70+
C = A @ B.T
71+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
72+
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
73+
C_q = blockwise_fp8_gemm_1x128_128x128(A_q, A_s, B_q, B_s)
74+
sqnr = compute_error(C, C_q)
75+
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"
76+
77+
78+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
79+
@pytest.mark.skipif(
80+
version.parse(triton.__version__) < version.parse("3.3.0"),
81+
reason="Triton version < 3.3.0, test skipped",
82+
)
83+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
84+
@pytest.mark.parametrize(
85+
"dtype",
86+
[torch.float8_e4m3fn, torch.float8_e5m2]
87+
if is_sm_at_least_89()
88+
else [torch.float8_e5m2],
89+
)
90+
def test_blockwise_fp8_gemm_1x128_1x128(M, N, K, dtype):
91+
A = torch.randn(M, K).cuda()
92+
B = torch.randn(N, K).cuda()
93+
C = A @ B.T
94+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
95+
B_q, B_s = fp8_blockwise_act_quant(B, dtype=dtype)
96+
C_q = blockwise_fp8_gemm_1x128_1x128(A_q, A_s, B_q, B_s)
97+
sqnr = compute_error(C, C_q)
98+
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"
99+
100+
101+
@skip_if_rocm("ROCm not supported")
102+
@pytest.mark.parametrize("tile_size", [128, 256])
103+
def test_triton_quantize_fp8_act_quant(tile_size: int):
104+
device = "cuda"
105+
M, K = 256, 256
106+
x = torch.randn(M, K, device=device)
107+
108+
# Get the quantized tensor and scales using triton implementation
109+
# Use block_m=1 to match the narrow tiles (1 x tile_size) in the reference implementation
110+
triton_fp8, triton_scale = triton_quantize_fp8_block(
111+
x, block_m=1, block_k=tile_size
112+
)
113+
114+
# Get the quantized tensor and scales using reference implementation
115+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant(x, tile_size=tile_size)
116+
117+
# Convert both to float32 for comparison
118+
triton_fp32 = triton_fp8.to(torch.float32)
119+
ref_fp32 = ref_fp8.to(torch.float32)
120+
121+
# Check that the quantized tensors are close
122+
# Note: We use a relatively high tolerance because the implementations might have
123+
# slight differences in how they handle edge cases, rounding, etc.
124+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
125+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
126+
)
127+
128+
# Check that the scales are close
129+
# Note: The scales might be stored differently (reciprocal vs. direct), so we need to
130+
# be careful about how we compare them
131+
132+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
133+
# In torch_blockwise_scale_act_quant, scales are stored directly
134+
# So we need to take the reciprocal of one of them for comparison
135+
136+
# Reshape triton_scale to match ref_scale shape for comparison
137+
triton_scale_reshaped = triton_scale.reshape(M, -1)
138+
139+
# Compare reciprocal of triton_scale with ref_scale
140+
assert torch.allclose(
141+
1.0 / triton_scale_reshaped, ref_scale, rtol=1e-2, atol=1e-2
142+
), (
143+
f"Scales differ: max diff = {(1.0 / triton_scale_reshaped - ref_scale).abs().max().item()}"
144+
)
145+
146+
147+
@skip_if_rocm("ROCm not supported")
148+
@pytest.mark.parametrize("tile_size", [128, 256])
149+
def test_triton_quantize_fp8_weight_quant(tile_size: int):
150+
device = "cuda"
151+
# Make sure dimensions are multiples of tile_size for clean comparison
152+
M = tile_size * 2
153+
K = tile_size * 2
154+
x = torch.randn(M, K, device=device)
155+
156+
# Get the quantized tensor and scales using triton implementation
157+
triton_fp8, triton_scale = triton_quantize_fp8_block(
158+
x, block_m=tile_size, block_k=tile_size
159+
)
160+
161+
# Get the quantized tensor and scales using reference implementation
162+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=tile_size)
163+
164+
# Convert both to float32 for comparison
165+
triton_fp32 = triton_fp8.to(torch.float32)
166+
ref_fp32 = ref_fp8.to(torch.float32)
167+
168+
# Check that the quantized tensors are close
169+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
170+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
171+
)
172+
173+
# Check that the scales are close
174+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
175+
# In torch_blockwise_scale_weight_quant, scales are stored directly
176+
177+
# Compare reciprocal of triton_scale with ref_scale
178+
assert torch.allclose(1.0 / triton_scale, ref_scale, rtol=1e-2, atol=1e-2), (
179+
f"Scales differ: max diff = {(1.0 / triton_scale - ref_scale).abs().max().item()}"
180+
)

0 commit comments

Comments
 (0)