Skip to content

Commit 5016603

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent f426b52 commit 5016603

File tree

8 files changed

+141
-75
lines changed

8 files changed

+141
-75
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -176,7 +176,7 @@ def test_disallowed_modules():
176176

177177

178178
def test_version_mismatch():
179-
"""Test that version mismatch raises an error during reconstruction."""
179+
"""Test that version mismatch prints a warning during reconstruction."""
180180
# Create a config
181181
dummy_config = DummyNonAllowedConfig()
182182
reconstructable = config_to_dict(dummy_config)
@@ -186,11 +186,13 @@ def test_version_mismatch():
186186

187187
# Patch to allow the module but should still fail due to version mismatch
188188
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
189-
with pytest.raises(
190-
VersionMismatchError,
191-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
192-
):
189+
with warnings.catch_warnings(record=True) as caught_warnings:
193190
config_from_dict(reconstructable)
191+
assert any(
192+
"Stored version is not the same as current default version of the config"
193+
in str(w.message)
194+
for w in caught_warnings
195+
), "Didn't get expected warning message for version mismatch"
194196

195197

196198
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from torchao.float8.float8_utils import compute_error
3131
from torchao.quantization import (
3232
Float8DynamicActivationFloat8WeightConfig,
33-
float8_dynamic_activation_float8_weight,
34-
float8_weight_only,
3533
quantize_,
3634
)
3735
from torchao.quantization.granularity import (
@@ -119,9 +117,9 @@ def test_fp8_linear_variants(
119117
)
120118
mode_map = {
121119
"dynamic": partial(
122-
float8_dynamic_activation_float8_weight, granularity=granularity
120+
Float8DynamicActivationFloat8WeightConfig, granularity=granularity, VERSION=1
123121
),
124-
"weight-only": float8_weight_only,
122+
"weight-only": partial(Float8WeightOnlyConfig, VERSION=1),
125123
"static": partial(
126124
float8_static_activation_float8_weight,
127125
scale=scale,
@@ -152,7 +150,7 @@ def test_fp8_linear_variants(
152150
)
153151
def test_invalid_granularity(self):
154152
with pytest.raises(ValueError, match="Invalid granularity specification"):
155-
float8_dynamic_activation_float8_weight(granularity="invalid")
153+
Float8DynamicActivationFloat8WeightConfig(granularity="invalid", VERSION=1)
156154

157155
@unittest.skipIf(
158156
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -162,7 +160,7 @@ def test_mismatched_granularity(self):
162160
ValueError,
163161
match="Different granularities for activation and weight are not supported",
164162
):
165-
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
163+
Float8DynamicActivationFloat8WeightConfig(granularity=(PerTensor(), PerRow()), VERSION=1)
166164

167165
@unittest.skipIf(
168166
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -172,8 +170,8 @@ class UnsupportedGranularity:
172170
pass
173171

174172
with pytest.raises(ValueError, match="Invalid granularity types"):
175-
float8_dynamic_activation_float8_weight(
176-
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
173+
Float8DynamicActivationFloat8WeightConfig(
174+
granularity=(UnsupportedGranularity(), UnsupportedGranularity(), VERSION=1)
177175
)
178176

179177
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -187,7 +185,7 @@ def test_per_row_with_float32(self):
187185
):
188186
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
189187
quantize_(
190-
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
188+
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1)
191189
)
192190

193191
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -201,9 +199,9 @@ def test_serialization(self, mode: str):
201199

202200
mode_map = {
203201
"dynamic": partial(
204-
float8_dynamic_activation_float8_weight, granularity=PerTensor()
202+
Float8DynamicActivationFloat8WeightConfig, granularity=PerTensor(), VERSION=1
205203
),
206-
"weight-only": float8_weight_only,
204+
"weight-only": partial(Float8WeightOnlyConfig, VERSION=1),
207205
"static": partial(
208206
float8_static_activation_float8_weight,
209207
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
@@ -275,7 +273,7 @@ def test_fp8_weight_dimension_warning(self):
275273
"torchao.quantization.quant_api", level="INFO"
276274
) as log_context:
277275
quantize_(
278-
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
276+
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor(), VERSION=1)
279277
)
280278
print(model)
281279

@@ -320,7 +318,8 @@ def test_mm_float8dq_per_row(
320318
)
321319
test_linear = copy.deepcopy(ref_linear)
322320
quantize_(
323-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
321+
test_linear,
322+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
324323
)
325324

326325
quant_weight = test_linear.weight
@@ -472,7 +471,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
472471
# Create and quantize a model
473472
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
474473
quantize_(
475-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
474+
model,
475+
Float8DynamicActivationFloat8WeightConfig(
476+
granularity=granularity, VERSION=1
477+
),
476478
)
477479

478480
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -506,7 +508,10 @@ def test_float8_tensor_slicing_per_tensor(self):
506508
# Create and quantize with per-tensor granularity
507509
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
508510
quantize_(
509-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
511+
model,
512+
Float8DynamicActivationFloat8WeightConfig(
513+
granularity=PerTensor(), VERSION=1
514+
),
510515
)
511516

512517
original_weight = model.weight
@@ -537,7 +542,8 @@ def test_float8_tensor_slicing_per_row(self):
537542
# Create and quantize with per-row granularity
538543
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
539544
quantize_(
540-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
545+
model,
546+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
541547
)
542548

543549
original_weight = model.weight # Shape: (32, 64)
@@ -575,7 +581,10 @@ def test_float8_tensor_slicing_edge_cases(self):
575581
# Create and quantize a model
576582
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
577583
quantize_(
578-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
584+
model,
585+
Float8DynamicActivationFloat8WeightConfig(
586+
granularity=PerTensor(), VERSION=1
587+
),
579588
)
580589

581590
original_weight = model.weight
@@ -613,7 +622,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
613622
quant_model = copy.deepcopy(ref_model)
614623
quantize_(
615624
quant_model,
616-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
625+
Float8DynamicActivationFloat8WeightConfig(
626+
granularity=granularity, VERSION=1
627+
),
617628
)
618629

619630
# Create input with batch size that works well with slicing
@@ -742,7 +753,7 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
742753
m = torch.nn.Sequential(
743754
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
744755
)
745-
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
756+
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity, VERSION=1))
746757
m = torch.compile(m, mode=torch_compile_mode)
747758
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
748759

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def test_quantize(self):
473473
m = nn.Sequential(nn.Linear(32, 32)).cuda()
474474
m = convert_to_float8_training(m)
475475
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
476-
from torchao.quantization.quant_api import float8_weight_only, quantize_
476+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
477477

478-
quantize_(m, float8_weight_only())
479-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
478+
quantize_(m, Float8WeightOnlyConfig())
479+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
480480
"Post quantization dtype should be torch.float8_e4m3fn"
481481
)
482482
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Didn't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Didn't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def test_fp8_linear_variants(
184184
config = Float8DynamicActivationFloat8WeightConfig(
185185
granularity=granularity,
186186
kernel_preference=kernel_preference,
187-
VERSION=2,
188187
)
189188
else:
190189
assert mode == "weight-only", f"Unsupported mode: {mode}"
@@ -210,9 +209,7 @@ def test_fp8_linear_variants(
210209
"AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later",
211210
)
212211
def test_slice(self, granularity):
213-
config = Float8DynamicActivationFloat8WeightConfig(
214-
granularity=granularity, VERSION=2
215-
)
212+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
216213
dtype = torch.bfloat16
217214
device = "cuda"
218215
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -273,9 +270,7 @@ def test_slice(self, granularity):
273270

274271
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
275272
def test_slice_preserves_aliasing(self, granularity):
276-
config = Float8DynamicActivationFloat8WeightConfig(
277-
granularity=granularity, VERSION=2
278-
)
273+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
279274
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
280275
l.weight = torch.nn.Parameter(
281276
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
@@ -296,9 +291,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
296291

297292
dtype = torch.bfloat16
298293
device = "cuda"
299-
config = Float8DynamicActivationFloat8WeightConfig(
300-
granularity=granularity, VERSION=2
301-
)
294+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
302295
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
303296
quantize_(l, config)
304297

@@ -335,9 +328,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
335328
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
336329
def test_bmm(self):
337330
# only support per row quantization
338-
config = Float8DynamicActivationFloat8WeightConfig(
339-
granularity=PerRow(), VERSION=2
340-
)
331+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
341332

342333
class M(torch.nn.Module):
343334
def __init__(self, weight):
@@ -369,9 +360,7 @@ def forward(self, x):
369360
],
370361
)
371362
def test_to_device(self, granularity, sizes):
372-
config = Float8DynamicActivationFloat8WeightConfig(
373-
granularity=granularity, VERSION=2
374-
)
363+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
375364
M, N, K = sizes
376365
dtype = torch.bfloat16
377366
for device in self.GPU_DEVICES:
@@ -401,9 +390,7 @@ def test_to_device(self, granularity, sizes):
401390
],
402391
)
403392
def test_cat(self, granularity, sizes):
404-
config = Float8DynamicActivationFloat8WeightConfig(
405-
granularity=granularity, VERSION=2
406-
)
393+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
407394
dtype = torch.bfloat16
408395
device = "cuda"
409396
M, N, K = sizes
@@ -461,9 +448,7 @@ def test_moe_weight_reshape_ops(self):
461448
dtype = torch.bfloat16
462449
device = "cuda"
463450

464-
bmm_config = Float8DynamicActivationFloat8WeightConfig(
465-
granularity=granularity, VERSION=2
466-
)
451+
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
467452
moe_config = MoEQuantConfig(bmm_config)
468453

469454
batch_size = 4

0 commit comments

Comments
 (0)