Skip to content

Commit d375bbb

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent ab9a1c3 commit d375bbb

File tree

7 files changed

+123
-40
lines changed

7 files changed

+123
-40
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -170,7 +170,7 @@ def test_disallowed_modules():
170170

171171

172172
def test_version_mismatch():
173-
"""Test that version mismatch raises an error during reconstruction."""
173+
"""Test that version mismatch prints a warning during reconstruction."""
174174
# Create a config
175175
dummy_config = DummyNonAllowedConfig()
176176
reconstructable = config_to_dict(dummy_config)
@@ -180,11 +180,13 @@ def test_version_mismatch():
180180

181181
# Patch to allow the module but should still fail due to version mismatch
182182
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
183-
with pytest.raises(
184-
VersionMismatchError,
185-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
186-
):
183+
with warnings.catch_warnings(record=True) as caught_warnings:
187184
config_from_dict(reconstructable)
185+
assert any(
186+
"Stored version is not the same as current default version of the config"
187+
in str(w.message)
188+
for w in caught_warnings
189+
), "Din't get expected warning message for version mismatch"
188190

189191

190192
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def test_mm_float8dq_per_row(
319319
)
320320
test_linear = copy.deepcopy(ref_linear)
321321
quantize_(
322-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
322+
test_linear,
323+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
323324
)
324325

325326
quant_weight = test_linear.weight
@@ -471,7 +472,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
471472
# Create and quantize a model
472473
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
473474
quantize_(
474-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
475+
model,
476+
Float8DynamicActivationFloat8WeightConfig(
477+
granularity=granularity, VERSION=1
478+
),
475479
)
476480

477481
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -505,7 +509,10 @@ def test_float8_tensor_slicing_per_tensor(self):
505509
# Create and quantize with per-tensor granularity
506510
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
507511
quantize_(
508-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
512+
model,
513+
Float8DynamicActivationFloat8WeightConfig(
514+
granularity=PerTensor(), VERSION=1
515+
),
509516
)
510517

511518
original_weight = model.weight
@@ -536,7 +543,8 @@ def test_float8_tensor_slicing_per_row(self):
536543
# Create and quantize with per-row granularity
537544
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
538545
quantize_(
539-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
546+
model,
547+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
540548
)
541549

542550
original_weight = model.weight # Shape: (32, 64)
@@ -574,7 +582,10 @@ def test_float8_tensor_slicing_edge_cases(self):
574582
# Create and quantize a model
575583
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
576584
quantize_(
577-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
585+
model,
586+
Float8DynamicActivationFloat8WeightConfig(
587+
granularity=PerTensor(), VERSION=1
588+
),
578589
)
579590

580591
original_weight = model.weight
@@ -612,7 +623,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
612623
quant_model = copy.deepcopy(ref_model)
613624
quantize_(
614625
quant_model,
615-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
626+
Float8DynamicActivationFloat8WeightConfig(
627+
granularity=granularity, VERSION=1
628+
),
616629
)
617630

618631
# Create input with batch size that works well with slicing

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,10 @@ def test_quantize(self):
477477
m = nn.Sequential(nn.Linear(32, 32)).cuda()
478478
m = convert_to_float8_training(m)
479479
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
480-
from torchao.quantization.quant_api import float8_weight_only, quantize_
480+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
481481

482-
quantize_(m, float8_weight_only())
483-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
482+
quantize_(m, Float8WeightOnlyConfig())
483+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
484484
"Post quantization dtype should be torch.float8_e4m3fn"
485485
)
486486
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Din't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Din't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

torchao/core/config.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import enum
99
import importlib
1010
import json
11+
import warnings
1112
from typing import Any, ClassVar, Dict
1213

1314
import torch
1415

