Skip to content

Commit cfa2abc

Browse files
committed
Add: Failing test
Refactor: modify_save_pretrained Signed-off-by: Rahul Tuli <[email protected]>
1 parent 90c4075 commit cfa2abc

File tree

2 files changed

+210
-102
lines changed

2 files changed

+210
-102
lines changed

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 143 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import inspect
12
import os
23
import re
3-
import weakref
44
from functools import wraps
55
from typing import Dict, Optional
66

@@ -33,115 +33,101 @@
3333
__all__ = ["modify_save_pretrained"]
3434

3535

36-
def modify_save_pretrained(model: PreTrainedModel):
36+
def modify_save_pretrained(model: PreTrainedModel) -> None:
3737
"""
3838
Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that
39-
supports compression. The new save_pretrained function performs the following saving
40-
operations:
39+
also supports compression params. The modified save_pretrained function performs the
40+
following operations:
4141
4242
1. Saves the model state, potentially in a compressed format
4343
2. Saves the recipe, appending any current recipes to existing recipe files
4444
3. Copies any necessary python files from the model cache
45-
"""
4645
47-
def save_pretrained_compressed(save_pretrained_method):
48-
if getattr(save_pretrained_method, "_overridden", False):
49-
# `model.save_pretrained` has already been replaced, return.
50-
return save_pretrained_method
51-
52-
# Keep a weak reference to the model class and unbound save_pretrained
53-
# method so we can call the original
54-
model_ref = weakref.ref(save_pretrained_method.__self__)
55-
original_save_pretrained = save_pretrained_method.__func__
56-
model_class = model_ref().__class__
57-
del save_pretrained_method
58-
59-
@wraps(original_save_pretrained)
60-
def save_pretrained_wrapper(
61-
save_directory: str,
62-
sparsity_config: Optional[SparsityCompressionConfig] = None,
63-
quantization_format: Optional[str] = None,
64-
save_compressed: bool = True,
65-
safe_serialization: bool = True,
66-
skip_sparsity_compression_stats: bool = True,
67-
disable_sparse_compression: bool = False,
68-
**kwargs,
69-
):
70-
"""
71-
Wrapper around PreTrainedModel.save_pretrained(), adds functionality for
72-
saving models in a compressed format on disk. The compression format is
73-
saved to the model's config file
74-
75-
:param save_directory: output directory to save model to
76-
:param sparsity_config: optional sparsity config to compress model with,
77-
if no config is provided it will be inferred from the model
78-
:param quantization_format: optional compression format for quantized
79-
models. If none is provided it will be inferred from the model
80-
:param save_compressed: whether or not to compress the model on disk
81-
:param skip_sparsity_compression_stats: whether to skip the calculation of
82-
sparsity statistics (such as global sparsity and sparsity structure)
83-
when saving a model in dense format
84-
:param disable_sparse_compression: whether to skip sparse compression
85-
during save, default is False
86-
:param kwargs: additional kwargs to pass on to model.save_pretrained
87-
"""
88-
89-
# HACK: Override the dtype_byte_size function in transformers to
90-
# support float8 types. Fix is posted upstream
91-
# https://github.com/huggingface/transformers/pull/30488
92-
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
93-
94-
# state_dict gets passed in as a kwarg for FSDP models
95-
state_dict = kwargs.pop("state_dict", None)
96-
if state_dict is None:
97-
logger.info("Fetching state_dict - this may take some time")
98-
state_dict = get_state_dict_offloaded_model(model)
99-
100-
logger.info("Fetching compressor")
101-
compressor = get_model_compressor(
102-
model=model,
103-
sparsity_config=sparsity_config,
104-
quantization_format=quantization_format,
105-
save_compressed=save_compressed,
106-
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
107-
state_dict=state_dict,
108-
disable_sparse_compression=disable_sparse_compression,
46+
:param model: The model whose save_pretrained method will be modified
47+
"""
48+
original = model.save_pretrained
49+
# Avoid double-wrapping if already modified
50+
if getattr(original, "_overridden", False):
51+
return
52+
53+
# Create enhanced signature with compression parameters
54+
orig_sig = inspect.signature(original)
55+
sig_with_compression_params = _create_compression_signature(orig_sig)
56+
57+
@wraps(original)
58+
def save_pretrained_wrapper(
59+
*args,
60+
sparsity_config: Optional[SparsityCompressionConfig] = None,
61+
quantization_format: Optional[str] = None,
62+
save_compressed: bool = True,
63+
skip_sparsity_compression_stats: bool = True,
64+
disable_sparse_compression: bool = False,
65+
**kwargs,
66+
):
67+
"""
68+
Wrapper around PreTrainedModel.save_pretrained() that adds compression
69+
functionality. The compression format is saved to the model's config file
70+
71+
NOTE: If adding parameters here, also update _create_compression_signature()
72+
to maintain signature consistency.
73+
74+
:param sparsity_config: Optional sparsity compression configuration.
75+
If None and `skip_sparsity_compression_stats` is False, a sparsity
76+
config will be inferred from the model.
77+
:param quantization_format: Optional format string for quantization
78+
:param save_compressed: Whether to save the model in compressed format
79+
:param skip_sparsity_compression_stats: Whether to skip calculating
80+
sparsity stats.
81+
:param disable_sparse_compression: Whether to disable sparse compression
82+
entirely
83+
"""
84+
# HACK: Override the dtype_byte_size function in transformers to
85+
# support float8 types. Fix is posted upstream
86+
# https://github.com/huggingface/transformers/pull/30488
87+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
88+
89+
# Extract save_directory from args or kwargs
90+
save_directory = args[0] if args else kwargs.get("save_directory")
91+
if save_directory is None:
92+
raise ValueError(
93+
"`save_directory` must be provided as first positional arg or kwarg"
10994
)
11095

111-
if compressor is None:
112-
# model is not compressed or quantized, save as normal
113-
original_save_pretrained_func = original_save_pretrained.__get__(
114-
model, model_class
115-
)
116-
original_save_pretrained_func(
117-
save_directory, state_dict=state_dict, **kwargs
118-
)
119-
return
120-
121-
# make sure we're on the main process when saving
122-
if state_dict is not None and len(state_dict) > 0:
123-
compressed_state_dict = compressor.compress(model, state_dict)
124-
logger.info("Saving compressed model to disk")
125-
original_save_pretrained.__get__(model, model_class)(
126-
save_directory,
127-
state_dict=compressed_state_dict,
128-
safe_serialization=safe_serialization,
129-
**kwargs,
130-
)
131-
compressor.update_config(save_directory)
132-
133-
# update existing recipe
134-
update_and_save_recipe(model.name_or_path, save_directory)
135-
136-
# copy python files from cache dir to save_path if any
137-
copy_python_files_from_model_cache(model, save_directory)
138-
139-
save_pretrained_wrapper._overriden = True
140-
return save_pretrained_wrapper
141-
142-
# wrap save_pretrained if not already
143-
if not getattr(model.save_pretrained, "_overriden", False):
144-
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
96+
# Get state_dict or fetch it if not provided
97+
state_dict = kwargs.pop("state_dict", None)
98+
if state_dict is None:
99+
logger.info("Fetching state_dict – this may take some time")
100+
state_dict = get_state_dict_offloaded_model(model)
101+
102+
logger.info("Fetching compressor")
103+
compressor = get_model_compressor(
104+
model=model,
105+
sparsity_config=sparsity_config,
106+
quantization_format=quantization_format,
107+
save_compressed=save_compressed,
108+
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
109+
state_dict=state_dict,
110+
disable_sparse_compression=disable_sparse_compression,
111+
)
112+
113+
if compressor is None:
114+
# No compression needed
115+
original(*args, state_dict=state_dict, **kwargs)
116+
else:
117+
# Compress and save
118+
compressed_state_dict = compressor.compress(model, state_dict)
119+
logger.info("Saving compressed model to disk")
120+
original(*args, state_dict=compressed_state_dict, **kwargs)
121+
compressor.update_config(save_directory)
122+
123+
# These operations happen regardless of compression
124+
update_and_save_recipe(model.name_or_path, save_directory)
125+
copy_python_files_from_model_cache(model, save_directory)
126+
127+
# Apply compression signature
128+
save_pretrained_wrapper.__signature__ = sig_with_compression_params
129+
save_pretrained_wrapper._overridden = True
130+
model.save_pretrained = save_pretrained_wrapper
145131

146132

147133
# HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -306,3 +292,59 @@ def update_and_save_recipe(model_stub: str, save_directory: str):
306292
# save recipe
307293
recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME)
308294
recipe.yaml(recipe_path)
295+
296+
297+
def _create_compression_signature(orig_sig: inspect.Signature) -> inspect.Signature:
298+
"""
299+
Creates an enhanced signature with compression parameters.
300+
301+
:param orig_sig: Original function signature
302+
:return: Enhanced signature with compression parameters
303+
"""
304+
# Define compression parameters
305+
compression_params = [
306+
inspect.Parameter(
307+
name="sparsity_config",
308+
kind=inspect.Parameter.KEYWORD_ONLY,
309+
default=None,
310+
annotation=Optional[SparsityCompressionConfig],
311+
),
312+
inspect.Parameter(
313+
name="quantization_format",
314+
kind=inspect.Parameter.KEYWORD_ONLY,
315+
default=None,
316+
annotation=Optional[str],
317+
),
318+
inspect.Parameter(
319+
name="save_compressed",
320+
kind=inspect.Parameter.KEYWORD_ONLY,
321+
default=True,
322+
annotation=bool,
323+
),
324+
inspect.Parameter(
325+
name="skip_sparsity_compression_stats",
326+
kind=inspect.Parameter.KEYWORD_ONLY,
327+
default=True,
328+
annotation=bool,
329+
),
330+
inspect.Parameter(
331+
name="disable_sparse_compression",
332+
kind=inspect.Parameter.KEYWORD_ONLY,
333+
default=False,
334+
annotation=bool,
335+
),
336+
]
337+
338+
# Only add parameters that don't exist in the original signature
339+
existing_params = orig_sig.parameters.keys()
340+
new_params = []
341+
342+
for param in orig_sig.parameters.values():
343+
if param.kind == inspect.Parameter.VAR_KEYWORD:
344+
# Add compression params before **kwargs
345+
new_params.extend(
346+
[p for p in compression_params if p.name not in existing_params]
347+
)
348+
new_params.append(param)
349+
350+
return orig_sig.replace(parameters=new_params)

tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import math
23
import os
34
import shutil
@@ -8,7 +9,11 @@
89
from accelerate.accelerator import get_state_dict_offloaded_model
910
from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat
1011
from compressed_tensors.compressors import ModelCompressor
11-
from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig
12+
from compressed_tensors.config import (
13+
BitmaskConfig,
14+
DenseSparsityConfig,
15+
SparsityCompressionConfig,
16+
)
1217
from compressed_tensors.quantization import (
1318
QuantizationConfig,
1419
QuantizationStatus,
@@ -708,3 +713,64 @@ def test_correct_compressor_inferred(
708713
)
709714
else:
710715
assert compressor.sparsity_config.format == expected_sparsity_compressor
716+
717+
718+
@pytest.mark.parametrize(
719+
"sparse_uncompressed_model",
720+
["nm-testing/llama2.c-stories15M-pruned_50.2of4-uncompressed"],
721+
)
722+
@pytest.mark.parametrize("save_compressed", [True, False])
723+
def test_modify_save_pretrained(sparse_uncompressed_model, save_compressed, tmp_path):
724+
"""
725+
Test if the `modify_save_pretrained` function correctly modifies the model's
726+
`save_pretrained` method.
727+
"""
728+
model = AutoModelForCausalLM.from_pretrained(sparse_uncompressed_model)
729+
730+
modify_save_pretrained(model)
731+
732+
# Get the actual function object (handle both bound and unbound methods)
733+
modified_func = getattr(
734+
model.save_pretrained,
735+
"__func__",
736+
model.save_pretrained,
737+
)
738+
739+
# Check that the method was properly modified
740+
assert hasattr(model, "save_pretrained")
741+
assert callable(model.save_pretrained)
742+
assert getattr(modified_func, "_overridden", True)
743+
744+
# Verify the signature contains expected compression parameters
745+
expected_params = {
746+
"sparsity_config",
747+
"quantization_format",
748+
"save_compressed",
749+
"skip_sparsity_compression_stats",
750+
"disable_sparse_compression",
751+
}
752+
sig = inspect.signature(model.save_pretrained)
753+
actual_params = set(sig.parameters.keys())
754+
755+
# Check that all expected parameters are present
756+
assert expected_params.issubset(
757+
actual_params
758+
), f"Missing parameters: {expected_params - actual_params}"
759+
760+
# Test the actual functionality
761+
save_dir = tmp_path / "compressed_model"
762+
model.save_pretrained(
763+
save_dir,
764+
save_compressed=save_compressed,
765+
skip_sparsity_compression_stats=not save_compressed,
766+
)
767+
768+
# Verify the model was saved correctly
769+
assert (save_dir / "recipe.yaml").exists()
770+
771+
# Additional checks when saving in compressed format
772+
if save_compressed:
773+
# Verify we can load a compressor from the saved model config
774+
compressor = ModelCompressor.from_pretrained(save_dir)
775+
assert compressor is not None
776+
assert isinstance(compressor.sparsity_config, SparsityCompressionConfig)

0 commit comments

Comments
 (0)