From 25aa861f41d20ae911d39fd361647b48d655ce35 Mon Sep 17 00:00:00 2001 From: Gauri Sahnan Date: Fri, 18 Jul 2025 10:28:27 +0000 Subject: [PATCH 1/6] [Feat] Restore and validate KleidiAI INT4 quantization path using updated quantizer API - Switched to quantize_() with Int8DynamicActivationIntxWeightConfig - Validated the move of packed_linear_int8_dynamic_activation_intx_weight_layout.py in torchao/dtypes/uintx - Fixed handling of SYMMETRIC_NO_CLIPPING_ERR mapping type - Validated INT4 path on a 2-layer nn.Sequential model with torch.int4 weights - Compared SYMMETRIC vs SYMMETRIC_NO_CLIPPING_ERR across PerAxis and PerGroup granularities --- torchao/experimental/docs/readme.md | 32 --------------------------- torchao/quantization/README.md | 34 +++++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 8 ++++++- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index a178c9b328..0f61a89c0f 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -96,38 +96,6 @@ quantize_( layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() is also supported, but much slower on CPU ), ) -``` - -KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: - -```python -from torchao.dtypes import PlainLayout -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.quantization.quant_api import quantize_ -from torchao.quantization.quant_primitives import MappingType - -my_model = Model() - -quantize_( - my_model, - int8_dynamic_activation_intx_weight( - weight_dtype=torch.int4, - granularity=PerGroup(32), # PerRow() is also supported - has_weight_zeros=True, # Should be True - weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), - ), -) -``` If you get stuck, consult `torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 83caffdc09..86809d204a 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -205,6 +205,40 @@ quantize_(model, FPXWeightOnlyConfig(3, 2)) You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. +``` + +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + quantize_, +) +from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, + Target, +) +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_primitives import MappingType +from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler + +my_model = Model() + +# Set quantization layout +layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.ATEN) + +quantize_( + my_model, + Int8DynamicActivationIntxWeightConfig( + weight_scale_dtype=torch.float32, + weight_granularity=PerGroup(32), #PerAxis is also supported + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error + layout=layout, + weight_dtype=torch.int4, + ), +) +``` + ## Affine Quantization Details Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 71ea0abe41..ece3fa6555 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -917,6 +917,12 @@ def _int8_dynamic_activation_intx_weight_transform( quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] # We quantize with QDQLayout, and then construct the packed weight tensor later + # set preserve_zero based on weight mapping type + preserve_zero = weight_mapping_type in [ + MappingType.SYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, + ] + weight = to_affine_quantized_intx( input_float=weight, mapping_type=weight_mapping_type, @@ -926,7 +932,7 @@ def _int8_dynamic_activation_intx_weight_transform( quant_max=quant_max, scale_dtype=weight_scale_dtype, zero_point_dtype=torch.int8, - preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC), + preserve_zero=preserve_zero, zero_point_domain=ZeroPointDomain.INT, _layout=QDQLayout(), ) From e9746dff4d330680e4b8b3d3ab4f6f9b04b9c31c Mon Sep 17 00:00:00 2001 From: Gauri Sahnan Date: Tue, 22 Jul 2025 19:22:49 +0000 Subject: [PATCH 2/6] [Fix]: Allow "SYMMETRIC_NO_CLIPPING_ERR" in Int8DynamicActivationIntxWeightConfig --- LICENSE | 2 ++ torchao/quantization/quant_api.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index 56f4d62a47..44018e4daf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,6 @@ Copyright 2023 Meta +All contributions by Arm: +Copyright (c) 2024-2025 Arm Limited and/or its affiliates Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ece3fa6555..2afbf85aa4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. # All rights reserved. - # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. @@ -862,8 +862,9 @@ def __post_init__(self): assert self.weight_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, ], ( - f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.weight_mapping_type}" + f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" ) assert self.act_mapping_type in [ MappingType.ASYMMETRIC, From 88cb2652423bde80556e0f4e2a059dc5f5974b11 Mon Sep 17 00:00:00 2001 From: Gauri Sahnan Date: Fri, 25 Jul 2025 10:22:34 +0000 Subject: [PATCH 3/6] [FEAT]: Add SYMMETRIC_NO_CLIPPING_ERR to tests --- .../tests/test_int8_dynamic_activation_intx_weight.py | 2 ++ torchao/quantization/quant_api.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 08548b9e9e..5d61e58186 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -54,6 +55,7 @@ class TestInt8DynamicActivationIntxWeight(unittest.TestCase): for weight_mapping_type in [ MappingType.SYMMETRIC, MappingType.ASYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, ] for weight_granularity in [ PerGroup(128), diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2afbf85aa4..ab820193b8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -923,7 +923,7 @@ def _int8_dynamic_activation_intx_weight_transform( MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR, ] - + weight = to_affine_quantized_intx( input_float=weight, mapping_type=weight_mapping_type, From 14d8ba6945902889a3f3f2a89ab2ab40b2decac7 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:37:08 -0700 Subject: [PATCH 4/6] Update test_int8_dynamic_activation_intx_weight.py --- .../tests/test_int8_dynamic_activation_intx_weight.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 5d61e58186..29420ad702 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -73,6 +73,9 @@ def test_accuracy( """ Checks the accuracy of packed layouts """ + if weight_dtype == torch.int1 and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR: + return + m = 3 n = 1071 k = 2048 From c4c1e50256ffb799e12196d8a9b37de9c81208b0 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Jul 2025 18:34:03 -0700 Subject: [PATCH 5/6] Update test_int8_dynamic_activation_intx_weight.py --- .../tests/test_int8_dynamic_activation_intx_weight.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 29420ad702..4f2991362c 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -73,7 +73,10 @@ def test_accuracy( """ Checks the accuracy of packed layouts """ - if weight_dtype == torch.int1 and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR: + if ( + weight_dtype == torch.int1 + and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR + ): return m = 3 From 3b255a232bc326a2e1ecc986fa092f7492c700d0 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 25 Jul 2025 22:01:20 -0700 Subject: [PATCH 6/6] Update test_int8_dynamic_activation_intx_weight.py --- .../tests/test_int8_dynamic_activation_intx_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 4f2991362c..5cba538068 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -78,7 +78,7 @@ def test_accuracy( and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR ): return - + m = 3 n = 1071 k = 2048