1516
__all__ = [
1617
"AOBaseConfig",
17-
"VersionMismatchError",
1818
"config_from_dict",
1919
"config_to_dict",
2020
"ALLOWED_AO_MODULES",
@@ -61,20 +61,6 @@ def _transform(
6161
VERSION: ClassVar[int] = _DEFAULT_VERSION
6262

6363

64-
class VersionMismatchError(Exception):
65-
"""Raised when trying to deserialize a config with a different version"""
66-
67-
def __init__(self, type_path, stored_version, current_version):
68-
self.type_path = type_path
69-
self.stored_version = stored_version
70-
self.current_version = current_version
71-
message = (
72-
f"Version mismatch for {type_path}: "
73-
f"stored version {stored_version} != current version {current_version}"
74-
)
75-
super().__init__(message)
76-
77-
7864
class ConfigJSONEncoder(json.JSONEncoder):
7965
"""Custom JSON encoder for AOBaseConfig objects"""
8066

@@ -91,7 +77,9 @@ def default(self, o):
9177
return {
9278
# Only store the class name, not the full module path
9379
"_type": o.__class__.__name__,
94-
"_version": getattr(o.__class__, "VERSION", 1),
80+
# not using class VERSION since we might be explicitly
81+
# setting a different VERSION for the object itself
82+
"_version": getattr(o, "VERSION", 1),
9583
"_data": data_dict,
9684
}
9785

@@ -105,7 +93,9 @@ def default(self, o):
10593

10694
return {
10795
"_type": o.__class__.__name__,
108-
"_version": getattr(o.__class__, "VERSION", 1),
96+
# not using class VERSION since we might be explicitly
97+
# setting a different VERSION for the object itself
98+
"_version": getattr(o, "VERSION", 1),
10999
"_data": processed_data,
110100
}
111101

@@ -120,7 +110,9 @@ def default(self, o):
120110
return {
121111
# Only store the class name for dataclasses too
122112
"_type": o.__class__.__name__,
123-
"_version": getattr(o.__class__, "VERSION", 1),
113+
# not using class VERSION since we might be explicitly
114+
# setting a different VERSION for the object itself
115+
"_version": getattr(o, "VERSION", 1),
124116
"_data": data_dict,
125117
}
126118

@@ -217,7 +209,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
217209
An instance of the appropriate AOBaseConfig subclass
218210
219211
Raises:
220-
VersionMismatchError: If the stored version doesn't match the class version
221212
ValueError: If deserialization fails for other reasons
222213
"""
223214
if not isinstance(data, dict):
@@ -252,10 +243,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
252243
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
253244
)
254245

255-
# Check version - require exact match
256246
current_version = getattr(cls, "VERSION", 1)
257247
if stored_version != current_version:
258-
raise VersionMismatchError(type_path, stored_version, current_version)
248+
warnings.warn(
249+
f"Stored version is not the same as current default version of the config: {stored_version=}, {current_version=}, please check the deprecation warning"
250+
)
259251

260252
# Handle the case where obj_data is not a dictionary
261253
if not isinstance(obj_data, dict):

torchao/dtypes/floatx/float8_layout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from dataclasses import dataclass
78
from typing import Any, Dict, List, Optional, Tuple, Union
89

@@ -109,6 +110,9 @@ def __init__(
109110
transposed: bool,
110111
_layout: Layout,
111112
):
113+
warnings.warn(
114+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in March 2026 (9 months), please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details"
115+
)
112116
self.float8_data = float8_data
113117
self.scale = scale
114118
self.transposed = transposed

torchao/quantization/quant_api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,15 +1489,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14891489
Args:
14901490
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14911491
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1492-
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor
1492+
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default)
14931493
14941494
Note:
14951495
The actual matmul will be computed in original precision of the weight tensor.
14961496
"""
14971497

14981498
weight_dtype: torch.dtype = e4m3_dtype
14991499
set_inductor_config: bool = True
1500-
VERSION: int = 1
1500+
VERSION: int = 2
15011501

15021502

15031503
# for BC
@@ -1506,6 +1506,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
15061506

15071507
def _float8_weight_only_quant_tensor(weight, config):
15081508
if config.VERSION == 1:
1509+
warnings.warn(
1510+
"VERSION 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in March 2026 (9 months), please use VERSION 2, see https://github.com/pytorch/ao/issues/2649 for more details"
1511+
)
15091512
from torchao.dtypes import to_affine_quantized_floatx
15101513

15111514
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]])
@@ -1629,7 +1632,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16291632
activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale
16301633
kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight
16311634
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1632-
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor
1635+
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default)
16331636
16341637
"""
16351638

@@ -1641,7 +1644,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16411644
activation_value_ub: Optional[float] = None
16421645
kernel_preference: KernelPreference = KernelPreference.AUTO
16431646
set_inductor_config: bool = True
1644-
VERSION: int = 1
1647+
VERSION: int = 2
16451648

16461649
def __post_init__(self):
16471650
if self.mm_config is None:
@@ -1680,6 +1683,10 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16801683
)
16811684

16821685
if config.VERSION == 1:
1686+
warnings.warn(
1687+
"VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in March 2026 (9 months), please use VERSION 2, see https://github.com/pytorch/ao/issues/2649 for more details"
1688+
)
1689+
16831690
block_size = get_block_size(weight.shape[-2:], weight_granularity)
16841691
if weight.dim() == 3:
16851692
block_size = tuple([1] + list(block_size))

0 commit comments

Comments
 (0)