Skip to content

Commit e19cb46

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags:
1 parent fdcb0c4 commit e19cb46

File tree

7 files changed

+120
-34
lines changed

7 files changed

+120
-34
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def test_mm_float8dq_per_row(
319319
)
320320
test_linear = copy.deepcopy(ref_linear)
321321
quantize_(
322-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
322+
test_linear,
323+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
323324
)
324325

325326
quant_weight = test_linear.weight
@@ -471,7 +472,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
471472
# Create and quantize a model
472473
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
473474
quantize_(
474-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
475+
model,
476+
Float8DynamicActivationFloat8WeightConfig(
477+
granularity=granularity, VERSION=1
478+
),
475479
)
476480

477481
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -505,7 +509,10 @@ def test_float8_tensor_slicing_per_tensor(self):
505509
# Create and quantize with per-tensor granularity
506510
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
507511
quantize_(
508-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
512+
model,
513+
Float8DynamicActivationFloat8WeightConfig(
514+
granularity=PerTensor(), VERSION=1
515+
),
509516
)
510517

511518
original_weight = model.weight
@@ -536,7 +543,8 @@ def test_float8_tensor_slicing_per_row(self):
536543
# Create and quantize with per-row granularity
537544
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
538545
quantize_(
539-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
546+
model,
547+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
540548
)
541549

542550
original_weight = model.weight # Shape: (32, 64)
@@ -574,7 +582,10 @@ def test_float8_tensor_slicing_edge_cases(self):
574582
# Create and quantize a model
575583
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
576584
quantize_(
577-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
585+
model,
586+
Float8DynamicActivationFloat8WeightConfig(
587+
granularity=PerTensor(), VERSION=1
588+
),
578589
)
579590

580591
original_weight = model.weight
@@ -612,7 +623,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
612623
quant_model = copy.deepcopy(ref_model)
613624
quantize_(
614625
quant_model,
615-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
626+
Float8DynamicActivationFloat8WeightConfig(
627+
granularity=granularity, VERSION=1
628+
),
616629
)
617630

618631
# Create input with batch size that works well with slicing

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,10 @@ def test_quantize(self):
477477
m = nn.Sequential(nn.Linear(32, 32)).cuda()
478478
m = convert_to_float8_training(m)
479479
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
480-
from torchao.quantization.quant_api import float8_weight_only, quantize_
480+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
481481

482-
quantize_(m, float8_weight_only())
483-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
482+
quantize_(m, Float8WeightOnlyConfig())
483+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
484484
"Post quantization dtype should be torch.float8_e4m3fn"
485485
)
486486
with torch.no_grad():
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
got_expected_message = any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
)
42+
assert got_expected_message, "Din't get expected message"
43+
44+
got_expected_message = any(
45+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
46+
in str(w.message)
47+
for w in caught_warnings
48+
)
49+
assert got_expected_message, "Din't get expected message"
50+
51+
tokenizer = AutoTokenizer.from_pretrained(model_name)
52+
prompt = ("Hello, my name is",)
53+
inputs = tokenizer(
54+
prompt,
55+
return_tensors="pt",
56+
).to("cuda")
57+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
58+
# make sure it runs
59+
_ = tokenizer.batch_decode(
60+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
61+
)
62+
63+
64+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
65+
66+
if __name__ == "__main__":
67+
run_tests()

torchao/core/config.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import enum
99
import importlib
1010
import json
11+
import warnings
1112
from typing import Any, ClassVar, Dict
1213

1314
import torch
1415

