Skip to content

Commit c464d5b

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent ab9a1c3 commit c464d5b

File tree

8 files changed

+130
-62
lines changed

8 files changed

+130
-62
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -170,7 +170,7 @@ def test_disallowed_modules():
170170

171171

172172
def test_version_mismatch():
173-
"""Test that version mismatch raises an error during reconstruction."""
173+
"""Test that version mismatch prints a warning during reconstruction."""
174174
# Create a config
175175
dummy_config = DummyNonAllowedConfig()
176176
reconstructable = config_to_dict(dummy_config)
@@ -180,11 +180,13 @@ def test_version_mismatch():
180180

181181
# Patch to allow the module but should still fail due to version mismatch
182182
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
183-
with pytest.raises(
184-
VersionMismatchError,
185-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
186-
):
183+
with warnings.catch_warnings(record=True) as caught_warnings:
187184
config_from_dict(reconstructable)
185+
assert any(
186+
"Stored version is not the same as current default version of the config"
187+
in str(w.message)
188+
for w in caught_warnings
189+
), "Din't get expected warning message for version mismatch"
188190

189191

190192
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def test_mm_float8dq_per_row(
319319
)
320320
test_linear = copy.deepcopy(ref_linear)
321321
quantize_(
322-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
322+
test_linear,
323+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
323324
)
324325

325326
quant_weight = test_linear.weight
@@ -471,7 +472,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
471472
# Create and quantize a model
472473
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
473474
quantize_(
474-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
475+
model,
476+
Float8DynamicActivationFloat8WeightConfig(
477+
granularity=granularity, VERSION=1
478+
),
475479
)
476480

477481
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -505,7 +509,10 @@ def test_float8_tensor_slicing_per_tensor(self):
505509
# Create and quantize with per-tensor granularity
506510
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
507511
quantize_(
508-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
512+
model,
513+
Float8DynamicActivationFloat8WeightConfig(
514+
granularity=PerTensor(), VERSION=1
515+
),
509516
)
510517

511518
original_weight = model.weight
@@ -536,7 +543,8 @@ def test_float8_tensor_slicing_per_row(self):
536543
# Create and quantize with per-row granularity
537544
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
538545
quantize_(
539-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
546+
model,
547+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
540548
)
541549

542550
original_weight = model.weight # Shape: (32, 64)
@@ -574,7 +582,10 @@ def test_float8_tensor_slicing_edge_cases(self):
574582
# Create and quantize a model
575583
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
576584
quantize_(
577-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
585+
model,
586+
Float8DynamicActivationFloat8WeightConfig(
587+
granularity=PerTensor(), VERSION=1
588+
),
578589
)
579590

580591
original_weight = model.weight
@@ -612,7 +623,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
612623
quant_model = copy.deepcopy(ref_model)
613624
quantize_(
614625
quant_model,
615-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
626+
Float8DynamicActivationFloat8WeightConfig(
627+
granularity=granularity, VERSION=1
628+
),
616629
)
617630

618631
# Create input with batch size that works well with slicing

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,10 @@ def test_quantize(self):
477477
m = nn.Sequential(nn.Linear(32, 32)).cuda()
478478
m = convert_to_float8_training(m)
479479
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
480-
from torchao.quantization.quant_api import float8_weight_only, quantize_
480+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
481481

482-
quantize_(m, float8_weight_only())
483-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
482+
quantize_(m, Float8WeightOnlyConfig())
483+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
484484
"Post quantization dtype should be torch.float8_e4m3fn"
485485
)
486486
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Din't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Din't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def test_fp8_linear_variants(
177177
config = Float8DynamicActivationFloat8WeightConfig(
178178
granularity=granularity,
179179
kernel_preference=kernel_preference,
180-
VERSION=2,
181180
)
182181
else:
183182
assert mode == "weight-only", f"Unsupported mode: {mode}"
@@ -198,9 +197,7 @@ def test_fp8_linear_variants(
198197

199198
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
200199
def test_slice(self, granularity):
201-
config = Float8DynamicActivationFloat8WeightConfig(
202-
granularity=granularity, VERSION=2
203-
)
200+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
204201
dtype = torch.bfloat16
205202
device = "cuda"
206203
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -261,9 +258,7 @@ def test_slice(self, granularity):
261258

262259
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
263260
def test_slice_preserves_aliasing(self, granularity):
264-
config = Float8DynamicActivationFloat8WeightConfig(
265-
granularity=granularity, VERSION=2
266-
)
261+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
267262
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
268263
l.weight = torch.nn.Parameter(
269264
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
@@ -284,9 +279,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
284279

285280
dtype = torch.bfloat16
286281
device = "cuda"
287-
config = Float8DynamicActivationFloat8WeightConfig(
288-
granularity=granularity, VERSION=2
289-
)
282+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
290283
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
291284
quantize_(l, config)
292285

@@ -323,9 +316,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
323316
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
324317
def test_bmm(self):
325318
# only support per row quantization
326-
config = Float8DynamicActivationFloat8WeightConfig(
327-
granularity=PerRow(), VERSION=2
328-
)
319+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
329320

330321
class M(torch.nn.Module):
331322
def __init__(self, weight):
@@ -357,9 +348,7 @@ def forward(self, x):
357348
],
358349
)
359350
def test_to_device(self, granularity, sizes):
360-
config = Float8DynamicActivationFloat8WeightConfig(
361-
granularity=granularity, VERSION=2
362-
)
351+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
363352
M, N, K = sizes
364353
dtype = torch.bfloat16
365354
for device in self.GPU_DEVICES:
@@ -389,9 +378,7 @@ def test_to_device(self, granularity, sizes):
389378
],
390379
)
391380
def test_cat(self, granularity, sizes):
392-
config = Float8DynamicActivationFloat8WeightConfig(
393-
granularity=granularity, VERSION=2
394-
)
381+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
395382
dtype = torch.bfloat16
396383
device = "cuda"
397384
M, N, K = sizes
@@ -449,9 +436,7 @@ def test_moe_weight_reshape_ops(self):
449436
dtype = torch.bfloat16
450437
device = "cuda"
451438

