Skip to content

Commit e66af4f

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent 3b4bc98 commit e66af4f

File tree

8 files changed

+196
-102
lines changed

8 files changed

+196
-102
lines changed

test/core/test_config.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -151,7 +151,9 @@ def test_reconstructable_dict_file_round_trip(config):
151151
# Define a dummy config in a non-allowed module
152152
@dataclass
153153
class DummyNonAllowedConfig(AOBaseConfig):
154-
VERSION = 2
154+
# NOTE: must be `version: int` (with type annotations) to
155+
# overload the version variable from AOBaseConfig
156+
version: int = 2
155157
value: int = 42
156158

157159

@@ -172,11 +174,11 @@ def test_disallowed_modules():
172174
reconstructed = config_from_dict(reconstructable)
173175
assert isinstance(reconstructed, DummyNonAllowedConfig)
174176
assert reconstructed.value == 42
175-
assert reconstructed.VERSION == 2
177+
assert reconstructed.version == 2
176178

177179

178180
def test_version_mismatch():
179-
"""Test that version mismatch raises an error during reconstruction."""
181+
"""Test that version mismatch prints a warning during reconstruction."""
180182
# Create a config
181183
dummy_config = DummyNonAllowedConfig()
182184
reconstructable = config_to_dict(dummy_config)
@@ -186,25 +188,27 @@ def test_version_mismatch():
186188

187189
# Patch to allow the module but should still fail due to version mismatch
188190
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
189-
with pytest.raises(
190-
VersionMismatchError,
191-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
192-
):
191+
with warnings.catch_warnings(record=True) as caught_warnings:
193192
config_from_dict(reconstructable)
193+
assert any(
194+
"Stored version is not the same as current default version of the config"
195+
in str(w.message)
196+
for w in caught_warnings
197+
), "Didn't get expected warning message for version mismatch"
194198

195199

196200
def test_default_version():
197201
"""Making sure the default version for a new config inheriting from AOBaseConfig is always 1
198-
because it's the default VERSION that all children has when they haven't explicitly
199-
defined a VERSION class variable
202+
because it's the default version that all children has when they haven't explicitly
203+
defined a version class variable
200204
"""
201205

202206
@dataclass
203207
class DummyConfig(AOBaseConfig):
204208
pass
205209

206210
config = DummyConfig()
207-
assert config.VERSION == 1, "Default version must be 1"
211+
assert config.version == 1, "Default version must be 1"
208212

209213

210214
if __name__ == "__main__":

test/dtypes/test_affine_quantized_float.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@
3030
from torchao.float8.float8_utils import compute_error
3131
from torchao.quantization import (
3232
Float8DynamicActivationFloat8WeightConfig,
33-
float8_dynamic_activation_float8_weight,
34-
float8_weight_only,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3535
quantize_,
3636
)
3737
from torchao.quantization.granularity import (
3838
PerRow,
3939
PerTensor,
4040
)
41-
from torchao.quantization.quant_api import (
42-
float8_static_activation_float8_weight,
43-
)
4441
from torchao.quantization.quant_primitives import (
4542
MappingType,
4643
_choose_scale_float8,
@@ -119,11 +116,13 @@ def test_fp8_linear_variants(
119116
)
120117
mode_map = {
121118
"dynamic": partial(
122-
float8_dynamic_activation_float8_weight, granularity=granularity
119+
Float8DynamicActivationFloat8WeightConfig,
120+
granularity=granularity,
121+
version=1,
123122
),
124-
"weight-only": float8_weight_only,
123+
"weight-only": partial(Float8WeightOnlyConfig, version=1),
125124
"static": partial(
126-
float8_static_activation_float8_weight,
125+
Float8StaticActivationFloat8WeightConfig,
127126
scale=scale,
128127
granularity=granularity,
129128
),
@@ -152,7 +151,7 @@ def test_fp8_linear_variants(
152151
)
153152
def test_invalid_granularity(self):
154153
with pytest.raises(ValueError, match="Invalid granularity specification"):
155-
float8_dynamic_activation_float8_weight(granularity="invalid")
154+
Float8DynamicActivationFloat8WeightConfig(granularity="invalid")
156155

157156
@unittest.skipIf(
158157
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -162,7 +161,9 @@ def test_mismatched_granularity(self):
162161
ValueError,
163162
match="Different granularities for activation and weight are not supported",
164163
):
165-
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
164+
Float8DynamicActivationFloat8WeightConfig(
165+
granularity=(PerTensor(), PerRow())
166+
)
166167

167168
@unittest.skipIf(
168169
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -172,8 +173,8 @@ class UnsupportedGranularity:
172173
pass
173174

174175
with pytest.raises(ValueError, match="Invalid granularity types"):
175-
float8_dynamic_activation_float8_weight(
176-
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
176+
Float8DynamicActivationFloat8WeightConfig(
177+
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
177178
)
178179

179180
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -187,7 +188,8 @@ def test_per_row_with_float32(self):
187188
):
188189
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
189190
quantize_(
190-
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
191+
model,
192+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
191193
)
192194

193195
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -201,15 +203,18 @@ def test_serialization(self, mode: str):
201203

202204
mode_map = {
203205
"dynamic": partial(
204-
float8_dynamic_activation_float8_weight, granularity=PerTensor()
206+
Float8DynamicActivationFloat8WeightConfig,
207+
granularity=PerTensor(),
208+
version=1,
205209
),
206-
"weight-only": float8_weight_only,
210+
"weight-only": partial(Float8WeightOnlyConfig, version=1),
207211
"static": partial(
208-
float8_static_activation_float8_weight,
212+
Float8StaticActivationFloat8WeightConfig,
209213
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
210214
granularity=PerTensor(),
211215
),
212216
}
217+
213218
factory = mode_map[mode]()
214219
quantize_(model, factory)
215220

@@ -275,7 +280,10 @@ def test_fp8_weight_dimension_warning(self):
275280
"torchao.quantization.quant_api", level="INFO"
276281
) as log_context:
277282
quantize_(
278-
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
283+
model,
284+
Float8DynamicActivationFloat8WeightConfig(
285+
granularity=PerTensor(), version=1
286+
),
279287
)
280288
print(model)
281289

@@ -320,7 +328,8 @@ def test_mm_float8dq_per_row(
320328
)
321329
test_linear = copy.deepcopy(ref_linear)
322330
quantize_(
323-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
331+
test_linear,
332+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1),
324333
)
325334

326335
quant_weight = test_linear.weight
@@ -472,7 +481,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
472481
# Create and quantize a model
473482
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
474483
quantize_(
475-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
484+
model,
485+
Float8DynamicActivationFloat8WeightConfig(
486+
granularity=granularity, version=1
487+
),
476488
)
477489

478490
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -506,7 +518,10 @@ def test_float8_tensor_slicing_per_tensor(self):
506518
# Create and quantize with per-tensor granularity
507519
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
508520
quantize_(
509-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
521+
model,
522+
Float8DynamicActivationFloat8WeightConfig(
523+
granularity=PerTensor(), version=1
524+
),
510525
)
511526

512527
original_weight = model.weight
@@ -537,7 +552,8 @@ def test_float8_tensor_slicing_per_row(self):
537552
# Create and quantize with per-row granularity
538553
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
539554
quantize_(
540-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
555+
model,
556+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1),
541557
)
542558