1516
__all__ = [
1617
"AOBaseConfig",
17-
"VersionMismatchError",
1818
"config_from_dict",
1919
"config_to_dict",
2020
"ALLOWED_AO_MODULES",
@@ -50,20 +50,6 @@ def _transform(
5050
VERSION: ClassVar[int] = 1
5151

5252

53-
class VersionMismatchError(Exception):
54-
"""Raised when trying to deserialize a config with a different version"""
55-
56-
def __init__(self, type_path, stored_version, current_version):
57-
self.type_path = type_path
58-
self.stored_version = stored_version
59-
self.current_version = current_version
60-
message = (
61-
f"Version mismatch for {type_path}: "
62-
f"stored version {stored_version} != current version {current_version}"
63-
)
64-
super().__init__(message)
65-
66-
6753
class ConfigJSONEncoder(json.JSONEncoder):
6854
"""Custom JSON encoder for AOBaseConfig objects"""
6955

@@ -80,7 +66,9 @@ def default(self, o):
8066
return {
8167
# Only store the class name, not the full module path
8268
"_type": o.__class__.__name__,
83-
"_version": getattr(o.__class__, "VERSION", 1),
69+
# not using class VERSION since we might be explicitly
70+
# setting a different VERSION for the object itself
71+
"_version": getattr(o, "VERSION", 1),
8472
"_data": data_dict,
8573
}
8674

@@ -94,7 +82,9 @@ def default(self, o):
9482

9583
return {
9684
"_type": o.__class__.__name__,
97-
"_version": getattr(o.__class__, "VERSION", 1),
85+
# not using class VERSION since we might be explicitly
86+
# setting a different VERSION for the object itself
87+
"_version": getattr(o, "VERSION", 1),
9888
"_data": processed_data,
9989
}
10090

@@ -109,7 +99,9 @@ def default(self, o):
10999
return {
110100
# Only store the class name for dataclasses too
111101
"_type": o.__class__.__name__,
112-
"_version": getattr(o.__class__, "VERSION", 1),
102+
# not using class VERSION since we might be explicitly
103+
# setting a different VERSION for the object itself
104+
"_version": getattr(o, "VERSION", 1),
113105
"_data": data_dict,
114106
}
115107

@@ -206,7 +198,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
206198
An instance of the appropriate AOBaseConfig subclass
207199
208200
Raises:
209-
VersionMismatchError: If the stored version doesn't match the class version
210201
ValueError: If deserialization fails for other reasons
211202
"""
212203
if not isinstance(data, dict):
@@ -241,10 +232,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
241232
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
242233
)
243234

244-
# Check version - require exact match
245235
current_version = getattr(cls, "VERSION", 1)
246236
if stored_version != current_version:
247-
raise VersionMismatchError(type_path, stored_version, current_version)
237+
warnings.warn(
238+
f"Stored version is not the same as current default version of the config: {stored_version=}, {current_version=}, please check the deprecation warning"
239+
)
248240

249241
# Handle the case where obj_data is not a dictionary
250242
if not isinstance(obj_data, dict):

torchao/dtypes/floatx/float8_layout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from dataclasses import dataclass
78
from typing import Any, Dict, List, Optional, Tuple, Union
89

@@ -109,6 +110,9 @@ def __init__(
109110
transposed: bool,
110111
_layout: Layout,
111112
):
113+
warnings.warn(
114+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in March 2026 (9 months), please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details"
115+
)
112116
self.float8_data = float8_data
113117
self.scale = scale
114118
self.transposed = transposed

torchao/quantization/quant_api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,15 +1489,15 @@ class Float8WeightOnlyConfig(AOBaseConfig):
14891489
Args:
14901490
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
14911491
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1492-
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor
1492+
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default)
14931493
14941494
Note:
14951495
The actual matmul will be computed in original precision of the weight tensor.
14961496
"""
14971497

14981498
weight_dtype: torch.dtype = e4m3_dtype
14991499
set_inductor_config: bool = True
1500-
VERSION: int = 1
1500+
VERSION: int = 2
15011501

15021502

15031503
# for BC
@@ -1506,6 +1506,9 @@ class Float8WeightOnlyConfig(AOBaseConfig):
15061506

15071507
def _float8_weight_only_quant_tensor(weight, config):
15081508
if config.VERSION == 1:
1509+
warnings.warn(
1510+
"VERSION 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in March 2026 (9 months), please use VERSION 2, see https://github.com/pytorch/ao/issues/2649 for more details"
1511+
)
15091512
from torchao.dtypes import to_affine_quantized_floatx
15101513

15111514
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]])
@@ -1629,7 +1632,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16291632
activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale
16301633
kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight
16311634
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1632-
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor
1635+
VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default)
16331636
16341637
"""
16351638

@@ -1641,7 +1644,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
16411644
activation_value_ub: Optional[float] = None
16421645
kernel_preference: KernelPreference = KernelPreference.AUTO
16431646
set_inductor_config: bool = True
1644-
VERSION: int = 1
1647+
VERSION: int = 2
16451648

16461649
def __post_init__(self):
16471650
if self.mm_config is None:
@@ -1680,6 +1683,10 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16801683
)
16811684

16821685
if config.VERSION == 1:
1686+
warnings.warn(
1687+
"VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in March 2026 (9 months), please use VERSION 2, see https://github.com/pytorch/ao/issues/2649 for more details"
1688+
)
1689+
16831690
block_size = get_block_size(weight.shape[-2:], weight_granularity)
16841691
if weight.dim() == 3:
16851692
block_size = tuple([1] + list(block_size))

torchao/quantization/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,6 @@ def recommended_inductor_config_setter():
681681
torch._inductor.config.fx_graph_cache = True
682682
torch._inductor.config.triton.unique_kernel_names = True
683683
torch.set_float32_matmul_precision("high")
684+
685+
686+
AQT_FLOAT8_DEPRECATION_WARNING = ""

0 commit comments

Comments
 (0)