Skip to content

Commit 8fb7215

Browse files
committed
Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig
Summary: we will * deprecate FbgemmConfig since it's a single kernel (later). * we'd like to categorize things to derived dtype + packed format, e.g. int4 preshuffled, float8 plain * Added PackingFormat that has preshuffled, plain in Version 2 of Int4WeightOnlyConfig, the older AQT tensor will remain in Version 1 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/float8/test_float8_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2474, branch: jerryzh168/stack/10
1 parent 1506c0d commit 8fb7215

File tree

13 files changed

+239
-139
lines changed

13 files changed

+239
-139
lines changed

docs/source/api_ref_quantization.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Inference APIs for quantize\_
2424
:nosignatures:
2525

2626
Int4WeightOnlyConfig
27+
Float8ActivationInt4WeightConfig
2728
Float8DynamicActivationFloat8WeightConfig
2829
Float8WeightOnlyConfig
2930
Float8StaticActivationFloat8WeightConfig

test/integration/test_loading_deprecated_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
class TestLoadingDeprecatedCheckpoint(TestCase):
2727
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
2828
def test_load_model_and_run(self, model_name_and_version):
29-
"""Test that we print correct warning message when loading a deprecated checkpoint"""
29+
"""Test that we print correct warning message when loading a deprecated checkpoint
30+
and making sure the deprecated checkpoints can still be loaded
31+
"""
3032
# Load and quantize model
3133
model_name, version = model_name_and_version
3234
with warnings.catch_warnings(record=True) as caught_warnings:
@@ -41,6 +43,7 @@ def test_load_model_and_run(self, model_name_and_version):
4143
for w in caught_warnings
4244
), "Didn't get expected warning message for version mismatch"
4345

46+
# TODO: generalize when we test more checkpoints
4447
assert any(
4548
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
4649
in str(w.message)

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
run_tests,
1616
)
1717

18-
from torchao.float8.config import e4m3_dtype
1918
from torchao.quantization import (
20-
FbgemmConfig,
19+
Float8ActivationInt4WeightConfig,
20+
Int4WeightOnlyConfig,
2121
quantize_,
2222
)
2323
from torchao.quantization.utils import compute_error
@@ -27,44 +27,16 @@
2727
is_sm_at_least_90,
2828
)
2929

30-
if TORCH_VERSION_AT_LEAST_2_8:
31-
BF16_ACT_CONFIG = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
preshuffle=True,
37-
)
38-
39-
BF16_ACT_BMM_CONFIG = FbgemmConfig(
40-
input_dtype=torch.bfloat16,
41-
weight_dtype=torch.int4,
42-
output_dtype=torch.bfloat16,
43-
block_size=[1, 1, 128],
44-
preshuffle=True,
45-
)
46-
47-
FP8_ACT_CONFIG = FbgemmConfig(
48-
input_dtype=e4m3_dtype,
49-
weight_dtype=torch.int4,
50-
output_dtype=torch.bfloat16,
51-
block_size=[1, 128],
52-
preshuffle=True,
53-
)
54-
55-
FP8_ACT_BMM_CONFIG = FbgemmConfig(
56-
input_dtype=e4m3_dtype,
57-
weight_dtype=torch.int4,
58-
output_dtype=torch.bfloat16,
59-
block_size=[1, 1, 128],
60-
preshuffle=True,
61-
)
62-
63-
else:
64-
BF16_ACT_CONFIG = None
65-
BF16_ACT_BMM_CONFIG = None
66-
FP8_ACT_CONFIG = None
67-
FP8_ACT_BMM_CONFIG = None
30+
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
31+
group_size=128,
32+
packing_format="preshuffled",
33+
VERSION=2,
34+
)
35+
36+
FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
37+
group_size=128,
38+
packing_format="preshuffled",
39+
)
6840

6941

7042
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@@ -90,7 +62,7 @@ def test_linear(self, config):
9062

9163
# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
9264
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93-
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
65+
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
9466
def test_bmm(self, bmm_config):
9567
class M(torch.nn.Module):
9668
def __init__(self, weight):

