-
Notifications
You must be signed in to change notification settings - Fork 248
Enhance save_pretrained #1376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
rahul-tuli
wants to merge
4
commits into
main
Choose a base branch
from
save-pretrained-updates
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Enhance save_pretrained #1376
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
76fb370
Add: Failing test
rahul-tuli 94f0b62
Add: save_pretrained readme
rahul-tuli b403c25
Update src/llmcompressor/transformers/sparsification/compressed_tenso…
rahul-tuli ad7e45d
Update tests/llmcompressor/transformers/sparsification/test_compress_…
rahul-tuli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Enhanced `save_pretrained` Arguments | ||
|
||
The `llmcompressor` library extends Hugging Face's `save_pretrained` method with additional arguments to support model compression functionality. This document explains these extra arguments and how to use them effectively. | ||
|
||
## How It Works | ||
|
||
When you import `llmcompressor`, it automatically wraps the model's original `save_pretrained` method with an enhanced version that supports compression. This happens in two ways: | ||
|
||
1. **Direct modification**: When you call `modify_save_pretrained(model)` directly | ||
2. **Automatic wrapping**: When you call `oneshot(...)`, which wraps `save_pretrained` under the hood | ||
|
||
This means that after applying compression with `oneshot`, your model's `save_pretrained` method is already enhanced with compression capabilities, and you can use the additional arguments described below. | ||
|
||
## Additional Arguments | ||
|
||
When saving your compressed models, you can use the following extra arguments with the `save_pretrained` method: | ||
|
||
| Parameter | Type | Default | Description | | ||
|-----------|------|---------|-------------| | ||
| `sparsity_config` | `Optional[SparsityCompressionConfig]` | `None` | Optional configuration for sparsity compression. If None and `skip_sparsity_compression_stats` is False, configuration will be automatically inferred from the model. | | ||
| `quantization_format` | `Optional[str]` | `None` | Optional format string for quantization. If not provided, it will be inferred from the model. | | ||
| `save_compressed` | `bool` | `True` | Controls whether to save the model in a compressed format. Set to `False` to save in the original dense format. | | ||
| `skip_sparsity_compression_stats` | `bool` | `True` | Controls whether to skip calculating sparsity statistics (e.g., global sparsity and structure) when saving the model. Set to `False` to include these statistics. | | ||
| `disable_sparse_compression` | `bool` | `False` | When set to `True`, skips any sparse compression during save, even if the model has been previously compressed. | | ||
|
||
## Examples | ||
|
||
### Applying Compression with oneshot | ||
|
||
The simplest approach is to use `oneshot`, which handles both compression and wrapping `save_pretrained`: | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
|
||
# Load model | ||
model = AutoModelForCausalLM.from_pretrained("your-model") | ||
tokenizer = AutoTokenizer.from_pretrained("your-model") | ||
|
||
# Apply compression - this also wraps save_pretrained | ||
oneshot( | ||
model=model, | ||
recipe=[GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"])], | ||
# Other oneshot parameters... | ||
) | ||
|
||
# Now you can use the enhanced save_pretrained | ||
SAVE_DIR = "your-model-W8A8-compressed" | ||
model.save_pretrained( | ||
SAVE_DIR, | ||
save_compressed=True # Use the enhanced functionality | ||
) | ||
tokenizer.save_pretrained(SAVE_DIR) | ||
``` | ||
|
||
### Manual Approach (Without oneshot) | ||
|
||
If you need more control, you can wrap `save_pretrained` manually: | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM | ||
from llmcompressor.transformers.sparsification import modify_save_pretrained | ||
|
||
# Load model | ||
model = AutoModelForCausalLM.from_pretrained("your-model") | ||
|
||
# Manually wrap save_pretrained | ||
modify_save_pretrained(model) | ||
|
||
# Now you can use the enhanced save_pretrained | ||
model.save_pretrained( | ||
"your-model-path", | ||
save_compressed=True, | ||
skip_sparsity_compression_stats=False # to infer sparsity config | ||
) | ||
``` | ||
|
||
### Saving with Custom Sparsity Configuration | ||
|
||
```python | ||
from compressed_tensors.sparsification import SparsityCompressionConfig | ||
|
||
# Create custom sparsity config | ||
custom_config = SparsityCompressionConfig( | ||
format="2:4", | ||
block_size=16 | ||
) | ||
|
||
# Save with custom config | ||
model.save_pretrained( | ||
"your-model-custom-sparse", | ||
sparsity_config=custom_config, | ||
) | ||
``` | ||
|
||
## Notes | ||
|
||
- When loading compressed models with `from_pretrained`, the compression format is automatically detected. | ||
- To use compressed models with vLLM, simply load them as you would any model: | ||
```python | ||
from vllm import LLM | ||
model = LLM("./your-model-compressed") | ||
``` | ||
- Compression configurations are saved in the model's config file and are automatically applied when loading. | ||
|
||
For more information about compression algorithms and formats, please refer to the documentation and examples in the llmcompressor repository. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import inspect | ||
import os | ||
import re | ||
import weakref | ||
from functools import wraps | ||
from typing import Dict, Optional | ||
|
||
|
@@ -33,115 +33,104 @@ | |
__all__ = ["modify_save_pretrained"] | ||
|
||
|
||
def modify_save_pretrained(model: PreTrainedModel): | ||
def modify_save_pretrained(model: PreTrainedModel) -> None: | ||
""" | ||
Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that | ||
supports compression. The new save_pretrained function performs the following saving | ||
operations: | ||
also supports compression params. The modified save_pretrained function performs the | ||
following operations: | ||
|
||
1. Saves the model state, potentially in a compressed format | ||
2. Saves the recipe, appending any current recipes to existing recipe files | ||
3. Copies any necessary python files from the model cache | ||
""" | ||
|
||
def save_pretrained_compressed(save_pretrained_method): | ||
if getattr(save_pretrained_method, "_overridden", False): | ||
# `model.save_pretrained` has already been replaced, return. | ||
return save_pretrained_method | ||
|
||
# Keep a weak reference to the model class and unbound save_pretrained | ||
# method so we can call the original | ||
model_ref = weakref.ref(save_pretrained_method.__self__) | ||
original_save_pretrained = save_pretrained_method.__func__ | ||
model_class = model_ref().__class__ | ||
del save_pretrained_method | ||
|
||
@wraps(original_save_pretrained) | ||
def save_pretrained_wrapper( | ||
save_directory: str, | ||
sparsity_config: Optional[SparsityCompressionConfig] = None, | ||
quantization_format: Optional[str] = None, | ||
save_compressed: bool = True, | ||
safe_serialization: bool = True, | ||
skip_sparsity_compression_stats: bool = True, | ||
disable_sparse_compression: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Wrapper around PreTrainedModel.save_pretrained(), adds functionality for | ||
saving models in a compressed format on disk. The compression format is | ||
saved to the model's config file | ||
|
||
:param save_directory: output directory to save model to | ||
:param sparsity_config: optional sparsity config to compress model with, | ||
if no config is provided it will be inferred from the model | ||
:param quantization_format: optional compression format for quantized | ||
models. If none is provided it will be inferred from the model | ||
:param save_compressed: whether or not to compress the model on disk | ||
:param skip_sparsity_compression_stats: whether to skip the calculation of | ||
sparsity statistics (such as global sparsity and sparsity structure) | ||
when saving a model in dense format | ||
:param disable_sparse_compression: whether to skip sparse compression | ||
during save, default is False | ||
:param kwargs: additional kwargs to pass on to model.save_pretrained | ||
""" | ||
|
||
# HACK: Override the dtype_byte_size function in transformers to | ||
# support float8 types. Fix is posted upstream | ||
# https://github.com/huggingface/transformers/pull/30488 | ||
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size | ||
|
||
# state_dict gets passed in as a kwarg for FSDP models | ||
state_dict = kwargs.pop("state_dict", None) | ||
if state_dict is None: | ||
logger.info("Fetching state_dict - this may take some time") | ||
state_dict = get_state_dict_offloaded_model(model) | ||
|
||
logger.info("Fetching compressor") | ||
compressor = get_model_compressor( | ||
model=model, | ||
sparsity_config=sparsity_config, | ||
quantization_format=quantization_format, | ||
save_compressed=save_compressed, | ||
skip_sparsity_compression_stats=skip_sparsity_compression_stats, | ||
state_dict=state_dict, | ||
disable_sparse_compression=disable_sparse_compression, | ||
For more information on the compression parameters and model saving in | ||
llmcompressor, refer to docs/save_pretrained.md | ||
|
||
:param model: The model whose save_pretrained method will be modified | ||
""" | ||
original = model.save_pretrained | ||
# Avoid double-wrapping if already modified | ||
if getattr(original, "_overridden", False): | ||
return | ||
|
||
# Create enhanced signature with compression parameters | ||
orig_sig = inspect.signature(original) | ||
sig_with_compression_params = _create_compression_signature(orig_sig) | ||
|
||
@wraps(original) | ||
def save_pretrained_wrapper( | ||
*args, | ||
sparsity_config: Optional[SparsityCompressionConfig] = None, | ||
quantization_format: Optional[str] = None, | ||
save_compressed: bool = True, | ||
skip_sparsity_compression_stats: bool = True, | ||
disable_sparse_compression: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Wrapper around PreTrainedModel.save_pretrained() that adds compression | ||
functionality. The compression format is saved to the model's config file | ||
|
||
NOTE: If adding parameters here, also update _create_compression_signature() | ||
to maintain signature consistency. | ||
|
||
:param sparsity_config: Optional sparsity compression configuration. | ||
If None and `skip_sparsity_compression_stats` is False, a sparsity | ||
config will be inferred from the model. | ||
:param quantization_format: Optional format string for quantization | ||
:param save_compressed: Whether to save the model in compressed format | ||
:param skip_sparsity_compression_stats: Whether to skip calculating | ||
sparsity stats. | ||
:param disable_sparse_compression: Whether to disable sparse compression | ||
entirely | ||
""" | ||
# HACK: Override the dtype_byte_size function in transformers to | ||
# support float8 types. Fix is posted upstream | ||
# https://github.com/huggingface/transformers/pull/30488 | ||
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size | ||
|
||
# Extract save_directory from args or kwargs | ||
save_directory = args[0] if args else kwargs.get("save_directory") | ||
if save_directory is None: | ||
raise ValueError( | ||
"`save_directory` must be provided as first positional arg or kwarg" | ||
) | ||
|
||
if compressor is None: | ||
# model is not compressed or quantized, save as normal | ||
original_save_pretrained_func = original_save_pretrained.__get__( | ||
model, model_class | ||
) | ||
original_save_pretrained_func( | ||
save_directory, state_dict=state_dict, **kwargs | ||
) | ||
return | ||
|
||
# make sure we're on the main process when saving | ||
if state_dict is not None and len(state_dict) > 0: | ||
compressed_state_dict = compressor.compress(model, state_dict) | ||
logger.info("Saving compressed model to disk") | ||
original_save_pretrained.__get__(model, model_class)( | ||
save_directory, | ||
state_dict=compressed_state_dict, | ||
safe_serialization=safe_serialization, | ||
**kwargs, | ||
) | ||
compressor.update_config(save_directory) | ||
|
||
# update existing recipe | ||
update_and_save_recipe(model.name_or_path, save_directory) | ||
|
||
# copy python files from cache dir to save_path if any | ||
copy_python_files_from_model_cache(model, save_directory) | ||
|
||
save_pretrained_wrapper._overriden = True | ||
return save_pretrained_wrapper | ||
|
||
# wrap save_pretrained if not already | ||
if not getattr(model.save_pretrained, "_overriden", False): | ||
model.save_pretrained = save_pretrained_compressed(model.save_pretrained) | ||
# Get state_dict or fetch it if not provided | ||
state_dict = kwargs.pop("state_dict", None) | ||
if state_dict is None: | ||
logger.info("Fetching state_dict – this may take some time") | ||
state_dict = get_state_dict_offloaded_model(model) | ||
|
||
logger.info("Fetching compressor") | ||
compressor = get_model_compressor( | ||
model=model, | ||
sparsity_config=sparsity_config, | ||
quantization_format=quantization_format, | ||
save_compressed=save_compressed, | ||
skip_sparsity_compression_stats=skip_sparsity_compression_stats, | ||
state_dict=state_dict, | ||
disable_sparse_compression=disable_sparse_compression, | ||
) | ||
|
||
if compressor is None: | ||
# No compression needed | ||
original(*args, state_dict=state_dict, **kwargs) | ||
else: | ||
# Compress and save | ||
compressed_state_dict = compressor.compress(model, state_dict) | ||
logger.info("Saving compressed model to disk") | ||
original(*args, state_dict=compressed_state_dict, **kwargs) | ||
compressor.update_config(save_directory) | ||
|
||
# These operations happen regardless of compression | ||
update_and_save_recipe(model.name_or_path, save_directory) | ||
copy_python_files_from_model_cache(model, save_directory) | ||
|
||
# Apply compression signature | ||
save_pretrained_wrapper.__signature__ = sig_with_compression_params | ||
save_pretrained_wrapper._overridden = True | ||
model.save_pretrained = save_pretrained_wrapper | ||
|
||
|
||
# HACK: Override the dtype_byte_size function in transformers to support float8 types | ||
|
@@ -306,3 +295,59 @@ def update_and_save_recipe(model_stub: str, save_directory: str): | |
# save recipe | ||
recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) | ||
recipe.yaml(recipe_path) | ||
|
||
|
||
def _create_compression_signature(orig_sig: inspect.Signature) -> inspect.Signature: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very nice! |
||
""" | ||
Creates an enhanced signature with compression parameters. | ||
|
||
:param orig_sig: Original function signature | ||
:return: Enhanced signature with compression parameters | ||
""" | ||
# Define compression parameters | ||
compression_params = [ | ||
inspect.Parameter( | ||
name="sparsity_config", | ||
kind=inspect.Parameter.KEYWORD_ONLY, | ||
default=None, | ||
annotation=Optional[SparsityCompressionConfig], | ||
), | ||
inspect.Parameter( | ||
name="quantization_format", | ||
kind=inspect.Parameter.KEYWORD_ONLY, | ||
default=None, | ||
annotation=Optional[str], | ||
), | ||
inspect.Parameter( | ||
name="save_compressed", | ||
kind=inspect.Parameter.KEYWORD_ONLY, | ||
default=True, | ||
annotation=bool, | ||
), | ||
inspect.Parameter( | ||
name="skip_sparsity_compression_stats", | ||
kind=inspect.Parameter.KEYWORD_ONLY, | ||
default=True, | ||
annotation=bool, | ||
), | ||
inspect.Parameter( | ||
name="disable_sparse_compression", | ||
kind=inspect.Parameter.KEYWORD_ONLY, | ||
default=False, | ||
annotation=bool, | ||
), | ||
] | ||
|
||
# Only add parameters that don't exist in the original signature | ||
existing_params = orig_sig.parameters.keys() | ||
new_params = [] | ||
|
||
for param in orig_sig.parameters.values(): | ||
if param.kind == inspect.Parameter.VAR_KEYWORD: | ||
# Add compression params before **kwargs | ||
new_params.extend( | ||
[p for p in compression_params if p.name not in existing_params] | ||
) | ||
new_params.append(param) | ||
|
||
rahul-tuli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return orig_sig.replace(parameters=new_params) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
95% sure
._overridden
is not saved with the model, just want to confirm that this change won't break any previously saved models that had._overriden