452-
bmm_config = Float8DynamicActivationFloat8WeightConfig(
453-
granularity=granularity, VERSION=2
454-
)
439+
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
455440
moe_config = MoEQuantConfig(bmm_config)
456441

457442
batch_size = 4

torchao/core/config.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import enum
99
import importlib
1010
import json
11+
import warnings
1112
from typing import Any, ClassVar, Dict
1213

1314
import torch
1415

1516
__all__ = [
1617
"AOBaseConfig",
17-
"VersionMismatchError",
1818
"config_from_dict",
1919
"config_to_dict",
2020
"ALLOWED_AO_MODULES",
@@ -61,20 +61,6 @@ def _transform(
6161
VERSION: ClassVar[int] = _DEFAULT_VERSION
6262

6363

64-
class VersionMismatchError(Exception):
65-
"""Raised when trying to deserialize a config with a different version"""
66-
67-
def __init__(self, type_path, stored_version, current_version):
68-
self.type_path = type_path
69-
self.stored_version = stored_version
70-
self.current_version = current_version
71-
message = (
72-
f"Version mismatch for {type_path}: "
73-
f"stored version {stored_version} != current version {current_version}"
74-
)
75-
super().__init__(message)
76-
77-
7864
class ConfigJSONEncoder(json.JSONEncoder):
7965
"""Custom JSON encoder for AOBaseConfig objects"""
8066

@@ -91,7 +77,9 @@ def default(self, o):
9177
return {
9278
# Only store the class name, not the full module path
9379
"_type": o.__class__.__name__,
94-
"_version": getattr(o.__class__, "VERSION", 1),
80+
# not using class VERSION since we might be explicitly
81+
# setting a different VERSION for the object itself
82+
"_version": getattr(o, "VERSION", 1),
9583
"_data": data_dict,
9684
}
9785

@@ -105,7 +93,9 @@ def default(self, o):
10593

10694
return {
10795
"_type": o.__class__.__name__,
108-
"_version": getattr(o.__class__, "VERSION", 1),
96+
# not using class VERSION since we might be explicitly
97+
# setting a different VERSION for the object itself
98+
"_version": getattr(o, "VERSION", 1),
10999
"_data": processed_data,
110100
}
111101

@@ -120,7 +110,9 @@ def default(self, o):
120110
return {
121111
# Only store the class name for dataclasses too
122112
"_type": o.__class__.__name__,
123-
"_version": getattr(o.__class__, "VERSION", 1),
113+
# not using class VERSION since we might be explicitly
114+
# setting a different VERSION for the object itself
115+
"_version": getattr(o, "VERSION", 1),
124116
"_data": data_dict,
125117
}
126118

@@ -217,7 +209,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
217209
An instance of the appropriate AOBaseConfig subclass
218210
219211
Raises:
220-
VersionMismatchError: If the stored version doesn't match the class version
221212
ValueError: If deserialization fails for other reasons
222213
"""
223214
if not isinstance(data, dict):
@@ -252,10 +243,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
252243
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
253244
)
254245

255-
# Check version - require exact match
256246
current_version = getattr(cls, "VERSION", 1)
257247
if stored_version != current_version:
258-
raise VersionMismatchError(type_path, stored_version, current_version)
248+
warnings.warn(
249+
f"Stored version is not the same as current default version of the config: {stored_version=}, {current_version=}, please check the deprecation warning"
250+
)
259251

260252
# Handle the case where obj_data is not a dictionary
261253
if not isinstance(obj_data, dict):

torchao/dtypes/floatx/float8_layout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from dataclasses import dataclass
78
from typing import Any, Dict, List, Optional, Tuple, Union
89

@@ -109,6 +110,9 @@ def __init__(
109110
transposed: bool,
110111
_layout: Layout,
111112
):
113+
warnings.warn(
114+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details"
115+
)
112116
self.float8_data = float8_data
113117
self.scale = scale
114118
self.transposed = transposed

0 commit comments

Comments
 (0)