Skip to content

Commit 99a86bc

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent cb65283 commit 99a86bc

File tree

8 files changed

+130
-62
lines changed

8 files changed

+130
-62
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -176,7 +176,7 @@ def test_disallowed_modules():
176176

177177

178178
def test_version_mismatch():
179-
"""Test that version mismatch raises an error during reconstruction."""
179+
"""Test that version mismatch prints a warning during reconstruction."""
180180
# Create a config
181181
dummy_config = DummyNonAllowedConfig()
182182
reconstructable = config_to_dict(dummy_config)
@@ -186,11 +186,13 @@ def test_version_mismatch():
186186

187187
# Patch to allow the module but should still fail due to version mismatch
188188
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
189-
with pytest.raises(
190-
VersionMismatchError,
191-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
192-
):
189+
with warnings.catch_warnings(record=True) as caught_warnings:
193190
config_from_dict(reconstructable)
191+
assert any(
192+
"Stored version is not the same as current default version of the config"
193+
in str(w.message)
194+
for w in caught_warnings
195+
), "Din't get expected warning message for version mismatch"
194196

195197

196198
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def test_mm_float8dq_per_row(
320320
)
321321
test_linear = copy.deepcopy(ref_linear)
322322
quantize_(
323-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
323+
test_linear,
324+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
324325
)
325326

326327
quant_weight = test_linear.weight
@@ -472,7 +473,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
472473
# Create and quantize a model
473474
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
474475
quantize_(
475-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
476+
model,
477+
Float8DynamicActivationFloat8WeightConfig(
478+
granularity=granularity, VERSION=1
479+
),
476480
)
477481

478482
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -506,7 +510,10 @@ def test_float8_tensor_slicing_per_tensor(self):
506510
# Create and quantize with per-tensor granularity
507511
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
508512
quantize_(
509-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
513+
model,
514+
Float8DynamicActivationFloat8WeightConfig(
515+
granularity=PerTensor(), VERSION=1
516+
),
510517
)
511518

512519
original_weight = model.weight
@@ -537,7 +544,8 @@ def test_float8_tensor_slicing_per_row(self):
537544
# Create and quantize with per-row granularity
538545
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
539546
quantize_(
540-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
547+
model,
548+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
541549
)
542550

543551
original_weight = model.weight # Shape: (32, 64)
@@ -575,7 +583,10 @@ def test_float8_tensor_slicing_edge_cases(self):
575583
# Create and quantize a model
576584
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
577585
quantize_(
578-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
586+
model,
587+
Float8DynamicActivationFloat8WeightConfig(
588+
granularity=PerTensor(), VERSION=1
589+
),
579590
)
580591

581592
original_weight = model.weight
@@ -613,7 +624,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
613624
quant_model = copy.deepcopy(ref_model)
614625
quantize_(
615626
quant_model,
616-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
627+
Float8DynamicActivationFloat8WeightConfig(
628+
granularity=granularity, VERSION=1
629+
),
617630
)
618631

619632
# Create input with batch size that works well with slicing

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def test_quantize(self):
473473
m = nn.Sequential(nn.Linear(32, 32)).cuda()
474474
m = convert_to_float8_training(m)
475475
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
476-
from torchao.quantization.quant_api import float8_weight_only, quantize_
476+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
477477

478-
quantize_(m, float8_weight_only())
479-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
478+
quantize_(m, Float8WeightOnlyConfig())
479+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
480480
"Post quantization dtype should be torch.float8_e4m3fn"
481481
)
482482
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Din't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Din't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def test_fp8_linear_variants(
184184
config = Float8DynamicActivationFloat8WeightConfig(
185185
granularity=granularity,
186186
kernel_preference=kernel_preference,
187-
VERSION=2,
188187
)
189188
else:
190189
assert mode == "weight-only", f"Unsupported mode: {mode}"
@@ -210,9 +209,7 @@ def test_fp8_linear_variants(
210209
"AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later",
211210
)
212211
def test_slice(self, granularity):
213-
config = Float8DynamicActivationFloat8WeightConfig(
214-
granularity=granularity, VERSION=2
215-
)
212+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
216213
dtype = torch.bfloat16
217214
device = "cuda"
218215
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -273,9 +270,7 @@ def test_slice(self, granularity):
273270

274271
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
275272
def test_slice_preserves_aliasing(self, granularity):
276-
config = Float8DynamicActivationFloat8WeightConfig(
277-
granularity=granularity, VERSION=2
278-
)
273+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
279274
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
280275
l.weight = torch.nn.Parameter(
281276
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
@@ -296,9 +291,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
296291

297292
dtype = torch.bfloat16
298293
device = "cuda"
299-
config = Float8DynamicActivationFloat8WeightConfig(
300-
granularity=granularity, VERSION=2
301-
)
294+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
302295
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
303296
quantize_(l, config)
304297

@@ -335,9 +328,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
335328
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
336329
def test_bmm(self):
337330
# only support per row quantization
338-
config = Float8DynamicActivationFloat8WeightConfig(
339-
granularity=PerRow(), VERSION=2
340-
)
331+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
341332

342333
class M(torch.nn.Module):
343334
def __init__(self, weight):
@@ -369,9 +360,7 @@ def forward(self, x):
369360
],
370361
)
371362
def test_to_device(self, granularity, sizes):
372-
config = Float8DynamicActivationFloat8WeightConfig(
373-
granularity=granularity, VERSION=2
374-
)
363+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
375364
M, N, K = sizes
376365
dtype = torch.bfloat16
377366
for device in self.GPU_DEVICES:
@@ -401,9 +390,7 @@ def test_to_device(self, granularity, sizes):
401390
],
402391
)
403392
def test_cat(self, granularity, sizes):
404-
config = Float8DynamicActivationFloat8WeightConfig(
405-
granularity=granularity, VERSION=2
406-
)
393+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
407394
dtype = torch.bfloat16
408395
device = "cuda"
409396
M, N, K = sizes
@@ -461,9 +448,7 @@ def test_moe_weight_reshape_ops(self):
461448
dtype = torch.bfloat16
462449
device = "cuda"
463450

