Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions docs/save_pretrained.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Enhanced `save_pretrained` Arguments

The `llmcompressor` library extends Hugging Face's `save_pretrained` method with additional arguments to support model compression functionality. This document explains these extra arguments and how to use them effectively.

## How It Works

When you import `llmcompressor`, it automatically wraps the model's original `save_pretrained` method with an enhanced version that supports compression. This happens in two ways:

1. **Direct modification**: When you call `modify_save_pretrained(model)` directly
2. **Automatic wrapping**: When you call `oneshot(...)`, which wraps `save_pretrained` under the hood

This means that after applying compression with `oneshot`, your model's `save_pretrained` method is already enhanced with compression capabilities, and you can use the additional arguments described below.

## Additional Arguments

When saving your compressed models, you can use the following extra arguments with the `save_pretrained` method:

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `sparsity_config` | `Optional[SparsityCompressionConfig]` | `None` | Optional configuration for sparsity compression. If None and `skip_sparsity_compression_stats` is False, configuration will be automatically inferred from the model. |
| `quantization_format` | `Optional[str]` | `None` | Optional format string for quantization. If not provided, it will be inferred from the model. |
| `save_compressed` | `bool` | `True` | Controls whether to save the model in a compressed format. Set to `False` to save in the original dense format. |
| `skip_sparsity_compression_stats` | `bool` | `True` | Controls whether to skip calculating sparsity statistics (e.g., global sparsity and structure) when saving the model. Set to `False` to include these statistics. |
| `disable_sparse_compression` | `bool` | `False` | When set to `True`, skips any sparse compression during save, even if the model has been previously compressed. |

## Examples

### Applying Compression with oneshot

The simplest approach is to use `oneshot`, which handles both compression and wrapping `save_pretrained`:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier

# Load model
model = AutoModelForCausalLM.from_pretrained("your-model")
tokenizer = AutoTokenizer.from_pretrained("your-model")

