From 1506c0d82fc6e70452e5c6fc4591176f5c16dc24 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 31 Jul 2025 17:26:38 -0700 Subject: [PATCH] Check numerical equivalence / closeness between different kernel preferences Summary: This PR checks different kernel preferences for Float8Tensor are similar in numerics (AUTO, TORCH and FBGEMM) triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not 1. difference in quantize op main difference seems to be the triton implementation is using: ``` a_scale = MAX_FP8 / max_abs then do a_scale = 1.0 / a_scale a_fp8 = a * a_scale ``` while torch is doing: ``` a_scale = max_abs / MAX_FP8 a_fp8 = a / a_scale ``` Also the hp_value_lb and hp_value_ub settings are slightly different triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392 torchao choose scale and quantize code: https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183 https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283 2. (potentially) difference in matrix multiplication ops TORCH and AUTO/FBGEMM are using different quantized mm ops Added a reverse option to bring sqnr closer: ``` granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16) granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16) ``` Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2651, branch: jerryzh168/stack/15 --- .../workflows/float8/test_float8_tensor.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 5372bb280d..814efce03c 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -268,6 +268,61 @@ def test_slice(self, granularity): sqnr = compute_error(res, res_ref) self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_kernel_preference_numerical_equivalence(self, granularity, sizes): + """Test different kernel preferences have the same numerics for float8 dynamic activation + and float8 weight config + """ + M, N, K = sizes + dtype = torch.bfloat16 + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + # reference kernel preference and results + # we are using KerenelPreference.TORCH as the reference + kp_ref = KernelPreference.TORCH + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, kernel_preference=kp_ref + ) + quantized_model = copy.deepcopy(model) + quantize_(quantized_model, config) + res_ref = quantized_model(input_tensor) + + other_kernel_preferences = [ + KernelPreference.AUTO, + ] + if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(): + other_kernel_preferences.append(KernelPreference.FBGEMM) + + quantized_outputs = {} + for kp in other_kernel_preferences: + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, kernel_preference=kp + ) + quantized_model = copy.deepcopy(model) + quantize_(quantized_model, config) + quantized_outputs[kp] = quantized_model(input_tensor) + + from torchao.quantization.utils import compute_error + + # comparing numerics between different kernel preferences, using TORCH as the standard + kp_and_res = list(quantized_outputs.items()) + for i in range(len(kp_and_res)): + kp, res = kp_and_res[i] + self.assertTrue( + compute_error(res, res_ref) > 28, + f"mismatch between {kp=} and {kp_ref}, {sizes=}, {granularity=}", + ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) def test_slice_preserves_aliasing(self, granularity): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)