Skip to content

Commit e818cb0

Browse files
committed
Check numerical equivalence / closeness between different kernel preferences
Summary: This PR checks different kernel preferences for Float8Tensor are similar in numerics (AUTO, TORCH and FBGEMM) triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not 1. difference in quantize op main difference seems to be the triton implementation is using: ``` a_scale = MAX_FP8 / max_abs then do a_scale = 1.0 / a_scale a_fp8 = a * a_scale ``` while torch is doing: ``` a_scale = max_abs / MAX_FP8 a_fp8 = a / a_scale ``` Also the hp_value_lb and hp_value_ub settings are slightly different triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392 torchao choose scale and quantize code: https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183 https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283 2. (potentially) difference in matrix multiplication ops TORCH and AUTO/FBGEMM are using different quantized mm ops Added a reverse option to bring sqnr closer: ``` granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16) granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16) .granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16) granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16) ``` Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2651, branch: jerryzh168/stack/15
1 parent 24cd4fa commit e818cb0

File tree

4 files changed

+70
-6
lines changed

4 files changed

+70
-6
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,59 @@ def test_slice(self, granularity):
268268
sqnr = compute_error(res, res_ref)
269269
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
270270

271+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
272+
# Inputs are (M,..), K, N
273+
@common_utils.parametrize(
274+
"sizes",
275+
[
276+
((128,), 256, 128),
277+
((32, 128), 64, 256),
278+
],
279+
)
280+
def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
281+
"""Test different kernel preferences have the same numerics for float8 dynamic activation
282+
and float8 weight config
283+
"""
284+
M, N, K = sizes
285+
dtype = torch.bfloat16
286+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
287+
# Create a linear layer with bfloat16 dtype
288+
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
289+
290+
kernel_preferences = [
291+
KernelPreference.TORCH,
292+
KernelPreference.AUTO,
293+
KernelPreference.FBGEMM,
294+
]
295+
quantized_outputs = {}
296+
for kp in kernel_preferences:
297+
config = Float8DynamicActivationFloat8WeightConfig(
298+
granularity=granularity, kernel_preference=kp
299+
)
300+
quantized_model = copy.deepcopy(model)
301+
quantize_(quantized_model, config)
302+
quantized_outputs[kp] = quantized_model(input_tensor)
303+
304+
from torchao.quantization.utils import compute_error
305+
306+
# comparing numerics between different kernel preferences, using TORCH as the standard
307+
kp_and_res = list(quantized_outputs.items())
308+
for i in range(1, len(kp_and_res)):
309+
kp, res = kp_and_res[i]
310+
print(
311+
"granularity:",
312+
granularity,
313+
" sizes:",
314+
sizes,
315+
" kp:",
316+
kp,
317+
compute_error(res, kp_and_res[0][1]),
318+
)
319+
self.assertTrue(
320+
compute_error(res, kp_and_res[0][1]) > 28,
321+
f"mismatch between {kp=} and {kp_and_res[0]=}, {sizes=}, {granularity=}",
322+
)
323+
271324
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
272325
def test_slice_preserves_aliasing(self, granularity):
273326
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16401640
weight_dtype: torch.dtype = e4m3_dtype
16411641
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
16421642
mm_config: Optional[Float8MMConfig] = None
1643-
activation_value_lb: Optional[float] = None
1643+
activation_value_lb: Optional[float] = 1e-12
16441644
activation_value_ub: Optional[float] = None
16451645
kernel_preference: KernelPreference = KernelPreference.AUTO
16461646
set_inductor_config: bool = True

torchao/quantization/quant_primitives.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,8 +2185,9 @@ def _choose_scale_float8(
21852185
block_size: List[int],
21862186
float8_dtype: torch.dtype = torch.float8_e4m3fn,
21872187
scale_dtype: torch.dtype = torch.float32,
2188-
hp_value_lb: Optional[float] = None,
2188+
hp_value_lb: Optional[float] = 1e-12,
21892189
hp_value_ub: Optional[float] = None,
2190+
reverse: bool = False,
21902191
) -> torch.Tensor:
21912192
"""
21922193
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -2214,7 +2215,11 @@ def _choose_scale_float8(
22142215
max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True)
22152216
if hp_value_lb is not None or hp_value_ub is not None:
22162217
max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub)
2217-
scale = max_abs / quant_max
2218+
if reverse:
2219+
scale = quant_max / max_abs.to(torch.float32)
2220+
scale[scale == float("inf")] = 1.0
2221+
else:
2222+
scale = max_abs / quant_max
22182223
# Reshape scale back to match the expected output shape
22192224
# The scale tensor should have the same shape as the input divided by block_size
22202225
output_shape = [
@@ -2284,6 +2289,7 @@ def _quantize_affine_float8(
22842289
tensor: torch.Tensor,
22852290
scale: torch.Tensor,
22862291
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2292+
reverse: bool = False,
22872293
) -> torch.Tensor:
22882294
"""
22892295
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
@@ -2293,7 +2299,10 @@ def _quantize_affine_float8(
22932299
# Expand scale to match tensor dimensions for block-wise quantization
22942300
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)
22952301

2296-
tensor_scaled = tensor_fp32 / scale_expanded
2302+
if reverse:
2303+
tensor_scaled = tensor_fp32 * scale_expanded
2304+
else:
2305+
tensor_scaled = tensor_fp32 / scale_expanded
22972306
max_value = torch.finfo(float8_dtype).max
22982307
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
22992308
fp8_tensor = tensor_clamped.to(float8_dtype)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def to_float8(
169169
float8_dtype: torch.dtype = torch.float8_e4m3fn,
170170
granularity: FP8Granularity = PerRow(),
171171
mm_config: Optional[Float8MMConfig] = None,
172-
hp_value_lb: Optional[float] = None,
172+
hp_value_lb: Optional[float] = 1e-12,
173173
hp_value_ub: Optional[float] = None,
174174
kernel_preference: KernelPreference = KernelPreference.AUTO,
175175
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
@@ -209,8 +209,10 @@ def to_float8(
209209
block_size=block_size,
210210
hp_value_lb=hp_value_lb,
211211
hp_value_ub=hp_value_ub,
212+
reverse=True,
212213
)
213-
data = _quantize_affine_float8(hp_tensor, scale, float8_dtype)
214+
data = _quantize_affine_float8(hp_tensor, scale, float8_dtype, reverse=True)
215+
scale = 1.0 / scale
214216

215217
hp_dtype = hp_tensor.dtype
216218
return Float8Tensor(

0 commit comments

Comments
 (0)