# Apply compression - this also wraps save_pretrained
oneshot(
model=model,
recipe=[GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"])],
# Other oneshot parameters...
)

# Now you can use the enhanced save_pretrained
SAVE_DIR = "your-model-W8A8-compressed"
model.save_pretrained(
SAVE_DIR,
save_compressed=True # Use the enhanced functionality
)
tokenizer.save_pretrained(SAVE_DIR)
```

### Manual Approach (Without oneshot)

If you need more control, you can wrap `save_pretrained` manually:

```python
from transformers import AutoModelForCausalLM
from llmcompressor.transformers.sparsification import modify_save_pretrained

# Load model
model = AutoModelForCausalLM.from_pretrained("your-model")

# Manually wrap save_pretrained
modify_save_pretrained(model)

# Now you can use the enhanced save_pretrained
model.save_pretrained(
"your-model-path",
save_compressed=True,
skip_sparsity_compression_stats=False # to infer sparsity config
)
```

### Saving with Custom Sparsity Configuration

```python
from compressed_tensors.sparsification import SparsityCompressionConfig

# Create custom sparsity config
custom_config = SparsityCompressionConfig(
format="2:4",
block_size=16
)

# Save with custom config
model.save_pretrained(
"your-model-custom-sparse",
sparsity_config=custom_config,
)
```

## Notes

- When loading compressed models with `from_pretrained`, the compression format is automatically detected.
- To use compressed models with vLLM, simply load them as you would any model:
```python
from vllm import LLM
model = LLM("./your-model-compressed")
```
- Compression configurations are saved in the model's config file and are automatically applied when loading.

For more information about compression algorithms and formats, please refer to the documentation and examples in the llmcompressor repository.
247 changes: 146 additions & 101 deletions src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import os
import re
import weakref
from functools import wraps
from typing import Dict, Optional

Expand Down Expand Up @@ -33,115 +33,104 @@
__all__ = ["modify_save_pretrained"]


def modify_save_pretrained(model: PreTrainedModel):
def modify_save_pretrained(model: PreTrainedModel) -> None:
"""
Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that
supports compression. The new save_pretrained function performs the following saving
operations:
also supports compression params. The modified save_pretrained function performs the
following operations:

1. Saves the model state, potentially in a compressed format
2. Saves the recipe, appending any current recipes to existing recipe files
3. Copies any necessary python files from the model cache
"""

def save_pretrained_compressed(save_pretrained_method):
if getattr(save_pretrained_method, "_overridden", False):
# `model.save_pretrained` has already been replaced, return.
return save_pretrained_method

# Keep a weak reference to the model class and unbound save_pretrained
# method so we can call the original
model_ref = weakref.ref(save_pretrained_method.__self__)
original_save_pretrained = save_pretrained_method.__func__
model_class = model_ref().__class__
del save_pretrained_method

@wraps(original_save_pretrained)
def save_pretrained_wrapper(
save_directory: str,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: Optional[str] = None,
save_compressed: bool = True,
safe_serialization: bool = True,
skip_sparsity_compression_stats: bool = True,
disable_sparse_compression: bool = False,
**kwargs,
):
"""
Wrapper around PreTrainedModel.save_pretrained(), adds functionality for
saving models in a compressed format on disk. The compression format is
saved to the model's config file

:param save_directory: output directory to save model to
:param sparsity_config: optional sparsity config to compress model with,
if no config is provided it will be inferred from the model
:param quantization_format: optional compression format for quantized
models. If none is provided it will be inferred from the model
:param save_compressed: whether or not to compress the model on disk
:param skip_sparsity_compression_stats: whether to skip the calculation of
sparsity statistics (such as global sparsity and sparsity structure)
when saving a model in dense format
:param disable_sparse_compression: whether to skip sparse compression
during save, default is False
:param kwargs: additional kwargs to pass on to model.save_pretrained
"""

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
logger.info("Fetching state_dict - this may take some time")
state_dict = get_state_dict_offloaded_model(model)

logger.info("Fetching compressor")
compressor = get_model_compressor(
model=model,
sparsity_config=sparsity_config,
quantization_format=quantization_format,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
state_dict=state_dict,
disable_sparse_compression=disable_sparse_compression,
For more information on the compression parameters and model saving in
llmcompressor, refer to docs/save_pretrained.md

:param model: The model whose save_pretrained method will be modified
"""
original = model.save_pretrained
# Avoid double-wrapping if already modified
if getattr(original, "_overridden", False):
return

# Create enhanced signature with compression parameters
orig_sig = inspect.signature(original)
sig_with_compression_params = _create_compression_signature(orig_sig)

@wraps(original)
def save_pretrained_wrapper(
*args,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: Optional[str] = None,
save_compressed: bool = True,
skip_sparsity_compression_stats: bool = True,
disable_sparse_compression: bool = False,
**kwargs,
):
"""
Wrapper around PreTrainedModel.save_pretrained() that adds compression
functionality. The compression format is saved to the model's config file

NOTE: If adding parameters here, also update _create_compression_signature()
to maintain signature consistency.

:param sparsity_config: Optional sparsity compression configuration.
If None and `skip_sparsity_compression_stats` is False, a sparsity
config will be inferred from the model.
:param quantization_format: Optional format string for quantization
:param save_compressed: Whether to save the model in compressed format
:param skip_sparsity_compression_stats: Whether to skip calculating
sparsity stats.
:param disable_sparse_compression: Whether to disable sparse compression
entirely
"""
# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

# Extract save_directory from args or kwargs
save_directory = args[0] if args else kwargs.get("save_directory")
if save_directory is None:
raise ValueError(
"`save_directory` must be provided as first positional arg or kwarg"
)

if compressor is None:
# model is not compressed or quantized, save as normal
original_save_pretrained_func = original_save_pretrained.__get__(
model, model_class
)
original_save_pretrained_func(
save_directory, state_dict=state_dict, **kwargs
)
return

# make sure we're on the main process when saving
if state_dict is not None and len(state_dict) > 0:
compressed_state_dict = compressor.compress(model, state_dict)
logger.info("Saving compressed model to disk")
original_save_pretrained.__get__(model, model_class)(
save_directory,
state_dict=compressed_state_dict,
safe_serialization=safe_serialization,
**kwargs,
)
compressor.update_config(save_directory)

# update existing recipe
update_and_save_recipe(model.name_or_path, save_directory)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)

save_pretrained_wrapper._overriden = True
return save_pretrained_wrapper

# wrap save_pretrained if not already
if not getattr(model.save_pretrained, "_overriden", False):
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
# Get state_dict or fetch it if not provided
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
logger.info("Fetching state_dict – this may take some time")
state_dict = get_state_dict_offloaded_model(model)

logger.info("Fetching compressor")
compressor = get_model_compressor(
model=model,
sparsity_config=sparsity_config,
quantization_format=quantization_format,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
state_dict=state_dict,
disable_sparse_compression=disable_sparse_compression,
)

if compressor is None:
# No compression needed
original(*args, state_dict=state_dict, **kwargs)
else:
# Compress and save
compressed_state_dict = compressor.compress(model, state_dict)
logger.info("Saving compressed model to disk")
original(*args, state_dict=compressed_state_dict, **kwargs)
compressor.update_config(save_directory)

# These operations happen regardless of compression
update_and_save_recipe(model.name_or_path, save_directory)
copy_python_files_from_model_cache(model, save_directory)

# Apply compression signature
save_pretrained_wrapper.__signature__ = sig_with_compression_params
save_pretrained_wrapper._overridden = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

95% sure ._overridden is not saved with the model, just want to confirm that this change won't break any previously saved models that had ._overriden

model.save_pretrained = save_pretrained_wrapper


# HACK: Override the dtype_byte_size function in transformers to support float8 types
Expand Down Expand Up @@ -306,3 +295,59 @@ def update_and_save_recipe(model_stub: str, save_directory: str):
# save recipe
recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME)
recipe.yaml(recipe_path)


def _create_compression_signature(orig_sig: inspect.Signature) -> inspect.Signature:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

"""
Creates an enhanced signature with compression parameters.

:param orig_sig: Original function signature
:return: Enhanced signature with compression parameters
"""
# Define compression parameters
compression_params = [
inspect.Parameter(
name="sparsity_config",
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=Optional[SparsityCompressionConfig],
),
inspect.Parameter(
name="quantization_format",
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=Optional[str],
),
inspect.Parameter(
name="save_compressed",
kind=inspect.Parameter.KEYWORD_ONLY,
default=True,
annotation=bool,
),
inspect.Parameter(
name="skip_sparsity_compression_stats",
kind=inspect.Parameter.KEYWORD_ONLY,
default=True,
annotation=bool,
),
inspect.Parameter(
name="disable_sparse_compression",
kind=inspect.Parameter.KEYWORD_ONLY,
default=False,
annotation=bool,
),
]

# Only add parameters that don't exist in the original signature
existing_params = orig_sig.parameters.keys()
new_params = []

for param in orig_sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
# Add compression params before **kwargs
new_params.extend(
[p for p in compression_params if p.name not in existing_params]
)
new_params.append(param)

return orig_sig.replace(parameters=new_params)
Loading
Loading