543559
original_weight = model.weight # Shape: (32, 64)
@@ -575,7 +591,10 @@ def test_float8_tensor_slicing_edge_cases(self):
575591
# Create and quantize a model
576592
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
577593
quantize_(
578-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
594+
model,
595+
Float8DynamicActivationFloat8WeightConfig(
596+
granularity=PerTensor(), version=1
597+
),
579598
)
580599

581600
original_weight = model.weight
@@ -613,7 +632,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
613632
quant_model = copy.deepcopy(ref_model)
614633
quantize_(
615634
quant_model,
616-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
635+
Float8DynamicActivationFloat8WeightConfig(
636+
granularity=granularity, version=1
637+
),
617638
)
618639

619640
# Create input with batch size that works well with slicing
@@ -720,6 +741,7 @@ def test_preprocess_scale_3d_reshape(self):
720741
self.assertEqual(result.shape, expected_shape)
721742

722743
@torch.no_grad()
744+
@unittest.skip("test is flaky in CI, will turn on a bit later")
723745
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
724746
@unittest.skipIf(
725747
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
@@ -743,7 +765,13 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
743765
m = torch.nn.Sequential(
744766
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
745767
)
746-
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
768+
quantize_(
769+
m,
770+
Float8DynamicActivationFloat8WeightConfig(
771+
granularity=granularity, version=1
772+
),
773+
)
774+
747775
m = torch.compile(m, mode=torch_compile_mode)
748776
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
749777

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def test_quantize(self):
473473
m = nn.Sequential(nn.Linear(32, 32)).cuda()
474474
m = convert_to_float8_training(m)
475475
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
476-
from torchao.quantization.quant_api import float8_weight_only, quantize_
476+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
477477

478-
quantize_(m, float8_weight_only())
479-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
478+
quantize_(m, Float8WeightOnlyConfig())
479+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
480480
"Post quantization dtype should be torch.float8_e4m3fn"
481481
)
482482
with torch.no_grad():
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAME_AND_VERSIONS = [
20+
("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1),
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
28+
def test_load_model_and_run(self, model_name_and_version):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
model_name, version = model_name_and_version
32+
with warnings.catch_warnings(record=True) as caught_warnings:
33+
quantized_model = AutoModelForCausalLM.from_pretrained(
34+
model_name,
35+
torch_dtype="bfloat16",
36+
device_map="cuda",
37+
)
38+
assert any(
39+
"Stored version is not the same as current default version of the config"
40+
in str(w.message)
41+
for w in caught_warnings
42+
), "Didn't get expected warning message for version mismatch"
43+
44+
assert any(
45+
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
46+
in str(w.message)
47+
for w in caught_warnings
48+
), "Didn't get expected warning message for deprecation"
49+
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
50+
assert (
51+
quantized_model.config.quantization_config.quant_type.version == version
52+
)
53+
54+
tokenizer = AutoTokenizer.from_pretrained(model_name)
55+
prompt = ("Hello, my name is",)
56+
inputs = tokenizer(
57+
prompt,
58+
return_tensors="pt",
59+
).to("cuda")
60+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
61+
# make sure it runs
62+
_ = tokenizer.batch_decode(
63+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
64+
)
65+
66+
67+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
68+
69+
if __name__ == "__main__":
70+
run_tests()

0 commit comments

Comments
 (0)