464-
bmm_config = Float8DynamicActivationFloat8WeightConfig(
465-
granularity=granularity, VERSION=2
466-
)
451+
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
467452
moe_config = MoEQuantConfig(bmm_config)
468453

469454
batch_size = 4

torchao/core/config.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import enum
99
import importlib
1010
import json
11+
import warnings
1112
from typing import Any, ClassVar, Dict
1213

1314
import torch
1415

1516
__all__ = [
1617
"AOBaseConfig",
17-
"VersionMismatchError",
1818
"config_from_dict",
1919
"config_to_dict",
2020
"ALLOWED_AO_MODULES",
@@ -61,20 +61,6 @@ def _transform(
6161
VERSION: ClassVar[int] = _DEFAULT_VERSION
6262

6363

64-
class VersionMismatchError(Exception):
65-
"""Raised when trying to deserialize a config with a different version"""
66-
67-
def __init__(self, type_path, stored_version, current_version):
68-
self.type_path = type_path
69-
self.stored_version = stored_version
70-
self.current_version = current_version
71-
message = (
72-
f"Version mismatch for {type_path}: "
73-
f"stored version {stored_version} != current version {current_version}"
74-
)
75-
super().__init__(message)
76-
77-
7864
class ConfigJSONEncoder(json.JSONEncoder):
7965
"""Custom JSON encoder for AOBaseConfig objects"""
8066

@@ -91,7 +77,9 @@ def default(self, o):
9177
return {
9278
# Only store the class name, not the full module path
9379
"_type": o.__class__.__name__,
94-
"_version": getattr(o.__class__, "VERSION", 1),
80+
# not using class VERSION since we might be explicitly
81+
# setting a different VERSION for the object itself
82+
"_version": getattr(o, "VERSION", 1),
9583
"_data": data_dict,
9684
}
9785

@@ -105,7 +93,9 @@ def default(self, o):
10593

10694
return {
10795
"_type": o.__class__.__name__,
108-
"_version": getattr(o.__class__, "VERSION", 1),
96+
# not using class VERSION since we might be explicitly
97+
# setting a different VERSION for the object itself
98+
"_version": getattr(o, "VERSION", 1),
10999
"_data": processed_data,
110100
}
111101

@@ -120,7 +110,9 @@ def default(self, o):
120110
return {
121111
# Only store the class name for dataclasses too
122112
"_type": o.__class__.__name__,
123-
"_version": getattr(o.__class__, "VERSION", 1),
113+
# not using class VERSION since we might be explicitly
114+
# setting a different VERSION for the object itself
115+
"_version": getattr(o, "VERSION", 1),
124116
"_data": data_dict,
125117
}
126118

@@ -218,7 +210,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
218210
An instance of the appropriate AOBaseConfig subclass
219211
220212
Raises:
221-
VersionMismatchError: If the stored version doesn't match the class version
222213
ValueError: If deserialization fails for other reasons
223214
"""
224215
if not isinstance(data, dict):
@@ -253,10 +244,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
253244
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
254245
)
255246

256-
# Check version - require exact match
257247
current_version = getattr(cls, "VERSION", 1)
258248
if stored_version != current_version:
259-
raise VersionMismatchError(type_path, stored_version, current_version)
249+
warnings.warn(
250+
f"Stored version is not the same as current default version of the config: {stored_version=}, {current_version=}, please check the deprecation warning"
251+
)
260252

261253
# Handle the case where obj_data is not a dictionary
262254
if not isinstance(obj_data, dict):

torchao/dtypes/floatx/float8_layout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from dataclasses import dataclass
78
from typing import Any, Dict, List, Optional, Tuple, Union
89

@@ -109,6 +110,9 @@ def __init__(
109110
transposed: bool,
110111
_layout: Layout,
111112
):
113+
warnings.warn(
114+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details"
115+
)
112116
self.float8_data = float8_data
113117
self.scale = scale
114118
self.transposed = transposed

0 commit comments

Comments
 (0)