Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@
# limitations under the License.

import json
import logging
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union

import compressed_tensors
import torch
import transformers
from compressed_tensors.base import (
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
Expand All @@ -39,26 +34,20 @@
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
apply_quantization_config,
load_pretrained_quantization_parameters,
)
from compressed_tensors.transform import TransformConfig
from compressed_tensors.utils import (
align_module_device,
delete_offload_parameter,
get_execution_device,
get_offloaded_device,
get_safetensors_folder,
has_offloaded_params,
register_offload_parameter,
update_parameter_data,
)
from compressed_tensors.utils.helpers import (
fix_fsdp_module_name,
is_compressed_tensors_config,
)
from compressed_tensors.utils.match import match_named_modules
from torch import Tensor
from torch.nn import Module
from tqdm import tqdm
from transformers import AutoConfig
Expand All @@ -71,8 +60,6 @@

__all__ = ["ModelCompressor", "map_module_to_scheme"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


if TYPE_CHECKING:
# dummy type if not available from transformers
Expand Down Expand Up @@ -488,153 +475,6 @@ def decompress_model(self, model: Module):

module.quantization_status = QuantizationStatus.FROZEN

# ----- state dict compression pathways ----- #

def compress(
self,
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
show_progress: bool = False,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict or model with sparsity and/or quantization

:param model: uncompressed model to compress
:param state_dict: optional uncompressed state_dict to insert into model
:return: compressed state dict
"""

if state_dict is None:
state_dict = model.state_dict()

if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
# Note - compress only supports one compression format atm
quant_compressor = next(iter(self.quantization_compressor.values()))
state_dict = quant_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=show_progress,
)

# TODO: consider sparse compression to also be compression
if self.quantization_config.format != CompressionFormat.dense.value:
self.quantization_config.quantization_status = (
QuantizationStatus.COMPRESSED
)

if self.sparsity_compressor is not None:
sparse_compression_targets: Set[str] = {
module_name
for module_name, _module in match_named_modules(
model=model,
targets=self.sparsity_config.targets,
ignore=self.sparsity_config.ignore,
)
}
state_dict = self.sparsity_compressor.compress(
state_dict,
compression_targets=sparse_compression_targets,
show_progress=show_progress,
)

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

return state_dict

# ----- disk decompression pathways ----- #

def decompress(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from model_path

:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into

Note: decompress makes use of both _replace_sparsity_weights and
_replace_weights. The variations in these methods are a result of the subtle
variations between the sparsity and quantization compressors. Specifically,
quantization compressors return not just the decompressed weight, but the
quantization parameters (e.g scales, zero_point) whereas sparsity compressors
only return the decompressed weight.

"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False
quant_compressor = (
next(iter(self.quantization_compressor.values()))
if self.quantization_compressor is not None
else None
)

if (
self.sparsity_compressor is not None
and self.sparsity_config.format != CompressionFormat.dense.value
):
# note - decompress only supports one compressor atm
params_to_ignore = None
if quant_compressor is not None:
params_to_ignore = quant_compressor.compression_param_names
# Sparse decompression is applied on the model_path
# The compressor will try and load any quantization parameters as well
# params_to_skip_load will skip over quantization params from being loaded
dense_gen = self.sparsity_compressor.decompress(
model_path, params_to_skip_load=params_to_ignore
)
self._replace_sparsity_weights(dense_gen, model)
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if quant_compressor is not None:
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.

with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
name: getattr(module, "quantization_scheme")
for name, module in model.named_modules()
if getattr(module, "quantization_scheme", None) is not None
}
# Load activation scales/zp or any other quantization parameters
# Conditionally load the weight quantization parameters if we have a
# dense compressor or if a sparsity compressor has already been applied
load_weight_qparams = sparse_decompressed or isinstance(
quant_compressor, DenseCompressor
)
load_pretrained_quantization_parameters(
model,
model_path,
# TODO: all weight quantization params will be moved to the
# compressor in a follow-up including initialization
load_weight_qparams=load_weight_qparams,
)
model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = quant_compressor.decompress(
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
# TODO: all weight quantization params will be moved to the compressor
# to prevent duplicate parameter updates in update_parameter_data
self._replace_weights(
dense_gen, model, load_weight_qparams=not load_weight_qparams
)

def freeze_quantization_status(module):
module.quantization_status = QuantizationStatus.FROZEN

model.apply(freeze_quantization_status)
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)

def update_config(self, save_directory: str):
"""
Update the model config located at save_directory with compression configs
Expand Down Expand Up @@ -688,79 +528,6 @@ def update_config(self, save_directory: str):
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
"""
Replace the weights of the model with the
provided dense weights.

This method iterates over the dense_weight_generator and
updates the corresponding weights in the model. If a parameter
name does not exist in the model, it will be skipped.

:param dense_weight_generator (generator): A generator that yields
tuples of (name, data), where 'name' is the parameter name and
'data' is the updated param data
:param model: The model whose weights are to be updated.
"""
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
split_name = name.split(".")
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
module = operator.attrgetter(prefix)(model)

params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device
delattr(module, param_name)
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
register_offload_parameter(module, param_name, param)

def _replace_weights(
self, dense_weight_generator, model: Module, load_weight_qparams: bool = True
):
"""
Replace the weights of the model with the
provided dense weights.

This method iterates over the dense_weight_generator and
updates the corresponding weights in the model. If a parameter
name does not exist in the model, it will be skipped.

:param dense_weight_generator (generator): A generator that yields
tuples of (name, data), where 'name' is the parameter name and
'data' is the updated param data
:param model: The model whose weights are to be updated.
"""

for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
module = operator.attrgetter(mod_path)(model)

params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device

for param_name, param_data in data.items():
if hasattr(module, param_name):
# If compressed, will have an incorrect dtype for transformers >4.49
# TODO: we can also just skip initialization of scales/zp if in
# decompression in init to be consistent with loading which happens
# later as well however, update_data does a good shape check -
# should be moved to the compressor

if param_name == "weight":
delattr(module, param_name)
requires_grad = param_data.dtype in (
torch.float16,
torch.float32,
torch.bfloat16,
)
param = torch.nn.Parameter(
param_data.to(device), requires_grad=requires_grad
)
register_offload_parameter(module, param_name, param)
elif load_weight_qparams:
# Should already be registered to the correct device for
# for scales/zero-points
update_parameter_data(module, param_data, param_name)


def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
"""
Expand All @@ -775,35 +542,3 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
and module.quantization_scheme.weights is not None
)
}


# HACK: Override the dtype_byte_size function in transformers to support float8 types
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
def new_dtype_byte_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.

:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
Loading