|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
| 10 | +triton = pytest.importorskip("triton", reason="Triton required to run this test") |
| 11 | + |
| 12 | +from packaging import version |
| 13 | +from torchao.float8.float8_utils import compute_error |
| 14 | +from torchao.prototype.blockwise_fp8_training.kernels import ( |
| 15 | + blockwise_fp8_gemm_1x128_128x1, |
| 16 | + blockwise_fp8_gemm_1x128_128x128, |
| 17 | + fp8_blockwise_act_quant_lhs, |
| 18 | + fp8_blockwise_act_quant_rhs, |
| 19 | + fp8_blockwise_act_quant_transposed_lhs, |
| 20 | + fp8_blockwise_weight_quant_rhs, |
| 21 | + fp8_blockwise_weight_quant_transposed_rhs, |
| 22 | + torch_blockwise_scale_act_quant_lhs, |
| 23 | + torch_blockwise_scale_act_quant_rhs, |
| 24 | + torch_blockwise_scale_weight_quant, |
| 25 | +) |
| 26 | +from torchao.testing.utils import skip_if_rocm |
| 27 | + |
| 28 | +BLOCKWISE_SIZE_MNK = [ |
| 29 | + (128, 128, 128), |
| 30 | + (2, 512, 128), |
| 31 | + (2, 5120, 1280), |
| 32 | + (3, 2048, 2048), |
| 33 | + (4, 3584, 640), |
| 34 | + (13, 8704, 8576), |
| 35 | + (26, 18944, 1664), |
| 36 | + (67, 6656, 1408), |
| 37 | +] |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 41 | +@pytest.mark.skipif( |
| 42 | + version.parse(triton.__version__) < version.parse("3.3.0"), |
| 43 | + reason="Triton version < 3.3.0, test skipped", |
| 44 | +) |
| 45 | +@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) |
| 46 | +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) |
| 47 | +def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): |
| 48 | + # Simulate output = input @ weight.T |
| 49 | + A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") |
| 50 | + B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") |
| 51 | + C = A @ B.T |
| 52 | + A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype) |
| 53 | + B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype) |
| 54 | + C_q = blockwise_fp8_gemm_1x128_128x128(A_q, 1.0 / A_s, B_t_q, 1.0 / B_t_s) |
| 55 | + assert not C_q.isnan().any(), "C_q must not contain NaNs" |
| 56 | + |
| 57 | + sqnr = compute_error(C, C_q) |
| 58 | + min_sqnr = 28.0 |
| 59 | + print(f"blockwise_fp8_gemm_1x128_128x128 ({M},{N},{K}) SQNR: {sqnr}") |
| 60 | + assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" |
| 61 | + |
| 62 | + |
| 63 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 64 | +@pytest.mark.skipif( |
| 65 | + version.parse(triton.__version__) < version.parse("3.3.0"), |
| 66 | + reason="Triton version < 3.3.0, test skipped", |
| 67 | +) |
| 68 | +@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) |
| 69 | +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) |
| 70 | +def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): |
| 71 | + # Simulate grad_weight = grad_output_t @ input |
| 72 | + A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda") |
| 73 | + B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda") |
| 74 | + C = A.T @ B |
| 75 | + A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype) |
| 76 | + B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype) |
| 77 | + C_q = blockwise_fp8_gemm_1x128_128x1(A_t_q, 1.0 / A_t_s, B_q, 1.0 / B_s) |
| 78 | + |
| 79 | + assert not C_q.isnan().any(), "C_q must not contain NaNs" |
| 80 | + assert C.dtype == torch.bfloat16 |
| 81 | + assert C_q.dtype == torch.bfloat16 |
| 82 | + |
| 83 | + sqnr = compute_error(C, C_q) |
| 84 | + min_sqnr = 28.0 |
| 85 | + print(f"blockwise_fp8_gemm_1x128_128x1 ({M},{N},{K}) SQNR: {sqnr}") |
| 86 | + assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" |
| 87 | + |
| 88 | + |
| 89 | +@skip_if_rocm("ROCm not supported") |
| 90 | +@pytest.mark.parametrize("block_size", [128, 256]) |
| 91 | +def test_triton_quantize_fp8_act_quant_lhs(block_size): |
| 92 | + device = "cuda" |
| 93 | + M, K = 4096, 1024 |
| 94 | + x = torch.randn(M, K, device=device) |
| 95 | + |
| 96 | + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the |
| 97 | + # quantized tensor will have NaNs due to division by 0 |
| 98 | + x[0, :block_size] = 0.0 |
| 99 | + |
| 100 | + # Get the quantized tensor and scales using triton implementation |
| 101 | + triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs( |
| 102 | + x, |
| 103 | + block_size=block_size, |
| 104 | + ) |
| 105 | + |
| 106 | + # Get the quantized tensor and scales using reference implementation |
| 107 | + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(x, tile_size=block_size) |
| 108 | + |
| 109 | + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 110 | + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 111 | + |
| 112 | + # Convert both to float32 for comparison |
| 113 | + triton_fp32 = triton_fp8.to(torch.float32) |
| 114 | + ref_fp32 = ref_fp8.to(torch.float32) |
| 115 | + |
| 116 | + # Check that the quantized tensors are close |
| 117 | + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), ( |
| 118 | + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" |
| 119 | + ) |
| 120 | + |
| 121 | + # Compare scales |
| 122 | + assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), ( |
| 123 | + f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}" |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +@skip_if_rocm("ROCm not supported") |
| 128 | +@pytest.mark.parametrize("block_size", [128, 256]) |
| 129 | +def test_triton_quantize_fp8_act_quant_rhs(block_size: int): |
| 130 | + device = "cuda" |
| 131 | + M, K = 4096, 1024 |
| 132 | + x = torch.randn(M, K, device=device) |
| 133 | + |
| 134 | + # Set one block to 0s, so if nan guards/EPS are not applied, the |
| 135 | + # quantized tensor will have NaNs due to division by 0 |
| 136 | + x[:block_size, :block_size] = 0.0 |
| 137 | + |
| 138 | + # Get the quantized tensor and scales using triton implementation |
| 139 | + triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs( |
| 140 | + x, |
| 141 | + block_size=block_size, |
| 142 | + ) |
| 143 | + |
| 144 | + # Get the quantized tensor and scales using reference implementation |
| 145 | + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_rhs(x, block_size=block_size) |
| 146 | + |
| 147 | + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 148 | + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 149 | + |
| 150 | + # Convert both to float32 for comparison |
| 151 | + triton_fp32 = triton_fp8.to(torch.float32) |
| 152 | + ref_fp32 = ref_fp8.to(torch.float32) |
| 153 | + |
| 154 | + # Check that the quantized tensors are close |
| 155 | + assert torch.allclose(triton_fp32, ref_fp32, atol=1e-3), ( |
| 156 | + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" |
| 157 | + ) |
| 158 | + |
| 159 | + # Compare scales |
| 160 | + assert torch.allclose(triton_scale, ref_scale, atol=1e-3), ( |
| 161 | + f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}" |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +@skip_if_rocm("ROCm not supported") |
| 166 | +@pytest.mark.parametrize("block_size", [128, 256]) |
| 167 | +@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) |
| 168 | +def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): |
| 169 | + device = "cuda" |
| 170 | + x = torch.randn(M, K, device=device) |
| 171 | + |
| 172 | + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the |
| 173 | + # quantized tensor will have NaNs due to division by 0 |
| 174 | + x[0, :block_size] = 0.0 |
| 175 | + |
| 176 | + # Get the quantized tensor and scales using triton implementation |
| 177 | + triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs( |
| 178 | + x, |
| 179 | + block_size=block_size, |
| 180 | + ) |
| 181 | + |
| 182 | + # Get the quantized tensor and scales using reference implementation |
| 183 | + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs( |
| 184 | + x.t().contiguous(), tile_size=block_size |
| 185 | + ) |
| 186 | + |
| 187 | + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 188 | + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 189 | + |
| 190 | + # Convert both to float32 for comparison |
| 191 | + triton_fp32 = triton_fp8.to(torch.float32) |
| 192 | + ref_fp32 = ref_fp8.to(torch.float32) |
| 193 | + |
| 194 | + # Check that the quantized tensors are close |
| 195 | + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), ( |
| 196 | + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" |
| 197 | + ) |
| 198 | + |
| 199 | + # Compare scales |
| 200 | + assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), ( |
| 201 | + f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}" |
| 202 | + ) |
| 203 | + |
| 204 | + |
| 205 | +@skip_if_rocm("ROCm not supported") |
| 206 | +@pytest.mark.parametrize("block_size", [128, 256]) |
| 207 | +@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) |
| 208 | +def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): |
| 209 | + device = "cuda" |
| 210 | + x = torch.randn(M, K, device=device) |
| 211 | + |
| 212 | + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the |
| 213 | + # quantized tensor will have NaNs due to division by 0 |
| 214 | + x[:block_size, :block_size] = 0.0 |
| 215 | + |
| 216 | + # Get the quantized tensor and scales using triton implementation |
| 217 | + triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs( |
| 218 | + x, |
| 219 | + block_size=block_size, |
| 220 | + ) |
| 221 | + # Get the quantized tensor and scales using reference implementation |
| 222 | + ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=block_size) |
| 223 | + |
| 224 | + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 225 | + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 226 | + |
| 227 | + # Convert both to float32 for comparison |
| 228 | + triton_fp32 = triton_fp8.to(torch.float32) |
| 229 | + ref_fp32 = ref_fp8.to(torch.float32) |
| 230 | + |
| 231 | + # Check that the quantized tensors are close |
| 232 | + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), ( |
| 233 | + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" |
| 234 | + ) |
| 235 | + |
| 236 | + # Compare scales |
| 237 | + assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), ( |
| 238 | + f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}" |
| 239 | + ) |
| 240 | + |
| 241 | + |
| 242 | +@skip_if_rocm("ROCm not supported") |
| 243 | +@pytest.mark.parametrize("block_size", [128, 256]) |
| 244 | +def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): |
| 245 | + device = "cuda" |
| 246 | + M = 512 |
| 247 | + K = 2048 |
| 248 | + x = torch.randn(M, K, device=device) |
| 249 | + |
| 250 | + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the |
| 251 | + # quantized tensor will have NaNs due to division by 0 |
| 252 | + x[:block_size, :block_size] = 0.0 |
| 253 | + |
| 254 | + # Get the quantized tensor and scales using triton implementation |
| 255 | + triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs( |
| 256 | + x, |
| 257 | + block_size=block_size, |
| 258 | + ) |
| 259 | + # Get the quantized tensor and scales using reference implementation |
| 260 | + ref_fp8, ref_scale = torch_blockwise_scale_weight_quant( |
| 261 | + x.t().contiguous(), tile_size=block_size |
| 262 | + ) |
| 263 | + |
| 264 | + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 265 | + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" |
| 266 | + |
| 267 | + # Convert both to float32 for comparison |
| 268 | + triton_fp32 = triton_fp8.to(torch.float32) |
| 269 | + ref_fp32 = ref_fp8.to(torch.float32) |
| 270 | + |
| 271 | + # Check that the quantized tensors are close |
| 272 | + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), ( |
| 273 | + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" |
| 274 | + ) |
| 275 | + |
| 276 | + # Compare scales |
| 277 | + assert torch.allclose(triton_scale, ref_scale, rtol=1e-3, atol=1e-3), ( |
| 278 | + f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}" |
| 279 | + ) |
0 commit comments