From a840ef535dbf91bdfca97742d469a541eb33c1e1 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 13:57:11 +0000 Subject: [PATCH 01/26] quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor --- test/float8/test_compile.py | 59 ++++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 24 +++++++++- torchao/utils.py | 11 +++-- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..fb4d6ea316 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,6 +37,10 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.quantization.quant_primitives import ( + dequantize_affine_float8, + quantize_affine_float8, +) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -392,5 +396,60 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) +@pytest.mark.parametrize( + "float8_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], +) +@pytest.mark.parametrize( + "hp_dtype", + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" +) +def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.quantize_affine_float8.default" + ).run(code_q) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.dequantize_affine_float8.default" + ).run(code_dq) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index df136bc06e..72b3935157 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2270,10 +2270,11 @@ def _expand_scale_to_tensor_shape( return expanded_scale +@_register_custom_op(quant_lib, False) def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype = torch.float8_e4m3fn, + float8_dtype: torch.dtype, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2290,10 +2291,20 @@ def _quantize_affine_float8( return fp8_tensor +@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta") +def _quantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=float8_dtype) + + +@_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype = torch.float32, + output_dtype: torch.dtype, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2305,3 +2316,12 @@ def _dequantize_affine_float8( hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype) + + +@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta") +def _dequantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=output_dtype) diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..99a0a729f5 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib): +def _register_custom_op(lib, implicit=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -206,6 +206,10 @@ def _the_op_that_needs_to_be_preserved(...) """ from torch._inductor.decomposition import register_decomposition + dispatch_key = ( + "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + ) + def decorator(fn): if TORCH_VERSION_AT_LEAST_2_5: from torch._library.infer_schema import infer_schema @@ -221,11 +225,12 @@ def decorator(fn): op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) - lib.impl(op_name, fn, "CompositeImplicitAutograd") + lib.impl(op_name, fn, dispatch_key) lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - register_decomposition([op])(fn) + if implicit: + register_decomposition([op])(fn) return op else: return fn From 02d045b267c20cca276770682d8769ec9d8c258b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 14:31:21 +0000 Subject: [PATCH 02/26] remove redundant unittest.skipIf --- test/float8/test_compile.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index fb4d6ea316..43f0d8e2f2 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -411,9 +411,6 @@ def test_dynamic_scale_numeric_parity( torch.bfloat16, ], ) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" -) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): input = torch.randn(10, 10) with torch.no_grad(): From 9860c56e87ef83986f9ca25b87c32e7a3023f186 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:44:02 +0000 Subject: [PATCH 03/26] fix rebase issue --- test/float8/test_compile.py | 10 ++++------ torchao/quantization/quant_primitives.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 43f0d8e2f2..64feaf7b5d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,10 +37,6 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.quantization.quant_primitives import ( - dequantize_affine_float8, - quantize_affine_float8, -) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -412,6 +408,8 @@ def test_dynamic_scale_numeric_parity( ], ) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 + dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 input = torch.randn(10, 10) with torch.no_grad(): torch._dynamo.reset() @@ -419,7 +417,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): expected_quantized = quantize_affine_float8( input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) expected_dequantized = dequantize_affine_float8( expected_quantized, @@ -430,7 +428,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): torch.compile(quantize_affine_float8), input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) torch.testing.FileCheck().check( "torch.ops.torchao.quantize_affine_float8.default" diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 72b3935157..56e8422197 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2274,7 +2274,7 @@ def _expand_scale_to_tensor_shape( def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2295,7 +2295,7 @@ def _quantize_affine_float8( def _quantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=float8_dtype) @@ -2304,7 +2304,7 @@ def _quantize_affine_float8_meta( def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2322,6 +2322,6 @@ def _dequantize_affine_float8( def _dequantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=output_dtype) From ca662f343e9164afa761e352e2a8e72400501c0a Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:47:32 +0000 Subject: [PATCH 04/26] change dispatch key to a flag decomposed --- torchao/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index 99a0a729f5..4814c7ec63 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib, implicit=True): +def _register_custom_op(lib, decomposed=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -207,7 +207,7 @@ def _the_op_that_needs_to_be_preserved(...) from torch._inductor.decomposition import register_decomposition dispatch_key = ( - "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + "CompositeImplicitAutograd" if decomposed else "CompositeExplicitAutograd" ) def decorator(fn): @@ -229,7 +229,7 @@ def decorator(fn): lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - if implicit: + if decomposed: register_decomposition([op])(fn) return op else: From f51a5bea091e7c6a7c288aa17d8a2484e3f68a65 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 14:59:33 +0000 Subject: [PATCH 05/26] support scaled_mm on inductor --- .../pt2e/test_x86inductor_fusion.py | 68 +++++ .../quantization/pt2e/inductor_passes/x86.py | 237 ++++++++++++++++++ 2 files changed, 305 insertions(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ffaa4573d8..579a9583e5 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2427,6 +2427,74 @@ def matcher_check_fn(): self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) + @skipIfNoONEDNN + # @parametrize("has_bias", [True, False]) + # @parametrize("dtype", [torch.float32, torch.bfloat16]) + # @parametrize("input_dim_exceeds_two", [True, False]) + @parametrize("has_bias", [True, ]) + @parametrize("dtype", [torch.float32, ]) + @parametrize("input_dim_exceeds_two", [False]) + def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): + class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)).to(dtype) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8( + tensor=self.weight.data, + scale=torch.tensor(self.weight_scale), + output_dtype=torch.float + ) + if dtype != torch.float: + weight = weight.to(dtype) + + q_input = torch.ops.torchao.quantize_affine_float8( + tensor=input, + scale=torch.tensor(self.scale), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8( + tensor=q_input, + scale=torch.tensor(self.scale), + output_dtype=torch.float + ) + if dtype != torch.float: + dq_input = dq_input.to(dtype) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + + class Mod(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.l0 = FP8QDQLinear(in_features, out_features) + + def forward(self, x): + y = self.l0(x) + return y + + M1, M2, N, K = 2, 3, 13, 16 + M = M1 * M2 + mod = Mod(N, K) + if input_dim_exceeds_two: + v = torch.randn(M1, M2, N) + else: + v = torch.randn(M, N) + v = v.to(dtype) + + def matcher_check_fn(): + self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], 1) + + self._test_common(mod, (v,), matcher_check_fn) + + @dynamo_config.patch( { "dynamic_shapes": True, diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 4ccb2a1f31..a937e38a63 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2740,6 +2740,241 @@ def _register_qlinear_binary_fusion(): ) +def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): + # + - - - - | - - - - - - | - - - - + + # | dq_per_tensor dq_per_tensor | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(quant_per_tensor) | + # | | | + # | OPT(reshape) | + assert dtype in [torch.float32, torch.bfloat16] + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_dtype"), + ) + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_dq_dtype"), + ) + + dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern + + +def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): + def _inner(match): + # input_contiguous = True + # # Check dequant pattern has only 1 user. + # ( + # linear_node, + # _, + # ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + # input_index = 1 if linear_node.target is aten.addmm.default else 0 + # assert dtype in [torch.float32, torch.bfloat16] + # ( + # dequant_node, + # _, + # _, + # _, + # ) = _get_linear_dq_node( + # linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + # ) + # assert dequant_node.target is quantized_decomposed.dequantize_per_tensor.tensor + + # # only support float8_e4m3 input + # if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: + # return False + + # if len(list(dequant_node.users)) != 1: + # # Ensure the dequant pattern only has 1 user + # # since we will delete the dequant pattern here + # return False + + return True + + return _inner + + +def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), + pass_number=0, + ) + def scaled_mm_fusion(match: Match, *args, **kwargs): + input_contiguous = True + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_tensor = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_tensor = weight_to_bf16_node.args[0] + assert ( + dequant_per_tensor.target + is torch.ops.torchao.dequantize_affine_float8.default + ) + + # Activation QParams + qx, x_scale = ( + kwargs["x"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale = ( + kwargs["q_weight"], + kwargs["w_scale"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + scaled_mm_input_node = qx + if input_dim_exceeds_two: + new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) + new_act_reshape_node = graph.call_function( + torch.ops.aten.reshape.default, args=new_reshape_args + ) + scaled_mm_input_node = new_act_reshape_node + # Insert weight prepack node and the qlinear node + permute_weight_inputs = ( + qw, + t_node.args[1], + ) + permute_weight_op = torch.ops.aten.permute.default + permute_weight_node = graph.call_function( + permute_weight_op, args=permute_weight_inputs + ) + output_scale = torch.tensor(1.0) + new_args: tuple[Any, ...] = ( + scaled_mm_input_node, + permute_weight_node, + x_scale, + w_scale, + bias, + output_scale, # output_scale + dtype, # output_dtype + False, # use_fast_accum + ) + new_linear_node = graph.call_function( + torch.ops.aten._scaled_mm.default, args=new_args + ) + + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + graph.erase_node(linear_node) + if input_dim_exceeds_two: + graph.erase_node(act_reshape_node) + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_tensor) + + counters["inductor"]["scaled_mm_matcher_count"] += 1 + counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) + + +def _register_scaled_mm(): + fp8_linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [False, True] + ) + for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: + patterns = _generate_dequant_fp8_linear_node_pattern( + dtype, input_dim_exceeds_two + ) + for pattern in patterns: + _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) + + @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -2763,6 +2998,8 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() + _register_scaled_mm() + def quant_lift_up(module_graph: torch.fx.graph.Graph): """ From 719793c64420edbfb5c0a5eb7d3a2a0221a7c1f2 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 16:48:47 +0000 Subject: [PATCH 06/26] fix rebase issue --- .../pt2e/test_x86inductor_fusion.py | 14 ++--- .../quantization/pt2e/inductor_passes/x86.py | 58 +++++++++---------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 579a9583e5..43868bfed0 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2426,14 +2426,10 @@ def matcher_check_fn(): if test_for_pointwise_binary: self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) - @skipIfNoONEDNN - # @parametrize("has_bias", [True, False]) - # @parametrize("dtype", [torch.float32, torch.bfloat16]) - # @parametrize("input_dim_exceeds_two", [True, False]) - @parametrize("has_bias", [True, ]) - @parametrize("dtype", [torch.float32, ]) - @parametrize("input_dim_exceeds_two", [False]) + @parametrize("has_bias", [True, False]) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("input_dim_exceeds_two", [True, False]) def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): class FP8QDQLinear(torch.nn.Module): def __init__(self, in_features, out_features): @@ -2450,7 +2446,7 @@ def forward(self, input): weight = torch.ops.torchao.dequantize_affine_float8( tensor=self.weight.data, scale=torch.tensor(self.weight_scale), - output_dtype=torch.float + output_dtype=torch.float, ) if dtype != torch.float: weight = weight.to(dtype) @@ -2463,7 +2459,7 @@ def forward(self, input): dq_input = torch.ops.torchao.dequantize_affine_float8( tensor=q_input, scale=torch.tensor(self.scale), - output_dtype=torch.float + output_dtype=torch.float, ) if dtype != torch.float: dq_input = dq_input.to(dtype) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a937e38a63..1dfb2ce3d2 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2758,7 +2758,7 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): torch.ops.torchao.dequantize_affine_float8.default, KeywordArg("q_weight"), KeywordArg("w_scale"), - KeywordArg("w_dtype"), + output_dtype=KeywordArg("w_dtype"), ) t_pattern = CallFunction( aten.permute.default, @@ -2773,7 +2773,7 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): torch.ops.torchao.dequantize_affine_float8.default, KeywordArg("x"), KeywordArg("x_scale"), - KeywordArg("x_dq_dtype"), + output_dtype=KeywordArg("x_dq_dtype"), ) dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( @@ -2816,33 +2816,33 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): def _inner(match): - # input_contiguous = True - # # Check dequant pattern has only 1 user. - # ( - # linear_node, - # _, - # ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - - # input_index = 1 if linear_node.target is aten.addmm.default else 0 - # assert dtype in [torch.float32, torch.bfloat16] - # ( - # dequant_node, - # _, - # _, - # _, - # ) = _get_linear_dq_node( - # linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - # ) - # assert dequant_node.target is quantized_decomposed.dequantize_per_tensor.tensor - - # # only support float8_e4m3 input - # if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: - # return False - - # if len(list(dequant_node.users)) != 1: - # # Ensure the dequant pattern only has 1 user - # # since we will delete the dequant pattern here - # return False + input_contiguous = True + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default + + # only support float8_e4m3 input + if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: + return False + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False return True From 48a3d999eba50277b735a00848f62a008f9f2376 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 10:06:21 +0000 Subject: [PATCH 07/26] support dequant promtion for fp8 --- .../quantization/pt2e/inductor_passes/x86.py | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 1dfb2ce3d2..e5cfd95f71 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -116,18 +116,26 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): - dequantize_per_tensor_activation_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.tensor - if is_tensor_overload - else quantized_decomposed.dequantize_per_tensor.default, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("x_quant_min"), - KeywordArg("x_quant_max"), - KeywordArg("x_dq_dtype"), - ) +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False, is_fp8=False): + if is_fp8: + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + output_dtype=KeywordArg("x_dq_dtype"), + ) + else: + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) return dequantize_per_tensor_activation_pattern @@ -491,6 +499,7 @@ def _inner(match): if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ]: @@ -520,6 +529,7 @@ def _inner(match): in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] and len(list(dequant_pattern_end_node.users)) > 1 ): @@ -530,7 +540,7 @@ def _inner(match): return _inner -def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_promotion_pattern(dtype), @@ -586,6 +596,7 @@ def clone_to_new_node(graph, source_node, user_node): assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ] @@ -598,6 +609,7 @@ def _find_first_node_in_dequant_pattern(_node): if _node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node @@ -614,6 +626,7 @@ def _find_first_node_in_dequant_pattern(_node): assert dequant_pattern_start_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] # Clone the dequant pattern for each user node @@ -1332,9 +1345,9 @@ def _generate_linear_dynamic_fp16_pattern( def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 @@ -1355,11 +1368,15 @@ def _register_dequant_promotion(): # OPT(to_fp32) OPT(to_fp32) # + - - | - - - | - - + # quant quant + if is_fp8 and not is_tensor_overload: + continue + _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern( - is_tensor_overload=is_tensor_overload + is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, @@ -1369,6 +1386,7 @@ def _register_dequant_promotion(): ), pass_number=0, dtype=dtype, + is_fp8=is_fp8, ) # pass_number=0 to run before weight prepack @@ -2853,7 +2871,7 @@ def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), - pass_number=0, + pass_number=1, ) def scaled_mm_fusion(match: Match, *args, **kwargs): input_contiguous = True From 1921b2f8d7a3340caf6345a2a89d56968abc7f73 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 14:19:34 +0000 Subject: [PATCH 08/26] add ut --- .../quantization/pt2e/test_x86inductor_fusion.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 43868bfed0..fa981dc4d6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2430,7 +2430,8 @@ def matcher_check_fn(): @parametrize("has_bias", [True, False]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("input_dim_exceeds_two", [True, False]) - def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): + @parametrize("check_reuse_input", [True, False]) + def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): class FP8QDQLinear(torch.nn.Module): def __init__(self, in_features, out_features): super().__init__() @@ -2468,17 +2469,23 @@ def forward(self, input): return out class Mod(torch.nn.Module): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, check_reuse_input): super().__init__() self.l0 = FP8QDQLinear(in_features, out_features) + self.check_reuse_input = check_reuse_input + if self.check_reuse_input: + self.l1 = FP8QDQLinear(in_features, out_features) def forward(self, x): y = self.l0(x) + if self.check_reuse_input: + z = self.l1(x) + y += z return y M1, M2, N, K = 2, 3, 13, 16 M = M1 * M2 - mod = Mod(N, K) + mod = Mod(N, K, check_reuse_input) if input_dim_exceeds_two: v = torch.randn(M1, M2, N) else: @@ -2486,7 +2493,8 @@ def forward(self, x): v = v.to(dtype) def matcher_check_fn(): - self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], 1) + counter = 2 if check_reuse_input else 1 + self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], counter) self._test_common(mod, (v,), matcher_check_fn) From 0335415f8449bfa59613a1832184fa74999cf153 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 14:35:38 +0000 Subject: [PATCH 09/26] remove redundant codes --- torchao/quantization/pt2e/inductor_passes/x86.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index e5cfd95f71..0af7d21a99 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -540,7 +540,7 @@ def _inner(match): return _inner -def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_promotion_pattern(dtype), @@ -1386,7 +1386,6 @@ def _register_dequant_promotion(): ), pass_number=0, dtype=dtype, - is_fp8=is_fp8, ) # pass_number=0 to run before weight prepack From a5bb4d0cece16b2beaf36323e7b5440be9d9a198 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 16:24:35 +0000 Subject: [PATCH 10/26] fix lint --- torchao/quantization/pt2e/inductor_passes/x86.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 0af7d21a99..dd9f0e6c21 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -116,7 +116,9 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False, is_fp8=False): +def get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=False, is_fp8=False +): if is_fp8: dequantize_per_tensor_activation_pattern = CallFunction( torch.ops.torchao.dequantize_affine_float8.default, @@ -1347,7 +1349,12 @@ def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in dequant_pattern_cases: + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + is_fp8, + ) in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 From 0c7f8eacaf98d181d3cb60c775af9d503dae736b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 16:30:37 +0000 Subject: [PATCH 11/26] resolve conflict --- test/float8/test_compile.py | 54 ------------------------------------- 1 file changed, 54 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c9280cd70c..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -392,59 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@pytest.mark.parametrize( - "float8_dtype", - [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], -) -@pytest.mark.parametrize( - "hp_dtype", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): - quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 - dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 - input = torch.randn(10, 10) - with torch.no_grad(): - torch._dynamo.reset() - expected_scale = torch.tensor(2.0) - expected_quantized = quantize_affine_float8( - input, - expected_scale, - float8_dtype=float8_dtype, - ) - expected_dequantized = dequantize_affine_float8( - expected_quantized, - expected_scale, - output_dtype=hp_dtype, - ) - test_q, (code_q,) = torch._inductor.utils.run_and_get_code( - torch.compile(quantize_affine_float8), - input, - expected_scale, - float8_dtype=float8_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.quantize_affine_float8.default" - ).run(code_q) - test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( - torch.compile(dequantize_affine_float8), - test_q, - expected_scale, - hp_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.dequantize_affine_float8.default" - ).run(code_dq) - torch.testing.assert_close(expected_quantized, test_q) - torch.testing.assert_close(expected_dequantized, test_dq) - - if __name__ == "__main__": pytest.main([__file__]) From 0175b1795897e66fa57ccf0228d8a0dca19ee7b0 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 1 Jul 2025 13:54:18 +0000 Subject: [PATCH 12/26] change to use qlinear --- .../pt2e/test_x86inductor_fusion.py | 12 +- .../quantization/pt2e/inductor_passes/x86.py | 298 +++--------------- 2 files changed, 49 insertions(+), 261 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index fa981dc4d6..0522e9e426 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2431,7 +2431,7 @@ def matcher_check_fn(): @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("input_dim_exceeds_two", [True, False]) @parametrize("check_reuse_input", [True, False]) - def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): + def test_fp8_qlinear(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): class FP8QDQLinear(torch.nn.Module): def __init__(self, in_features, out_features): super().__init__() @@ -2446,7 +2446,7 @@ def __init__(self, in_features, out_features): def forward(self, input): weight = torch.ops.torchao.dequantize_affine_float8( tensor=self.weight.data, - scale=torch.tensor(self.weight_scale), + scale=torch.tensor([self.weight_scale]), output_dtype=torch.float, ) if dtype != torch.float: @@ -2454,12 +2454,12 @@ def forward(self, input): q_input = torch.ops.torchao.quantize_affine_float8( tensor=input, - scale=torch.tensor(self.scale), + scale=torch.tensor([self.scale]), float8_dtype=self.qtype, ) dq_input = torch.ops.torchao.dequantize_affine_float8( tensor=q_input, - scale=torch.tensor(self.scale), + scale=torch.tensor([self.scale]), output_dtype=torch.float, ) if dtype != torch.float: @@ -2480,7 +2480,7 @@ def forward(self, x): y = self.l0(x) if self.check_reuse_input: z = self.l1(x) - y += z + y = torch.cat([y, z]) return y M1, M2, N, K = 2, 3, 13, 16 @@ -2494,7 +2494,7 @@ def forward(self, x): def matcher_check_fn(): counter = 2 if check_reuse_input else 1 - self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], counter) + self.assertEqual(counters["inductor"]["qlinear_weight_prepack_matcher_count"], counter) self._test_common(mod, (v,), matcher_check_fn) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index dd9f0e6c21..4576eb99cf 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -951,6 +951,7 @@ def _inner(match): assert dequant_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] if len(list(dequant_node.users)) != 1: @@ -1007,6 +1008,7 @@ def _register_qlinear_weight_prepack_pass( dtype=torch.float32, input_dim_exceeds_two=False, input_contiguous=True, + is_fp8=False, ): @register_freezing_graph_pattern( pattern, @@ -1022,7 +1024,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - mm/addmm <- t <- dequant_per_channel <- int8_weight + mm/addmm <- t <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -1054,28 +1056,30 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): t_node = linear_node.args[weight_index] if dtype == torch.float32: - dequant_per_channel = t_node.args[0] + dequant = t_node.args[0] else: weight_to_bf16_node = t_node.args[0] - dequant_per_channel = weight_to_bf16_node.args[0] + dequant = weight_to_bf16_node.args[0] assert ( - dequant_per_channel.target - is quantized_decomposed.dequantize_per_channel.default + dequant.target in [ + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8.default, + ] ) # Activation QParams - qx, x_zp, x_scale = ( + qx, x_scale = ( kwargs["x"], - kwargs["x_zp"], kwargs["x_scale"], ) # Weight QParams - qw, w_scale, w_zp = ( + qw, w_scale = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], ) + x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None + w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None # Params bias = kwargs["b"] if "b" in kwargs else None @@ -1112,7 +1116,8 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): "", # post op algorithm ) Node = torch.fx.node.Node - if isinstance(x_scale, Node) and isinstance(x_zp, Node): + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): new_linear_node = graph.call_function( torch.ops.onednn.qlinear_pointwise.tensor, args=new_args ) @@ -1158,7 +1163,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(t_node) if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] - graph.erase_node(dequant_per_channel) + graph.erase_node(dequant) counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( @@ -1171,6 +1176,7 @@ def _generate_dequant_linear_node_pattern( dtype=torch.float32, input_dim_exceeds_two=False, is_tensor_overload=False, + is_fp8=False, ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1180,7 +1186,7 @@ def _generate_dequant_linear_node_pattern( KeywordArg("b"), _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1197,7 +1203,7 @@ def _generate_dequant_linear_node_pattern( aten.mm.default, _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1217,6 +1223,7 @@ def _generate_dequant_bmm_node_pattern( dtype=torch.float32, with_bias=False, is_tensor_overload=False, + is_fp8=False, ): # When activation of linear dim exceed 2 and not contiguous t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1227,7 +1234,7 @@ def _generate_dequant_bmm_node_pattern( CallFunction( aten.expand.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1259,20 +1266,32 @@ def _generate_qlinear_weight_prepack_patterns( input_contiguous=True, with_bias=False, is_tensor_overload=False, + is_fp8=False, ): + if is_fp8: + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), + ) + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: return _generate_dequant_bmm_node_pattern( - dequantize_per_channel_weight_pattern, + dequant_wgt_pattern, dtype, with_bias, is_tensor_overload, + is_fp8, ) else: return _generate_dequant_linear_node_pattern( - dequantize_per_channel_weight_pattern, + dequant_wgt_pattern, dtype, input_dim_exceeds_two, is_tensor_overload, + is_fp8, ) @@ -1442,15 +1461,18 @@ def _register_qlinear_weight_prepack(): # | OPT(add) | linear_weight_prepack_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) # Step 1: register patterns from mm and addmm - for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in linear_weight_prepack_cases: + if is_fp8 and not is_tensor_overload: + continue weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. @@ -1459,6 +1481,7 @@ def _register_qlinear_weight_prepack(): pass_number=1, dtype=dtype, input_dim_exceeds_two=input_dim_exceeds_two, + is_fp8=is_fp8, ) # Step 2: register patterns from bmm @@ -1476,6 +1499,7 @@ def _register_qlinear_weight_prepack(): input_contiguous=False, with_bias=with_bias, is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ) _register_qlinear_weight_prepack_pass( bmm_pattern, @@ -1485,6 +1509,7 @@ def _register_qlinear_weight_prepack(): dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, + is_fp8=is_fp8, ) @@ -2764,241 +2789,6 @@ def _register_qlinear_binary_fusion(): ) -def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): - # + - - - - | - - - - - - | - - - - + - # | dq_per_tensor dq_per_tensor | - # | | | | - # | OPT(to_bf16) OPT(to_bf16) | - # | | | | - # | OPT(reshape) permute | - # | \ / | - # | addmm/mm | - # | | | - # | OPT(quant_per_tensor) | - # | | | - # | OPT(reshape) | - assert dtype in [torch.float32, torch.bfloat16] - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) - t_pattern = CallFunction( - aten.permute.default, - _may_generate_pattern_with_dtype_convert( - dequant_wgt_pattern, - KeywordArg("autocast_wgt_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("permute_axes"), - ) - dequantize_per_tensor_activation_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8.default, - KeywordArg("x"), - KeywordArg("x_scale"), - output_dtype=KeywordArg("x_dq_dtype"), - ) - - dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( - CallFunction( - aten.addmm.default, - KeywordArg("b"), - _may_generate_pattern_with_reshape( - _may_generate_pattern_with_dtype_convert( - dequantize_per_tensor_activation_pattern, - KeywordArg("autocast_act_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("act_reshape_size"), - input_dim_exceeds_two, - ), - t_pattern, - ), - KeywordArg("output_reshape_size"), - input_dim_exceeds_two, - ) - dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( - CallFunction( - aten.mm.default, - _may_generate_pattern_with_reshape( - _may_generate_pattern_with_dtype_convert( - dequantize_per_tensor_activation_pattern, - KeywordArg("autocast_act_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("act_reshape_size"), - input_dim_exceeds_two, - ), - t_pattern, - ), - KeywordArg("output_reshape_size"), - input_dim_exceeds_two, - ) - return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern - - -def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): - def _inner(match): - input_contiguous = True - # Check dequant pattern has only 1 user. - ( - linear_node, - _, - ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - - input_index = 1 if linear_node.target is aten.addmm.default else 0 - assert dtype in [torch.float32, torch.bfloat16] - ( - dequant_node, - _, - _, - _, - ) = _get_linear_dq_node( - linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - ) - assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default - - # only support float8_e4m3 input - if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: - return False - - if len(list(dequant_node.users)) != 1: - # Ensure the dequant pattern only has 1 user - # since we will delete the dequant pattern here - return False - - return True - - return _inner - - -def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): - @register_freezing_graph_pattern( - pattern, - extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), - pass_number=1, - ) - def scaled_mm_fusion(match: Match, *args, **kwargs): - input_contiguous = True - assert dtype in [torch.float32, torch.bfloat16] - ( - linear_node, - output_reshape_node, - ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - input_index = 1 if linear_node.target is aten.addmm.default else 0 - weight_index = input_index + 1 - - ( - dequant_node, - act_reshape_node, - activation_to_bf16_node, - act_expand_node, - ) = _get_linear_dq_node( - linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - ) - - if input_dim_exceeds_two and not input_contiguous: - wgt_expand_node = linear_node.args[weight_index] - assert wgt_expand_node.target is aten.expand.default - t_node = wgt_expand_node.args[0] - else: - t_node = linear_node.args[weight_index] - - if dtype == torch.float32: - dequant_per_tensor = t_node.args[0] - else: - weight_to_bf16_node = t_node.args[0] - dequant_per_tensor = weight_to_bf16_node.args[0] - assert ( - dequant_per_tensor.target - is torch.ops.torchao.dequantize_affine_float8.default - ) - - # Activation QParams - qx, x_scale = ( - kwargs["x"], - kwargs["x_scale"], - ) - - # Weight QParams - qw, w_scale = ( - kwargs["q_weight"], - kwargs["w_scale"], - ) - - # Params - bias = kwargs["b"] if "b" in kwargs else None - - x_shape = qx.meta.get("tensor_meta").shape - if has_free_symbols(x_shape): - # For dynamic shape case, we can't get activation shape ahead of runtime. - x_shape = None - graph = match.graph - with graph.inserting_before(linear_node): - scaled_mm_input_node = qx - if input_dim_exceeds_two: - new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) - new_act_reshape_node = graph.call_function( - torch.ops.aten.reshape.default, args=new_reshape_args - ) - scaled_mm_input_node = new_act_reshape_node - # Insert weight prepack node and the qlinear node - permute_weight_inputs = ( - qw, - t_node.args[1], - ) - permute_weight_op = torch.ops.aten.permute.default - permute_weight_node = graph.call_function( - permute_weight_op, args=permute_weight_inputs - ) - output_scale = torch.tensor(1.0) - new_args: tuple[Any, ...] = ( - scaled_mm_input_node, - permute_weight_node, - x_scale, - w_scale, - bias, - output_scale, # output_scale - dtype, # output_dtype - False, # use_fast_accum - ) - new_linear_node = graph.call_function( - torch.ops.aten._scaled_mm.default, args=new_args - ) - - linear_node.replace_all_uses_with(new_linear_node) - new_linear_node.meta.update(linear_node.meta) - - graph.erase_node(linear_node) - if input_dim_exceeds_two: - graph.erase_node(act_reshape_node) - if dtype == torch.bfloat16: - graph.erase_node(activation_to_bf16_node) - # Erase the dequant pattern - graph.erase_node(dequant_node) - # Erase the dequant per channel pattern - graph.erase_node(t_node) - if dtype == torch.bfloat16: - graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] - graph.erase_node(dequant_per_tensor) - - counters["inductor"]["scaled_mm_matcher_count"] += 1 - counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) - - -def _register_scaled_mm(): - fp8_linear_weight_prepack_cases = itertools.product( - [torch.float32, torch.bfloat16], [False, True] - ) - for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: - patterns = _generate_dequant_fp8_linear_node_pattern( - dtype, input_dim_exceeds_two - ) - for pattern in patterns: - _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) - - @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -3022,8 +2812,6 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() - _register_scaled_mm() - def quant_lift_up(module_graph: torch.fx.graph.Graph): """ From 564d4b75372d589d910269c1cb8a7eae7273b9c7 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 7 Jul 2025 09:30:27 +0000 Subject: [PATCH 13/26] add ut --- .../pt2e/test_x86inductor_fusion.py | 507 +++++++++++++----- .../quantization/pt2e/inductor_passes/x86.py | 240 +++++---- 2 files changed, 498 insertions(+), 249 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 0522e9e426..60ab0f7747 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -109,24 +109,104 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer +class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features, has_bias): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + + q_input = torch.ops.torchao.quantize_affine_float8( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + +def fp8_convert_(model): + + def generate_model_info(model): + from collections import namedtuple + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) + parent_child_mod_dict = {} + + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) + return parent_child_mod_dict + + parent_child_mod_dict = generate_model_info(model) + for name, mod in model.named_modules(): + mod_type_str = mod.__class__.__name__ + if mod_type_str not in [ + "Linear", + ]: + continue + param = mod.weight + xmax = torch.max(param) + weight_scale = xmax / torch.finfo(torch.float8_e4m3fn).max + mod.weight_scale = weight_scale + q_param = torch.clamp( + (param / weight_scale), + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max, + ).to(torch.float8_e4m3fn) + mod.weight.data = q_param + if mod_type_str in ["Linear"]: + patched_mod = FP8QDQLinear(mod.in_features, mod.out_features, False) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param + + parent = parent_child_mod_dict[mod].parent + name = parent_child_mod_dict[mod].name + setattr(parent, name, patched_mod) + + def _generate_qdq_quantized_model( - mod, inputs, is_qat=False, is_dynamic=False, quantizer=None + mod, inputs, is_qat=False, is_dynamic=False, quantizer=None, is_fp8=False, ): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training(mod, inputs, strict=True).module() - quantizer = ( - quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) - ) - prepare_model = ( - prepare_qat_pt2e(export_model, quantizer) - if is_qat - else prepare_pt2e(export_model, quantizer) - ) - prepare_model(*inputs) - torchao.quantization.pt2e.move_exported_model_to_eval(prepare_model) - convert_model = convert_pt2e(prepare_model) - return convert_model + if is_fp8: + assert(not is_qat) + fp8_convert_(mod) + return mod + else: + export_model = export_for_training(mod, inputs, strict=True).module() + quantizer = ( + quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) + ) + prepare_model = ( + prepare_qat_pt2e(export_model, quantizer) + if is_qat + else prepare_pt2e(export_model, quantizer) + ) + prepare_model(*inputs) + torchao.quantization.pt2e.move_exported_model_to_eval(prepare_model) + convert_model = convert_pt2e(prepare_model) + return convert_model def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): @@ -195,6 +275,7 @@ def _test_common( is_dynamic=False, quantizer=None, compile_options={}, # noqa: B006 + is_fp8=False, ): if not hasattr(self, "device"): has_xpu = any( @@ -225,7 +306,7 @@ def _test_common( maybe_autocast = contextlib.nullcontext() if check_quantization: convert_model = _generate_qdq_quantized_model( - mod, inputs, is_qat, is_dynamic, quantizer + mod, inputs, is_qat, is_dynamic, quantizer, is_fp8 ) with torch.no_grad(), maybe_autocast: _ = torch.compile(convert_model)(*inputs) @@ -250,11 +331,12 @@ def _test_code_common( check_dynamic=None, num_include_ops=None, quantizer=None, + is_fp8=False, ): with torch.no_grad(): clone_inputs = self._clone_inputs(inputs) if check_quantization: - mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer) + mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer, is_fp8=is_fp8) expected = mod(*inputs) actual, (source_code,) = run_and_get_code( torch.compile(mod, fullgraph=True, dynamic=check_dynamic), @@ -1342,12 +1424,13 @@ def _qlinear_test_helper( self, inputs, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, do_permute=False, matcher_check_fn=None, bias=True, is_dynamic=False, is_qat=False, + is_fp8=False ): class M(torch.nn.Module): def __init__(self, use_bias, do_permute=False): @@ -1382,10 +1465,11 @@ def _default_matcher_check_fn(): if matcher_check_fn is not None else _default_matcher_check_fn ), - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, is_qat=is_qat, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1394,8 +1478,9 @@ def test_qlinear_cpu(self): r""" This testcase will quantize a single Linear Moduel. """ - for bias in [True, False]: - self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) + for is_fp8 in [True, False]: + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1433,14 +1518,15 @@ def test_dynamic_qlinear_input_dim_exceeds_2(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16(self): + def test_qlinear_mixed_bf16(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias - ) + for is_fp8 in [True, False]: + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, bias=bias, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1448,20 +1534,22 @@ def test_qlinear_input_dim_exceeds_2(self): r""" This testcase will quantize a single Linear Moduel. """ - for bias in [True, False]: - self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) + for is_fp8 in [True, False]: + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias, is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): + def test_qlinear_mixed_bf16_input_dim_exceeds_2(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias - ) + for is_fp8 in [True, False]: + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1471,54 +1559,58 @@ def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for bias in [True, False]: + for is_fp8 in [True, ]: + for bias in [False, ]: - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 - ) - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 13 if bias else 12, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 13 if bias else 12, + ) - self._qlinear_test_helper( - (torch.randn((2, 4, 3, 4)),), - do_permute=True, - matcher_check_fn=matcher_check_fn, - bias=bias, - ) + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=is_fp8, + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): + def test_qlinear_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): r""" This testcase will quantize a single Linear Module for int8_bf16. * Input dim exceeds 2 * Input not contiguous """ - for bias in [True, False]: + for is_fp8 in [True, False]: + for bias in [True, False]: - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 - ) - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 17 if bias else 16, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 17 if bias else 16, + ) - self._qlinear_test_helper( - (torch.randn((2, 4, 3, 4)),), - int8_mixed_bf16=True, - do_permute=True, - matcher_check_fn=matcher_check_fn, - bias=bias, - ) + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + mixed_bf16=True, + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=is_fp8, + ) def _qlinear_unary_test_helper( - self, inputs, unary_op=torch.nn.ReLU(), device="cpu", int8_mixed_bf16=False + self, inputs, unary_op=torch.nn.ReLU(), device="cpu", mixed_bf16=False, is_fp8=False ): class M(torch.nn.Module): def __init__(self, use_bias): @@ -1555,8 +1647,9 @@ def matcher_check_fn(): mod, inputs, matcher_check_fn, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1565,54 +1658,60 @@ def test_qlinear_relu_cpu(self): r""" This testcase will quantize a Linear->ReLU pattern. """ - self._qlinear_unary_test_helper((torch.randn((2, 4)),)) + for is_fp8 in [True, False]: + self._qlinear_unary_test_helper((torch.randn((2, 4)),), is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_relu_int8_mixed_bf16(self): + def test_qlinear_relu_mixed_bf16(self): r""" - This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - self._qlinear_unary_test_helper((torch.randn((2, 4)),), int8_mixed_bf16=True) + for is_fp8 in [True, False]: + self._qlinear_unary_test_helper((torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_relu_input_dim_exceeds_2(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_relu_input_dim_exceeds_2(self, is_fp8): r""" This testcase will quantize a Linear->ReLU pattern. """ - self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),)) + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self, is_fp8): r""" - This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), int8_mixed_bf16=True) + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_gelu_cpu(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_gelu_cpu(self, is_fp8): r""" This testcase will quantize a Linear->GELU pattern. """ for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: - self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu) + self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu, is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_gelu_int8_mixed_bf16(self): + def test_qlinear_gelu_mixed_bf16(self): r""" - This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->GELU pattern with mixed_bf16 quantization. """ - for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: - self._qlinear_unary_test_helper( - (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True - ) + for is_fp8 in [True, False]: + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), gelu, mixed_bf16=True, is_fp8=is_fp8 + ) def _qlinear_add_test_helper( self, @@ -1804,13 +1903,170 @@ def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) + + def _fp8_qlinear_add_test_helper( + self, + device="cpu", + use_relu=False, + mixed_bf16=False, + ): + r""" + This testcase will quantize two consecutive Linear->Add(->relu) patterns as: + X + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + / \ + linear(X) linear(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + ): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x) + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.linear3(tmp) + tmp2 = self.linear4(tmp) + res = self.add_fn2(tmp1, tmp2) + if self.use_relu: + res = self.relu2(res) + return res + + add_fn_list = [ + lambda x, y: x + y, + lambda x, y: y + x, + lambda x, y: x.add_(y), + lambda x, y: y.add_(x), + ] + is_fp8=True + shape_list = [(4, 4), (4, 4, 4)] + cases = itertools.product(add_fn_list, shape_list) + for add_fn, shape in cases: + mod = M(add_fn, use_relu).eval().to(device=device) + v = torch.randn( + shape, dtype=torch.float32, requires_grad=False, device=device + ).add(1) + + def matcher_check_fn(): + # 1. Dequant-linear pattern matched in quantization weight prepack * 4 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 + ) + # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] + nodes_per_match = 6 if mixed_bf16 else 4 + if len(shape) == 3: + # pattern = [dequant_per_tensor, (convert_dtype), (view), \ + # dequant_per_channel, (convert_dtype), (view), permute, addmm] + nodes_per_match += 2 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 4 * nodes_per_match, + ) + # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_count"], + 0 if TEST_ACL else 2, + ) + # Two linear-binary patterns are matched + # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] + # matched patter2 = [qlinear, add, (convert dtype), (relu)] + # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary + expected_matcher_nodes = ( + 5 + 2 * use_relu + ) + self.assertEqual( + counters["inductor"]["qlinear_binary_matcher_nodes"], + 0 if TEST_ACL else expected_matcher_nodes, + ) + self.assertEqual( + counters["inductor"]["qlinear_binary_lower_count"], + 0 if TEST_ACL else 2, + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, + is_fp8=is_fp8, + ) + + if TEST_ACL: + continue + + if torch._inductor.config.cpp_wrapper: + # For CPP wrapper + self._test_code_common( + mod, + (v,), + [ + "aoti_torch_cpu__qlinear_pointwise_tensor", + "aoti_torch_cpu__qlinear_pointwise_binary_tensor", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + is_fp8=True, + ) + else: + # For python wrapper + self._test_code_common( + mod, + (v,), + [ + "torch.ops.onednn.qlinear_pointwise.tensor", + "torch.ops.onednn.qlinear_pointwise.binary", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + is_fp8=True, + ) + + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @parametrize("use_relu", [True, False]) + @parametrize("mixed_bf16", [True, False]) + def test_fp8_qlinear_add_cpu(self, use_relu, mixed_bf16): + self._fp8_qlinear_add_test_helper(use_relu=use_relu, mixed_bf16=mixed_bf16) + def _qlinear_dequant_promotion_test_helper( self, inputs, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, is_dynamic=False, matcher_check_fn=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -1850,14 +2106,16 @@ def default_matcher_check_fn(): if matcher_check_fn is not None else default_matcher_check_fn ), - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_dequant_promotion_cpu(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_dequant_promotion_cpu(self, is_fp8): r""" This testcase test if dequant node before linear is promoted correctly: X @@ -1870,14 +2128,15 @@ def test_qlinear_dequant_promotion_cpu(self): | Y """ - self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),)) + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),), is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_dequant_promotion_int8_mixed_bf16(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_dequant_promotion_mixed_bf16(self, is_fp8): r""" - Test with int8_mixed_bf16 quantization. + Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: X | @@ -1890,12 +2149,13 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16(self): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 4)),), int8_mixed_bf16=True + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8 ) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self, is_fp8): r""" This testcase test if dequant node before linear is promoted correctly: X @@ -1908,14 +2168,15 @@ def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): | Y """ - self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),)) + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),), is_fp8=is_fp8) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self, is_fp8): r""" - Test with int8_mixed_bf16 quantization. + Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: X | @@ -1928,12 +2189,13 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 3, 4)),), int8_mixed_bf16=True + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8 ) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_dequant_promotion_dynamic_cpu(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_dequant_promotion_dynamic_cpu(self, is_fp8): r""" This testcase test if dequant node before linear is promoted correctly: X @@ -1946,7 +2208,6 @@ def test_qlinear_dequant_promotion_dynamic_cpu(self): | Y """ - def matcher_check_fn(): # 1. Dequant pattern matcher for dequant promotion * 1 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) @@ -1959,11 +2220,13 @@ def matcher_check_fn(): (torch.randn((2, 4)),), matcher_check_fn=matcher_check_fn, is_dynamic=True, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @skipIfNoONEDNN - def test_qlinear_mul_cpu(self): + @parametrize("is_fp8", [True, False]) + def test_qlinear_mul_cpu(self, is_fp8): r""" This testcase will quantize a Linear->Mul pattern. """ @@ -1992,6 +2255,7 @@ def matcher_check_fn(): (x1, x2), matcher_check_fn, check_quantization=True, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -2432,49 +2696,13 @@ def matcher_check_fn(): @parametrize("input_dim_exceeds_two", [True, False]) @parametrize("check_reuse_input", [True, False]) def test_fp8_qlinear(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): - class FP8QDQLinear(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.qtype = torch.float8_e4m3fn - self.weight = torch.randn((out_features, in_features)).to(self.qtype) - self.weight_scale = 2.0 - self.scale = 2.0 - self.bias = None - if has_bias: - self.bias = torch.randn((out_features,)).to(dtype) - - def forward(self, input): - weight = torch.ops.torchao.dequantize_affine_float8( - tensor=self.weight.data, - scale=torch.tensor([self.weight_scale]), - output_dtype=torch.float, - ) - if dtype != torch.float: - weight = weight.to(dtype) - - q_input = torch.ops.torchao.quantize_affine_float8( - tensor=input, - scale=torch.tensor([self.scale]), - float8_dtype=self.qtype, - ) - dq_input = torch.ops.torchao.dequantize_affine_float8( - tensor=q_input, - scale=torch.tensor([self.scale]), - output_dtype=torch.float, - ) - if dtype != torch.float: - dq_input = dq_input.to(dtype) - - out = torch.nn.functional.linear(dq_input, weight, self.bias) - return out - class Mod(torch.nn.Module): def __init__(self, in_features, out_features, check_reuse_input): super().__init__() - self.l0 = FP8QDQLinear(in_features, out_features) + self.l0 = FP8QDQLinear(in_features, out_features, has_bias) self.check_reuse_input = check_reuse_input if self.check_reuse_input: - self.l1 = FP8QDQLinear(in_features, out_features) + self.l1 = FP8QDQLinear(in_features, out_features, has_bias) def forward(self, x): y = self.l0(x) @@ -2490,13 +2718,12 @@ def forward(self, x): v = torch.randn(M1, M2, N) else: v = torch.randn(M, N) - v = v.to(dtype) def matcher_check_fn(): counter = 2 if check_reuse_input else 1 self.assertEqual(counters["inductor"]["qlinear_weight_prepack_matcher_count"], counter) - self._test_common(mod, (v,), matcher_check_fn) + self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) @dynamo_config.patch( diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 4576eb99cf..e411cff19e 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -27,6 +27,7 @@ _PER_TENSOR_QUANTIZE_OPS = [ quantized_decomposed.quantize_per_tensor.default, quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.torchao.quantize_affine_float8.default, ] _VIEW_OPS = [ @@ -62,7 +63,7 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.float8_e4m3fn] return output_dtype @@ -327,20 +328,29 @@ def generate_pattern_with_unary(computation_call, unary_post_op): return computation_call -def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): - quantized_op_output_pattern_pt2e = CallFunction( - quantized_decomposed.quantize_per_tensor.default, - _may_generate_pattern_with_dtype_convert( - computation_call, - Arg(), - with_dtype_convert, - ), - KeywordArg("o_inv_scale"), - KeywordArg("o_zp"), - KeywordArg("o_qmin"), - KeywordArg("o_qmax"), - KeywordArg("o_dtype"), +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False, is_fp8=False): + may_generate_pattern_with_dtype_convert = _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, ) + if is_fp8: + quantized_op_output_pattern_pt2e = CallFunction( + torch.ops.torchao.quantize_affine_float8.default, + may_generate_pattern_with_dtype_convert, + KeywordArg("o_inv_scale"), + float8_dtype=KeywordArg("o_dtype"), + ) + else: + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + may_generate_pattern_with_dtype_convert, + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) return quantized_op_output_pattern_pt2e @@ -446,8 +456,10 @@ def fn(match): if extra_input_from_dequant and ( (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( - extra_input_of_binary_node.target - != quantized_decomposed.dequantize_per_tensor.default + extra_input_of_binary_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8.default, + ] ) ): return False @@ -732,8 +744,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] assert ( - dequant_per_channel.target # type: ignore[union-attr] - is quantized_decomposed.dequantize_per_channel.default + dequant_per_channel.target in [ # type: ignore[union-attr] + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8.default, + ] ) # Activation QParams @@ -1283,7 +1297,7 @@ def _generate_qlinear_weight_prepack_patterns( dtype, with_bias, is_tensor_overload, - is_fp8, + is_fp8=is_fp8, ) else: return _generate_dequant_linear_node_pattern( @@ -1291,7 +1305,7 @@ def _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload, - is_fp8, + is_fp8=is_fp8, ) @@ -1490,8 +1504,8 @@ def _register_qlinear_weight_prepack(): # https://github.com/pytorch/pytorch/blob/ # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 # in this case, we can convert it back to qlinear - for dtype, with_bias, is_tensor_overload in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + for dtype, with_bias, is_tensor_overload, is_fp8 in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ): bmm_pattern = _generate_qlinear_weight_prepack_patterns( dtype=dtype, @@ -2471,105 +2485,110 @@ def _register_qlinear_unary_fusion(): _gelu_fusion_2 as _gelu_fusion_tanh, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: - is_bf16 = original_pattern_output_dtype == torch.bfloat16 - for x_scale_zp_are_tensors in (False, True): - qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) - computation_op = ( - torch.ops.onednn.qlinear_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qlinear_pointwise.default - ) - # Priority 1 to match: QLinear Unary pattern with int8 output - linear_unary_replace_patterns = { - PostOpAttr( - "none", None, "none", [], "" - ): generate_pattern_with_output_quant( - qlinear_pattern, - ), - PostOpAttr( - "none", None, "relu", [], "" - ): generate_pattern_with_output_quant( - generate_pattern_with_unary(qlinear_pattern, aten.relu.default), - ), - PostOpAttr( - "none", None, "gelu", [], "none" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 + for is_fp8 in [True, False]: + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, ), - 2, - is_bf16, + with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), - with_dtype_convert=is_bf16, - ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, ), - 4, - is_bf16, + with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), - with_dtype_convert=is_bf16, - ), - } + } - for unary_attr, patterns in linear_unary_replace_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 3, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) - # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output - linear_unary_replace_float_out_patterns = { - PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - qlinear_pattern, aten.relu.default - ), - PostOpAttr( - "none", None, "gelu", [], "none" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, ), - 2, + Arg(), is_bf16, ), - Arg(), - is_bf16, - ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, ), - 4, + Arg(), is_bf16, ), - Arg(), - is_bf16, - ), - } + } - for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 4, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) def _register_qlinear_binary_fusion(): @@ -2635,14 +2654,16 @@ def _register_qlinear_binary_fusion(): # totally 3 patterns (2 are identical) swap_binary_inputs_list = [False, True] int8_mixed_bf16_list = [False, True] + is_fp8_list = [False, True] combinations = itertools.product( unary_postop_list, int8_mixed_bf16_list, swap_binary_inputs_list, convert_dtype_after_binary_list, + is_fp8_list, ) qlinear_binary_replace_patterns = {} - for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary, is_fp8 in combinations: if not int8_mixed_bf16 and cvt_dtype_binary: # No convert node after binary node if dtypes are all fp32 continue @@ -2663,6 +2684,7 @@ def _register_qlinear_binary_fusion(): ), unary_postop_dict[unary_op], ), + is_fp8=is_fp8, ) } ) From 994867428559f6666a56db0416e4d46bff011e4a Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 7 Jul 2025 09:33:55 +0000 Subject: [PATCH 14/26] fix lint --- .../pt2e/test_x86inductor_fusion.py | 74 +++++++++++++------ .../quantization/pt2e/inductor_passes/x86.py | 73 ++++++++++++------ 2 files changed, 101 insertions(+), 46 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 60ab0f7747..be88f9cf93 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -141,10 +141,11 @@ def forward(self, input): out = torch.nn.functional.linear(dq_input, weight, self.bias) return out -def fp8_convert_(model): +def fp8_convert_(model): def generate_model_info(model): from collections import namedtuple + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) parent_child_mod_dict = {} @@ -185,12 +186,17 @@ def create_mod_info_recursion(parent): def _generate_qdq_quantized_model( - mod, inputs, is_qat=False, is_dynamic=False, quantizer=None, is_fp8=False, + mod, + inputs, + is_qat=False, + is_dynamic=False, + quantizer=None, + is_fp8=False, ): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: if is_fp8: - assert(not is_qat) + assert not is_qat fp8_convert_(mod) return mod else: @@ -336,7 +342,9 @@ def _test_code_common( with torch.no_grad(): clone_inputs = self._clone_inputs(inputs) if check_quantization: - mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer, is_fp8=is_fp8) + mod = _generate_qdq_quantized_model( + mod, inputs, quantizer=quantizer, is_fp8=is_fp8 + ) expected = mod(*inputs) actual, (source_code,) = run_and_get_code( torch.compile(mod, fullgraph=True, dynamic=check_dynamic), @@ -1430,7 +1438,7 @@ def _qlinear_test_helper( bias=True, is_dynamic=False, is_qat=False, - is_fp8=False + is_fp8=False, ): class M(torch.nn.Module): def __init__(self, use_bias, do_permute=False): @@ -1480,7 +1488,9 @@ def test_qlinear_cpu(self): """ for is_fp8 in [True, False]: for bias in [True, False]: - self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=is_fp8) + self._qlinear_test_helper( + (torch.randn((2, 4)),), bias=bias, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1536,7 +1546,9 @@ def test_qlinear_input_dim_exceeds_2(self): """ for is_fp8 in [True, False]: for bias in [True, False]: - self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias, is_fp8=is_fp8) + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), bias=bias, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -1559,8 +1571,12 @@ def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for is_fp8 in [True, ]: - for bias in [False, ]: + for is_fp8 in [ + True, + ]: + for bias in [ + False, + ]: def matcher_check_fn(): self.assertEqual( @@ -1610,7 +1626,12 @@ def matcher_check_fn(): ) def _qlinear_unary_test_helper( - self, inputs, unary_op=torch.nn.ReLU(), device="cpu", mixed_bf16=False, is_fp8=False + self, + inputs, + unary_op=torch.nn.ReLU(), + device="cpu", + mixed_bf16=False, + is_fp8=False, ): class M(torch.nn.Module): def __init__(self, use_bias): @@ -1669,7 +1690,9 @@ def test_qlinear_relu_mixed_bf16(self): This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ for is_fp8 in [True, False]: - self._qlinear_unary_test_helper((torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8) + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1688,7 +1711,9 @@ def test_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self, is_fp8): r""" This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8) + self._qlinear_unary_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1903,7 +1928,6 @@ def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) - def _fp8_qlinear_add_test_helper( self, device="cpu", @@ -1965,7 +1989,7 @@ def forward(self, x): lambda x, y: x.add_(y), lambda x, y: y.add_(x), ] - is_fp8=True + is_fp8 = True shape_list = [(4, 4), (4, 4, 4)] cases = itertools.product(add_fn_list, shape_list) for add_fn, shape in cases: @@ -1998,9 +2022,7 @@ def matcher_check_fn(): # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] # matched patter2 = [qlinear, add, (convert dtype), (relu)] # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary - expected_matcher_nodes = ( - 5 + 2 * use_relu - ) + expected_matcher_nodes = 5 + 2 * use_relu self.assertEqual( counters["inductor"]["qlinear_binary_matcher_nodes"], 0 if TEST_ACL else expected_matcher_nodes, @@ -2051,7 +2073,6 @@ def matcher_check_fn(): is_fp8=True, ) - @skipIfNoDynamoSupport @skipIfNoONEDNN @parametrize("use_relu", [True, False]) @@ -2128,7 +2149,9 @@ def test_qlinear_dequant_promotion_cpu(self, is_fp8): | Y """ - self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),), is_fp8=is_fp8) + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 4)),), is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -2168,7 +2191,9 @@ def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self, is_fp8): | Y """ - self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),), is_fp8=is_fp8) + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 3, 4)),), is_fp8=is_fp8 + ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -2208,6 +2233,7 @@ def test_qlinear_dequant_promotion_dynamic_cpu(self, is_fp8): | Y """ + def matcher_check_fn(): # 1. Dequant pattern matcher for dequant promotion * 1 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) @@ -2695,7 +2721,9 @@ def matcher_check_fn(): @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("input_dim_exceeds_two", [True, False]) @parametrize("check_reuse_input", [True, False]) - def test_fp8_qlinear(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): + def test_fp8_qlinear( + self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input + ): class Mod(torch.nn.Module): def __init__(self, in_features, out_features, check_reuse_input): super().__init__() @@ -2721,7 +2749,9 @@ def forward(self, x): def matcher_check_fn(): counter = 2 if check_reuse_input else 1 - self.assertEqual(counters["inductor"]["qlinear_weight_prepack_matcher_count"], counter) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], counter + ) self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index e411cff19e..1f66138189 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -63,7 +63,13 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.float8_e4m3fn] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + ] return output_dtype @@ -328,7 +334,9 @@ def generate_pattern_with_unary(computation_call, unary_post_op): return computation_call -def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False, is_fp8=False): +def generate_pattern_with_output_quant( + computation_call, with_dtype_convert=False, is_fp8=False +): may_generate_pattern_with_dtype_convert = _may_generate_pattern_with_dtype_convert( computation_call, Arg(), @@ -456,7 +464,8 @@ def fn(match): if extra_input_from_dequant and ( (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( - extra_input_of_binary_node.target not in [ + extra_input_of_binary_node.target + not in [ quantized_decomposed.dequantize_per_tensor.default, torch.ops.torchao.dequantize_affine_float8.default, ] @@ -743,12 +752,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert ( - dequant_per_channel.target in [ # type: ignore[union-attr] - quantized_decomposed.dequantize_per_channel.default, - torch.ops.torchao.dequantize_affine_float8.default, - ] - ) + assert dequant_per_channel.target in [ # type: ignore[union-attr] + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8.default, + ] # Activation QParams qx, x_zp, x_scale = ( @@ -1074,12 +1081,10 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): else: weight_to_bf16_node = t_node.args[0] dequant = weight_to_bf16_node.args[0] - assert ( - dequant.target in [ - quantized_decomposed.dequantize_per_channel.default, - torch.ops.torchao.dequantize_affine_float8.default, - ] - ) + assert dequant.target in [ + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8.default, + ] # Activation QParams qx, x_scale = ( @@ -1200,7 +1205,9 @@ def _generate_dequant_linear_node_pattern( KeywordArg("b"), _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1217,7 +1224,9 @@ def _generate_dequant_linear_node_pattern( aten.mm.default, _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1248,7 +1257,9 @@ def _generate_dequant_bmm_node_pattern( CallFunction( aten.expand.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload, is_fp8), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1479,7 +1490,12 @@ def _register_qlinear_weight_prepack(): ) # Step 1: register patterns from mm and addmm - for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in linear_weight_prepack_cases: + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + is_fp8, + ) in linear_weight_prepack_cases: if is_fp8 and not is_tensor_overload: continue weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( @@ -2549,9 +2565,9 @@ def _register_qlinear_unary_fusion(): # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output linear_unary_replace_float_out_patterns = { - PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - qlinear_pattern, aten.relu.default - ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_unary(qlinear_pattern, aten.relu.default), PostOpAttr( "none", None, "gelu", [], "none" ): _may_generate_pattern_with_dtype_convert( @@ -2582,7 +2598,10 @@ def _register_qlinear_unary_fusion(): ), } - for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + for ( + unary_attr, + patterns, + ) in linear_unary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 4, # pass_number @@ -2663,7 +2682,13 @@ def _register_qlinear_binary_fusion(): is_fp8_list, ) qlinear_binary_replace_patterns = {} - for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary, is_fp8 in combinations: + for ( + unary_op, + int8_mixed_bf16, + swap_inputs, + cvt_dtype_binary, + is_fp8, + ) in combinations: if not int8_mixed_bf16 and cvt_dtype_binary: # No convert node after binary node if dtypes are all fp32 continue From 558d2164eefcf226ed12abaeebd0e29377545043 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Fri, 11 Jul 2025 17:36:23 +0000 Subject: [PATCH 15/26] support fp8 quant_lift_up --- .../pt2e/test_x86inductor_fusion.py | 97 ++++++++++++------- .../quantization/pt2e/inductor_passes/x86.py | 14 ++- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index be88f9cf93..34fccd6364 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -121,18 +121,18 @@ def __init__(self, in_features, out_features, has_bias): self.bias = torch.randn((out_features,)) def forward(self, input): - weight = torch.ops.torchao.dequantize_affine_float8( + weight = torch.ops.torchao.dequantize_affine_float8.default( tensor=self.weight.data, scale=torch.tensor([self.weight_scale]), output_dtype=torch.float, ) - q_input = torch.ops.torchao.quantize_affine_float8( + q_input = torch.ops.torchao.quantize_affine_float8.default( tensor=input, scale=torch.tensor([self.scale]), float8_dtype=self.qtype, ) - dq_input = torch.ops.torchao.dequantize_affine_float8( + dq_input = torch.ops.torchao.dequantize_affine_float8.default( tensor=q_input, scale=torch.tensor([self.scale]), output_dtype=torch.float, @@ -141,6 +141,20 @@ def forward(self, input): out = torch.nn.functional.linear(dq_input, weight, self.bias) return out +def qdq(input, scale): + dtype = input.dtype + q_input = torch.ops.torchao.quantize_affine_float8.default( + input, + torch.tensor([scale]), + torch.float8_e4m3fn, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8.default( + q_input, + torch.tensor([scale]), + dtype, + ) + return dq_input + def fp8_convert_(model): def generate_model_info(model): @@ -2856,6 +2870,7 @@ def __init__( transpose_for_score=False, num_attention_heads=None, attention_head_size=None, + annotate_matmul=False, ) -> None: super().__init__() self.input_dim = input_dim @@ -2864,6 +2879,12 @@ def __init__( self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) self.softmax = torch.nn.Softmax(dim=-1) self.transpose_for_score = transpose_for_score + self.annotate_matmul = annotate_matmul + if self.annotate_matmul: + self.q_out_scale = 0.5 + self.k_out_scale = 0.6 + self.v_out_scale = 0.7 + self.attn_weights_scale = 0.8 if self.transpose_for_score: assert num_attention_heads is not None assert attention_head_size is not None @@ -2886,43 +2907,53 @@ def forward(self, x): q = self.transpose_for_scores(q) k = self.transpose_for_scores(k) v = self.transpose_for_scores(v) - scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + k = k.transpose(-1, -2) + if self.annotate_matmul: + q = qdq(q, self.q_out_scale) + k = qdq(k, self.k_out_scale) + scores = torch.matmul(q, k) / (self.input_dim**0.5) attention = self.softmax(scores) + if self.annotate_matmul: + attention = qdq(attention, self.attn_weights_scale) + v = qdq(v, self.v_out_scale) weighted = torch.matmul(attention, v) return weighted - for annotate_matmul in [False, True]: - mod = SelfAttnLikeModule( - input_dim=64 * 16, - transpose_for_score=True, - num_attention_heads=16, - attention_head_size=64, - ).eval() - v = torch.randn(2, 384, 1024) + for is_fp8 in [True, False]: + for annotate_matmul in [True, False]: + mod = SelfAttnLikeModule( + input_dim=64 * 16, + transpose_for_score=True, + num_attention_heads=16, + attention_head_size=64, + annotate_matmul=annotate_matmul and is_fp8, + ).eval() + v = torch.randn(2, 384, 1024) - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 - ) - self.assertEqual( - counters["inductor"]["qlinear_unary_matcher_count"], - 3 if annotate_matmul and not TEST_ACL else 0, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 3 if annotate_matmul and not TEST_ACL else 0, + ) - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - if annotate_matmul: - quantizer.set_function_type_qconfig( - torch.matmul, quantizer.get_global_quantization_config() - ) + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) - self._test_common( - mod, - (v,), - matcher_check_fn, - check_quantization=True, - quantizer=quantizer, - ) + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + quantizer=quantizer, + is_fp8=is_fp8, + ) instantiate_parametrized_tests(TestPatternMatcher) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 1f66138189..2c6135f187 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -30,12 +30,18 @@ torch.ops.torchao.quantize_affine_float8.default, ] -_VIEW_OPS = [ +_VIEW_FUNCTION_OPS = [ aten.transpose.int, aten.permute.default, aten.view.default, ] +_VIEW_METHOD_OPS = [ + 'transpose', + 'permute', + 'view', +] + """ The quantization.py file primarily incorporates passes related to quantization fusion in inductor, includes: @@ -2896,7 +2902,8 @@ def quant_lift_up(module_graph: torch.fx.graph.Graph): """ def is_view_op(node): - return node.op == "call_function" and node.target in _VIEW_OPS + return (node.op == "call_function" and node.target in _VIEW_FUNCTION_OPS) or \ + (node.op == "call_method" and node.target in _VIEW_METHOD_OPS) for node in module_graph.nodes: # Leslie: Here we verify that the quant node has exactly @@ -2907,7 +2914,8 @@ def is_view_op(node): if ( node.op == "call_function" and node.target in _PER_TENSOR_QUANTIZE_OPS - and len(node.all_input_nodes) == 1 + # TODO: len(node.all_input_nodes) == 2 for fp8 quant + #and len(node.all_input_nodes) == 1 and is_view_op(node.all_input_nodes[0]) ): quant_node = node From 8cd1433b99526b8320a995cd1dcd1d63b15514e2 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 15 Jul 2025 11:03:34 +0000 Subject: [PATCH 16/26] add reshape into _VIEW_METHOD_OPS --- torchao/quantization/pt2e/inductor_passes/x86.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 2c6135f187..aa12d25b69 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -40,6 +40,7 @@ 'transpose', 'permute', 'view', + 'reshape', ] """ From ae4f58272026dd9c98fc20f8b259f8ec4d6bc111 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 17 Jul 2025 09:53:52 +0000 Subject: [PATCH 17/26] add quant_input_check --- torchao/quantization/pt2e/inductor_passes/x86.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index aa12d25b69..f27e20280c 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2906,6 +2906,13 @@ def is_view_op(node): return (node.op == "call_function" and node.target in _VIEW_FUNCTION_OPS) or \ (node.op == "call_method" and node.target in _VIEW_METHOD_OPS) + def quant_input_check(node): + if len(node.all_input_nodes) == 1: + return True + elif node.target == torch.ops.torchao.quantize_affine_float8.default: + # check if scale created by torch.tensor + return len(node.all_input_nodes) == 2 and node.all_input_nodes[1].target == torch.tensor + for node in module_graph.nodes: # Leslie: Here we verify that the quant node has exactly # one input FX node, with constant scalar value for scale and zero point. @@ -2915,8 +2922,7 @@ def is_view_op(node): if ( node.op == "call_function" and node.target in _PER_TENSOR_QUANTIZE_OPS - # TODO: len(node.all_input_nodes) == 2 for fp8 quant - #and len(node.all_input_nodes) == 1 + and quant_input_check(node) and is_view_op(node.all_input_nodes[0]) ): quant_node = node From 80263068a5af1bed982b42ad7e02aaae65964bf7 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 17 Jul 2025 10:30:48 +0000 Subject: [PATCH 18/26] fix lint --- .../pt2e/test_x86inductor_fusion.py | 1 + .../quantization/pt2e/inductor_passes/x86.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 34fccd6364..e87d835e7d 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -141,6 +141,7 @@ def forward(self, input): out = torch.nn.functional.linear(dq_input, weight, self.bias) return out + def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8.default( diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index f27e20280c..755493afbc 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -37,10 +37,10 @@ ] _VIEW_METHOD_OPS = [ - 'transpose', - 'permute', - 'view', - 'reshape', + "transpose", + "permute", + "view", + "reshape", ] """ @@ -2903,15 +2903,19 @@ def quant_lift_up(module_graph: torch.fx.graph.Graph): """ def is_view_op(node): - return (node.op == "call_function" and node.target in _VIEW_FUNCTION_OPS) or \ - (node.op == "call_method" and node.target in _VIEW_METHOD_OPS) + return (node.op == "call_function" and node.target in _VIEW_FUNCTION_OPS) or ( + node.op == "call_method" and node.target in _VIEW_METHOD_OPS + ) def quant_input_check(node): if len(node.all_input_nodes) == 1: return True elif node.target == torch.ops.torchao.quantize_affine_float8.default: # check if scale created by torch.tensor - return len(node.all_input_nodes) == 2 and node.all_input_nodes[1].target == torch.tensor + return ( + len(node.all_input_nodes) == 2 + and node.all_input_nodes[1].target == torch.tensor + ) for node in module_graph.nodes: # Leslie: Here we verify that the quant node has exactly From f735949b2cf76967eaf360949df7a64934b30e49 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 17 Jul 2025 10:38:12 +0000 Subject: [PATCH 19/26] refine ut --- test/quantization/pt2e/test_x86inductor_fusion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index e87d835e7d..e592516b0a 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -1586,12 +1586,8 @@ def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for is_fp8 in [ - True, - ]: - for bias in [ - False, - ]: + for is_fp8 in [True, False]: + for bias in [False, True]: def matcher_check_fn(): self.assertEqual( From 58035116db5c458b4d72dcd8d608f7312cf6712b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Thu, 17 Jul 2025 13:56:28 +0000 Subject: [PATCH 20/26] remove fp8 dynamic quant ut --- test/quantization/pt2e/test_x86inductor_fusion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index e592516b0a..cb2e9f9b44 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -211,6 +211,7 @@ def _generate_qdq_quantized_model( maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: if is_fp8: + assert not is_dynamic assert not is_qat fp8_convert_(mod) return mod @@ -2230,8 +2231,7 @@ def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self, is_fp8): @skipIfNoDynamoSupport @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_dequant_promotion_dynamic_cpu(self, is_fp8): + def test_qlinear_dequant_promotion_dynamic_cpu(self): r""" This testcase test if dequant node before linear is promoted correctly: X @@ -2257,7 +2257,6 @@ def matcher_check_fn(): (torch.randn((2, 4)),), matcher_check_fn=matcher_check_fn, is_dynamic=True, - is_fp8=is_fp8, ) @skipIfNoDynamoSupport From 3e37dea32fe13f94cc19c0305972b9f978b0aa98 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 21 Jul 2025 00:55:41 -0400 Subject: [PATCH 21/26] fix output_scale issue --- torchao/quantization/pt2e/inductor_passes/x86.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 755493afbc..41c2bccbc8 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2425,11 +2425,17 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): b = kwargs["b"] if "b" in kwargs else None # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype in [torch.uint8, torch.int8]) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, torchao.quantize_affine_float8 requires tensor as scale + # Support scale node is full firstly + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 ) From d9ac0924b3c7eda16fc8e378a4615474f22521f5 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 21 Jul 2025 22:29:35 -0400 Subject: [PATCH 22/26] add float8_e4m3fn to dtype_list --- torchao/quantization/pt2e/inductor_passes/x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 41c2bccbc8..e5a76b22c5 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -507,7 +507,7 @@ def fn(match): if "other" in match.kwargs else ( match.kwargs["accum"] - if (output_dtype in [torch.uint8, torch.int8]) + if (output_dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]) or (not extra_input_from_dequant) else match.kwargs["accum_after_dequant"] ) From f88db2db2eaa79ee3eb28228ac75f6ab3df0ca43 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 22 Jul 2025 21:23:07 -0400 Subject: [PATCH 23/26] refine code --- .../pt2e/test_x86inductor_fusion.py | 662 ++++++++++-------- 1 file changed, 387 insertions(+), 275 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index cb2e9f9b44..90d6563b8f 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -1502,11 +1502,17 @@ def test_qlinear_cpu(self): r""" This testcase will quantize a single Linear Moduel. """ - for is_fp8 in [True, False]: - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 4)),), bias=bias, is_fp8=is_fp8 - ) + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_cpu(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1548,11 +1554,22 @@ def test_qlinear_mixed_bf16(self): r""" This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ - for is_fp8 in [True, False]: - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 4)),), mixed_bf16=True, bias=bias, is_fp8=is_fp8 - ) + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, bias=bias + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_mixed_bf16(self): + r""" + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, bias=bias, is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1560,11 +1577,17 @@ def test_qlinear_input_dim_exceeds_2(self): r""" This testcase will quantize a single Linear Moduel. """ - for is_fp8 in [True, False]: - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 3, 4)),), bias=bias, is_fp8=is_fp8 - ) + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -1573,11 +1596,22 @@ def test_qlinear_mixed_bf16_input_dim_exceeds_2(self): r""" This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ - for is_fp8 in [True, False]: - for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias, is_fp8=is_fp8 - ) + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias, is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1587,26 +1621,51 @@ def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for is_fp8 in [True, False]: - for bias in [False, True]: + for bias in [False, True]: - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 - ) - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 13 if bias else 12, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 13 if bias else 12, + ) - self._qlinear_test_helper( - (torch.randn((2, 4, 3, 4)),), - do_permute=True, - matcher_check_fn=matcher_check_fn, - bias=bias, - is_fp8=is_fp8, + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [False, True]: + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 13 if bias else 12, ) + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=True, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -1616,26 +1675,53 @@ def test_qlinear_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for is_fp8 in [True, False]: - for bias in [True, False]: + for bias in [True, False]: - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 - ) - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 17 if bias else 16, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 17 if bias else 16, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + mixed_bf16=True, + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module for int8_bf16. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [True, False]: - self._qlinear_test_helper( - (torch.randn((2, 4, 3, 4)),), - mixed_bf16=True, - do_permute=True, - matcher_check_fn=matcher_check_fn, - bias=bias, - is_fp8=is_fp8, + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 17 if bias else 16, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + mixed_bf16=True, + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=True, + ) def _qlinear_unary_test_helper( self, @@ -1691,8 +1777,15 @@ def test_qlinear_relu_cpu(self): r""" This testcase will quantize a Linear->ReLU pattern. """ - for is_fp8 in [True, False]: - self._qlinear_unary_test_helper((torch.randn((2, 4)),), is_fp8=is_fp8) + self._qlinear_unary_test_helper((torch.randn((2, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_relu_cpu(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 4)),), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -1701,41 +1794,70 @@ def test_qlinear_relu_mixed_bf16(self): r""" This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - for is_fp8 in [True, False]: - self._qlinear_unary_test_helper( - (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8 - ) + self._qlinear_unary_test_helper((torch.randn((2, 4)),), mixed_bf16=True) @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_relu_input_dim_exceeds_2(self, is_fp8): + def test_fp8_qlinear_relu_mixed_bf16(self): + r""" + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_relu_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_relu_input_dim_exceeds_2(self): r""" This testcase will quantize a Linear->ReLU pattern. """ - self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), is_fp8=is_fp8) + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self, is_fp8): + def test_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self): r""" This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ self._qlinear_unary_test_helper( - (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8 + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=True ) @skipIfNoDynamoSupport @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_gelu_cpu(self, is_fp8): + def test_qlinear_gelu_cpu(self): + r""" + This testcase will quantize a Linear->GELU pattern. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu) + + def test_fp8_qlinear_gelu_cpu(self): r""" This testcase will quantize a Linear->GELU pattern. """ for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: - self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu, is_fp8=is_fp8) + self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @@ -1744,19 +1866,31 @@ def test_qlinear_gelu_mixed_bf16(self): r""" This testcase will quantize a Linear->GELU pattern with mixed_bf16 quantization. """ - for is_fp8 in [True, False]: - for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: - self._qlinear_unary_test_helper( - (torch.randn((2, 4)),), gelu, mixed_bf16=True, is_fp8=is_fp8 - ) + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), gelu, mixed_bf16=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_gelu_mixed_bf16(self): + r""" + This testcase will quantize a Linear->GELU pattern with mixed_bf16 quantization. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), gelu, mixed_bf16=True, is_fp8=True + ) def _qlinear_add_test_helper( self, device="cpu", use_relu=False, - int8_mixed_bf16=False, + mixed_bf16=False, is_qat=True, is_dynamic=True, + is_fp8=False, ): r""" This testcase will quantize two consecutive Linear->Add(->relu) patterns as: @@ -1824,13 +1958,17 @@ def forward(self, x): res = self.relu2(res) return res + if is_fp8: + assert not is_dynamic + assert not is_qat + add_fn_list = [ lambda x, y: x + y, lambda x, y: y + x, lambda x, y: x.add_(y), lambda x, y: y.add_(x), ] - fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] + fake_quant_x2_list = [False, True] if mixed_bf16 else [False] shape_list = [(4, 4), (4, 4, 4)] cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list) for add_fn, fq_x2, shape in cases: @@ -1845,7 +1983,7 @@ def matcher_check_fn(): counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 ) # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] - nodes_per_match = 6 if int8_mixed_bf16 else 4 + nodes_per_match = 6 if mixed_bf16 else 4 if len(shape) == 3: # pattern = [dequant_per_tensor, (convert_dtype), (view), \ # dequant_per_channel, (convert_dtype), (view), permute, addmm] @@ -1881,9 +2019,10 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, is_qat=is_qat, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) if TEST_ACL: @@ -1934,163 +2073,24 @@ def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic): @parametrize("is_dynamic", [True, False]) def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): self._qlinear_add_test_helper( - int8_mixed_bf16=True, + mixed_bf16=True, use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic, ) - def _fp8_qlinear_add_test_helper( - self, - device="cpu", - use_relu=False, - mixed_bf16=False, - ): - r""" - This testcase will quantize two consecutive Linear->Add(->relu) patterns as: - X - / \ - linear(X) linear(X) - \ / - Add - | - Optional(relu) - / \ - linear(X) linear(X) - \ / - Add - | - Optional(relu) - | - Y - """ - - class M(torch.nn.Module): - def __init__( - self, - add_fn, - use_relu, - ): - super().__init__() - self.linear1 = torch.nn.Linear(4, 4) - self.linear2 = torch.nn.Linear(4, 4) - self.add_fn = add_fn - self.relu = torch.nn.ReLU() - self.linear3 = torch.nn.Linear(4, 4) - self.linear4 = torch.nn.Linear(4, 4) - self.add_fn2 = add_fn - self.relu2 = torch.nn.ReLU() - self.use_relu = use_relu - - def forward(self, x): - x1 = self.linear1(x) - x2 = self.linear2(x) - tmp = self.add_fn(x1, x2) - if self.use_relu: - tmp = self.relu(tmp) - tmp1 = self.linear3(tmp) - tmp2 = self.linear4(tmp) - res = self.add_fn2(tmp1, tmp2) - if self.use_relu: - res = self.relu2(res) - return res - - add_fn_list = [ - lambda x, y: x + y, - lambda x, y: y + x, - lambda x, y: x.add_(y), - lambda x, y: y.add_(x), - ] - is_fp8 = True - shape_list = [(4, 4), (4, 4, 4)] - cases = itertools.product(add_fn_list, shape_list) - for add_fn, shape in cases: - mod = M(add_fn, use_relu).eval().to(device=device) - v = torch.randn( - shape, dtype=torch.float32, requires_grad=False, device=device - ).add(1) - - def matcher_check_fn(): - # 1. Dequant-linear pattern matched in quantization weight prepack * 4 - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 - ) - # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] - nodes_per_match = 6 if mixed_bf16 else 4 - if len(shape) == 3: - # pattern = [dequant_per_tensor, (convert_dtype), (view), \ - # dequant_per_channel, (convert_dtype), (view), permute, addmm] - nodes_per_match += 2 - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 4 * nodes_per_match, - ) - # 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2 - self.assertEqual( - counters["inductor"]["qlinear_binary_matcher_count"], - 0 if TEST_ACL else 2, - ) - # Two linear-binary patterns are matched - # matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor] - # matched patter2 = [qlinear, add, (convert dtype), (relu)] - # If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary - expected_matcher_nodes = 5 + 2 * use_relu - self.assertEqual( - counters["inductor"]["qlinear_binary_matcher_nodes"], - 0 if TEST_ACL else expected_matcher_nodes, - ) - self.assertEqual( - counters["inductor"]["qlinear_binary_lower_count"], - 0 if TEST_ACL else 2, - ) - - self._test_common( - mod, - (v,), - matcher_check_fn, - check_quantization=True, - check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, - is_fp8=is_fp8, - ) - - if TEST_ACL: - continue - - if torch._inductor.config.cpp_wrapper: - # For CPP wrapper - self._test_code_common( - mod, - (v,), - [ - "aoti_torch_cpu__qlinear_pointwise_tensor", - "aoti_torch_cpu__qlinear_pointwise_binary_tensor", - ], - [], - check_quantization=True, - num_include_ops=[2, 2], - is_fp8=True, - ) - else: - # For python wrapper - self._test_code_common( - mod, - (v,), - [ - "torch.ops.onednn.qlinear_pointwise.tensor", - "torch.ops.onednn.qlinear_pointwise.binary", - ], - [], - check_quantization=True, - num_include_ops=[2, 2], - is_fp8=True, - ) - @skipIfNoDynamoSupport @skipIfNoONEDNN @parametrize("use_relu", [True, False]) @parametrize("mixed_bf16", [True, False]) def test_fp8_qlinear_add_cpu(self, use_relu, mixed_bf16): - self._fp8_qlinear_add_test_helper(use_relu=use_relu, mixed_bf16=mixed_bf16) + self._qlinear_add_test_helper( + use_relu=use_relu, + mixed_bf16=mixed_bf16, + is_qat=False, + is_dynamic=False, + is_fp8=True, + ) def _qlinear_dequant_promotion_test_helper( self, @@ -2147,9 +2147,44 @@ def default_matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_dequant_promotion_cpu(self, is_fp8): + def test_qlinear_dequant_promotion_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_dequant_promotion_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),), is_fp8=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_mixed_bf16(self): r""" + Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: X | @@ -2162,14 +2197,13 @@ def test_qlinear_dequant_promotion_cpu(self, is_fp8): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 4)),), is_fp8=is_fp8 + (torch.randn((2, 4)),), mixed_bf16=True ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_dequant_promotion_mixed_bf16(self, is_fp8): + def test_fp8_qlinear_dequant_promotion_mixed_bf16(self): r""" Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: @@ -2184,13 +2218,29 @@ def test_qlinear_dequant_promotion_mixed_bf16(self, is_fp8): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=is_fp8 + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=True ) @skipIfNoDynamoSupport @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self, is_fp8): + def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),)) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): r""" This testcase test if dequant node before linear is promoted correctly: X @@ -2204,14 +2254,13 @@ def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self, is_fp8): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 3, 4)),), is_fp8=is_fp8 + (torch.randn((2, 3, 4)),), is_fp8=True ) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self, is_fp8): + def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self): r""" Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: @@ -2226,7 +2275,28 @@ def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self, is_fp8): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=is_fp8 + (torch.randn((2, 3, 4)),), mixed_bf16=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_fp8_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self): + r""" + Test with mixed_bf16 quantization. + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=True ) @skipIfNoDynamoSupport @@ -2261,8 +2331,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @parametrize("is_fp8", [True, False]) - def test_qlinear_mul_cpu(self, is_fp8): + def test_qlinear_mul_cpu(self): r""" This testcase will quantize a Linear->Mul pattern. """ @@ -2291,7 +2360,40 @@ def matcher_check_fn(): (x1, x2), matcher_check_fn, check_quantization=True, - is_fp8=is_fp8, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_qlinear_mul_cpu(self): + r""" + This testcase will quantize a Linear->Mul pattern. + """ + + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 5, use_bias) + + def forward(self, x1, x2): + return torch.mul(self.linear(x1), x2) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval() + x1 = torch.randn((2, 4)) + x2 = torch.randn((2, 5)) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (x1, x2), + matcher_check_fn, + check_quantization=True, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -2856,9 +2958,7 @@ def matcher_check_fn(): is_qat=True, ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - def test_q_attention_block(self): + def _test_q_attention_block_helper(self, annotate_matmul, is_fp8=False): class SelfAttnLikeModule(torch.nn.Module): def __init__( self, @@ -2915,41 +3015,53 @@ def forward(self, x): weighted = torch.matmul(attention, v) return weighted - for is_fp8 in [True, False]: - for annotate_matmul in [True, False]: - mod = SelfAttnLikeModule( - input_dim=64 * 16, - transpose_for_score=True, - num_attention_heads=16, - attention_head_size=64, - annotate_matmul=annotate_matmul and is_fp8, - ).eval() - v = torch.randn(2, 384, 1024) - - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 - ) - self.assertEqual( - counters["inductor"]["qlinear_unary_matcher_count"], - 3 if annotate_matmul and not TEST_ACL else 0, - ) + mod = SelfAttnLikeModule( + input_dim=64 * 16, + transpose_for_score=True, + num_attention_heads=16, + attention_head_size=64, + annotate_matmul=annotate_matmul and is_fp8, + ).eval() + v = torch.randn(2, 384, 1024) - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - if annotate_matmul: - quantizer.set_function_type_qconfig( - torch.matmul, quantizer.get_global_quantization_config() - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 3 if annotate_matmul and not TEST_ACL else 0, + ) - self._test_common( - mod, - (v,), - matcher_check_fn, - check_quantization=True, - quantizer=quantizer, - is_fp8=is_fp8, - ) + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + quantizer=quantizer, + is_fp8=is_fp8, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_q_attention_block(self): + for annotate_matmul in [True, False]: + self._test_q_attention_block_helper(annotate_matmul=annotate_matmul) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_fp8_q_attention_block(self): + for annotate_matmul in [True, False]: + self._test_q_attention_block_helper( + annotate_matmul=annotate_matmul, is_fp8=True + ) instantiate_parametrized_tests(TestPatternMatcher) From 4f4eb8bc48139353f6936c9eb4d4e98bae207fa7 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 22 Jul 2025 22:01:13 -0400 Subject: [PATCH 24/26] refine code --- .../pt2e/test_x86inductor_fusion.py | 44 +--- .../quantization/pt2e/inductor_passes/x86.py | 202 +++++++++--------- 2 files changed, 105 insertions(+), 141 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 90d6563b8f..8ced99dbb1 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -1621,7 +1621,7 @@ def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for bias in [False, True]: + for bias in [True, False]: def matcher_check_fn(): self.assertEqual( @@ -1647,7 +1647,7 @@ def test_fp8_qlinear_input_dim_exceeds_2_and_not_contiguous(self): * Input dim exceeds 2 * Input not contiguous """ - for bias in [False, True]: + for bias in [True, False]: def matcher_check_fn(): self.assertEqual( @@ -1959,6 +1959,7 @@ def forward(self, x): return res if is_fp8: + # fp8_convert_ not support dynamic and qat yet assert not is_dynamic assert not is_qat @@ -2828,45 +2829,6 @@ def matcher_check_fn(): if test_for_pointwise_binary: self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) - @skipIfNoONEDNN - @parametrize("has_bias", [True, False]) - @parametrize("dtype", [torch.float32, torch.bfloat16]) - @parametrize("input_dim_exceeds_two", [True, False]) - @parametrize("check_reuse_input", [True, False]) - def test_fp8_qlinear( - self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input - ): - class Mod(torch.nn.Module): - def __init__(self, in_features, out_features, check_reuse_input): - super().__init__() - self.l0 = FP8QDQLinear(in_features, out_features, has_bias) - self.check_reuse_input = check_reuse_input - if self.check_reuse_input: - self.l1 = FP8QDQLinear(in_features, out_features, has_bias) - - def forward(self, x): - y = self.l0(x) - if self.check_reuse_input: - z = self.l1(x) - y = torch.cat([y, z]) - return y - - M1, M2, N, K = 2, 3, 13, 16 - M = M1 * M2 - mod = Mod(N, K, check_reuse_input) - if input_dim_exceeds_two: - v = torch.randn(M1, M2, N) - else: - v = torch.randn(M, N) - - def matcher_check_fn(): - counter = 2 if check_reuse_input else 1 - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], counter - ) - - self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) - @dynamo_config.patch( { diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index e5a76b22c5..b45a1118f7 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2514,113 +2514,114 @@ def _register_qlinear_unary_fusion(): _gelu_fusion_2 as _gelu_fusion_tanh, ) - for is_fp8 in [True, False]: - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: - is_bf16 = original_pattern_output_dtype == torch.bfloat16 - for x_scale_zp_are_tensors in (False, True): - qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) - computation_op = ( - torch.ops.onednn.qlinear_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qlinear_pointwise.default - ) - # Priority 1 to match: QLinear Unary pattern with int8 output - linear_unary_replace_patterns = { - PostOpAttr( - "none", None, "none", [], "" - ): generate_pattern_with_output_quant( - qlinear_pattern, - is_fp8=is_fp8, - ), - PostOpAttr( - "none", None, "relu", [], "" - ): generate_pattern_with_output_quant( - generate_pattern_with_unary(qlinear_pattern, aten.relu.default), - is_fp8=is_fp8, - ), - PostOpAttr( - "none", None, "gelu", [], "none" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 - ), - 2, - is_bf16, - ), - with_dtype_convert=is_bf16, - is_fp8=is_fp8, + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [True, False] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 - ), - 4, - is_bf16, - ), - with_dtype_convert=is_bf16, - is_fp8=is_fp8, + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), - } + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + is_fp8=is_fp8, + ), + } - for unary_attr, patterns in linear_unary_replace_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 3, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) - # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output - linear_unary_replace_float_out_patterns = { - PostOpAttr( - "none", None, "relu", [], "" - ): generate_pattern_with_unary(qlinear_pattern, aten.relu.default), - PostOpAttr( - "none", None, "gelu", [], "none" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 - ), - 2, - is_bf16, - ), - Arg(), - is_bf16, + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 - ), - 4, - is_bf16, - ), - Arg(), - is_bf16, + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), - } + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } - for ( - unary_attr, - patterns, - ) in linear_unary_replace_float_out_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 4, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for ( + unary_attr, + patterns, + ) in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) def _register_qlinear_binary_fusion(): @@ -2922,6 +2923,7 @@ def quant_input_check(node): len(node.all_input_nodes) == 2 and node.all_input_nodes[1].target == torch.tensor ) + return False for node in module_graph.nodes: # Leslie: Here we verify that the quant node has exactly From 7c3f9f9b46adedcb04fb6a262885329f1dbeaee6 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 22 Jul 2025 22:52:30 -0400 Subject: [PATCH 25/26] fix bugs --- test/quantization/pt2e/test_x86inductor_fusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 8ced99dbb1..a7a9fb58be 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -1969,7 +1969,7 @@ def forward(self, x): lambda x, y: x.add_(y), lambda x, y: y.add_(x), ] - fake_quant_x2_list = [False, True] if mixed_bf16 else [False] + fake_quant_x2_list = [False, True] if mixed_bf16 and not is_fp8 else [False] shape_list = [(4, 4), (4, 4, 4)] cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list) for add_fn, fq_x2, shape in cases: @@ -2041,6 +2041,7 @@ def matcher_check_fn(): [], check_quantization=True, num_include_ops=[2, 2], + is_fp8=is_fp8, ) else: # For python wrapper @@ -2054,6 +2055,7 @@ def matcher_check_fn(): [], check_quantization=True, num_include_ops=[2, 2], + is_fp8=is_fp8, ) @skipIfNoDynamoSupport From 38de0e9b4947829ada3a12c042f8f39f5457d9ba Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 22 Jul 2025 22:56:48 -0400 Subject: [PATCH 26/26] add comment --- test/quantization/pt2e/test_x86inductor_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index a7a9fb58be..3b0bf2f8d6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -211,6 +211,7 @@ def _generate_qdq_quantized_model( maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: if is_fp8: + # fp8_convert_ not support dynamic and qat yet assert not is_dynamic assert not is_qat fp8_convert_(mod)