test/dtypes/test_fbgemm_int4.py renamed to test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from torchao.quantization import (
16-
FbgemmConfig,
16+
Int4WeightOnlyConfig,
1717
quantize_,
1818
)
1919
from torchao.quantization.utils import compute_error
@@ -26,19 +26,12 @@
2626
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2727
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2828
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
class TestFbgemmInt4Tensor(TestCase):
29+
class TestInt4Tensor(TestCase):
3030
def setUp(self):
31-
self.config = FbgemmConfig(
32-
input_dtype=torch.bfloat16,
33-
weight_dtype=torch.int4,
34-
output_dtype=torch.bfloat16,
35-
block_size=[1, 128],
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=torch.bfloat16,
39-
weight_dtype=torch.int4,
40-
output_dtype=torch.bfloat16,
41-
block_size=[1, 1, 128],
31+
self.config = Int4WeightOnlyConfig(
32+
group_size=128,
33+
packing_format="plain",
34+
VERSION=2,
4235
)
4336
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4437

@@ -68,13 +61,9 @@ def test_slice(self):
6861
quantize_(dummy, self.config)
6962
weight1 = dummy.weight.narrow(0, 0, 64)
7063
weight2 = dummy.weight.narrow(1, 0, 128)
71-
self.assertEqual(
72-
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
73-
)
64+
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
7465
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
75-
self.assertEqual(
76-
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
77-
)
66+
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
7867
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
7968

8069
# check for sliced weight, before and after float8 quantization
@@ -100,12 +89,10 @@ def test_slice_and_copy_(self):
10089
param = l.weight
10190
param_data = param.data
10291
param_data = param_data.narrow(0, 0, 512)
103-
assert (
104-
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
105-
)
92+
assert param.data._data.data_ptr() == param_data._data.data_ptr()
10693
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
10794
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
108-
orig_value = param.data.packed_weight[0][0].item()
95+
orig_value = param.data._data[0][0].item()
10996

11097
# dummy_l has random input (shouldn't be 0)
11198
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
@@ -116,7 +103,7 @@ def test_slice_and_copy_(self):
116103
param_data.copy_(quantized)
117104

118105
# making sure param.data is updated
119-
assert param.data.packed_weight[0][0] != orig_value
106+
assert param.data._data[0][0] != orig_value
120107

121108
def test_bmm(self):
122109
class M(torch.nn.Module):
@@ -135,7 +122,7 @@ def forward(self, x):
135122
original = m(input)
136123
# we need to transpose the weight first for bmm
137124
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
138-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
125+
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
139126
quantized = m(input)
140127
self.assertTrue(compute_error(original, quantized) > 18)
141128

torchao/dtypes/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to_affine_quantized_intx_static,
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
12-
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
1312
from .floatx import (
1413
CutlassSemiSparseLayout,
1514
Float8Layout,
@@ -64,8 +63,6 @@
6463
"PackedLinearInt8DynamicActivationIntxWeightLayout",
6564
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6665
"Int4XPULayout",
67-
"to_fbgemm_int4",
68-
"FbgemmInt4Tensor",
6966
"to_fbgemm_fp8",
7067
"FbgemmFp8Tensor",
7168
"Int8DynamicActInt4WeightCPULayout",

torchao/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .quant_api import (
4545
CutlassInt4PackedLayout,
4646
FbgemmConfig,
47+
Float8ActivationInt4WeightConfig,
4748
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4849
Float8DynamicActivationFloat8WeightConfig,
4950
Float8MMConfig,
@@ -90,6 +91,7 @@
9091
from .quantize_.workflows import (
9192
Float8Tensor,
9293
Int4PreshuffledTensor,
94+
Int4Tensor,
9395
)
9496
from .smoothquant import (
9597
SmoothFakeDynamicallyQuantizedLinear,
@@ -141,6 +143,7 @@
141143
"Int8DynamicActivationInt8WeightConfig",
142144
"Int8DynamicActivationIntxWeightConfig",
143145
"Int4WeightOnlyConfig",
146+
"Float8ActivationInt4WeightConfig",
144147
"Int8WeightOnlyConfig",
145148
"Float8WeightOnlyConfig",
146149
"Float8DynamicActivationFloat8WeightConfig",
@@ -154,6 +157,7 @@
154157
"ModuleFqnToConfig",
155158
"FbgemmConfig",
156159
# tensor subclasses
160+
"Int4Tensor",
157161
"Int4PreshuffledTensor",
158162
"Float8Tensor",
159163
# smooth quant - subject to change

torchao/quantization/quant_api.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
to_affine_quantized_floatx_static,
5050
to_affine_quantized_intx,
5151
to_fbgemm_fp8,
52-
to_fbgemm_int4,
5352
to_marlinqqq_quantized_intx,
5453
)
5554
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
@@ -71,10 +70,12 @@
7170
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
7271
from torchao.quantization.quantize_.common import (
7372
KernelPreference,
73+
PackingFormat,
7474
)
7575
from torchao.quantization.quantize_.workflows import (
7676
Float8Tensor,
7777
Int4PreshuffledTensor,
78+
Int4Tensor,
7879
QuantizeTensorToFloat8Kwargs,
7980
)
8081
from torchao.quantization.transform_module import (
@@ -1119,6 +1120,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11191120
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
11201121
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
11211122
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
1123+
`packing_format`: the packing format for int4 tensor, available from VERSION 2 and above
11221124
"""
11231125

11241126
group_size: int = 128
@@ -1127,6 +1129,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11271129
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
11281130
set_inductor_config: bool = True
11291131
preserve_zero: Optional[bool] = None
1132+
# only used in VERSION >= 2
1133+
packing_format: PackingFormat = PackingFormat.PLAIN
1134+
VERSION: int = 1
11301135

11311136

11321137
# for BC
@@ -1144,15 +1149,36 @@ def _int4_weight_only_quantize_tensor(weight, config):
11441149
layout = config.layout
11451150
use_hqq = config.use_hqq
11461151
zero_point_domain = config.zero_point_domain
1152+
packing_format = config.packing_format
11471153

11481154
if weight.shape[-1] % group_size != 0:
11491155
logger.info(
11501156
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
11511157
)
11521158
return weight
11531159

1160+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1161+
1162+
if config.VERSION == 2:
1163+
if packing_format == PackingFormat.PRESHUFFLED:
1164+
new_weight = Int4PreshuffledTensor.from_float(
1165+
weight,
1166+
block_size,
1167+
activation_dtype=torch.bfloat16,
1168+
)
1169+
return new_weight
1170+
elif packing_format == PackingFormat.PLAIN:
1171+
new_weight = Int4Tensor.from_float(
1172+
weight,
1173+
block_size,
1174+
)
1175+
return new_weight
1176+
else:
1177+
raise ValueError(f"Unsupported packing format: {packing_format}")
1178+
1179+
assert config.VERSION == 1
1180+
11541181
mapping_type = MappingType.ASYMMETRIC
1155-
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
11561182
target_dtype = torch.int32
11571183
quant_min = 0
11581184
quant_max = 15
@@ -1224,6 +1250,46 @@ def _int4_weight_only_transform(
12241250
return module
12251251

12261252

1253+
@dataclass
1254+
class Float8ActivationInt4WeightConfig(AOBaseConfig):
1255+
"""Configuration for apply float8 dynamic per row quantization and int4
1256+
per group weight quantization to linear
1257+
1258+
Args:
1259+
`group_size`: group size for groupwise quantization for weight
1260+
`packing_format`: how the weight is packed, only preshuffled is supported
1261+
"""
1262+
1263+
group_size: int = 128
1264+
packing_format: PackingFormat = "preshuffled"
1265+
1266+
1267+
@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
1268+
def _float8_activation_int4_weight_transform(
1269+
module: torch.nn.Module, config: Float8ActivationInt4WeightConfig
1270+
) -> torch.nn.Module:
1271+
assert hasattr(module, "weight"), (
1272+
"applying int8 weight only quant requires module to have weight attribute"
1273+
+ " but {module} does not have one"
1274+
)
1275+
group_size = config.group_size
1276+
packing_format = config.packing_format
1277+
1278+
assert packing_format == "preshuffled", (
1279+
f"only preshuffled packing_format supported right now, got: {packing_format}"
1280+
)
1281+
weight = module.weight
1282+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
1283+
new_weight = Int4PreshuffledTensor.from_float(
1284+
module.weight,
1285+
block_size,
1286+
activation_dtype=torch.float8_e4m3fn,
1287+
)
1288+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1289+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1290+
return module
1291+
1292+
12271293
@dataclass
12281294
class Int8WeightOnlyConfig(AOBaseConfig):
12291295
"""
@@ -1677,6 +1743,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16771743
# TODO(future PR): this should really throw an exception instead of silently
16781744
# not doing what the user asked
16791745
return weight
1746+
16801747
if isinstance(weight_granularity, PerRow):
16811748
assert weight.dtype == torch.bfloat16, (
16821749
"PerRow quantization only works for bfloat16 precision input weight"
@@ -2145,7 +2212,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
21452212
activation_dtype=torch.bfloat16,
21462213
)
21472214
else:
2148-
weight = to_fbgemm_int4(
2215+
weight = Int4Tensor.from_float(
21492216
module.weight,
21502217
config.block_size,
21512218
)

torchao/quantization/quantize_/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .kernel_preference import KernelPreference
2+
from .packing_format import PackingFormat
23
from .quantize_tensor_kwargs import (
34
QuantizeTensorKwargs,
45
_choose_quant_func_and_quantize_tensor,
@@ -7,5 +8,6 @@
78
__all__ = [
89
"QuantizeTensorKwargs",
910
"KernelPreference",
11+
"PackingFormat",
1012
"_choose_quant_func_and_quantize_tensor",
1113
]

0 commit comments

Comments
 (0)