From 24f6104baa30f1730f448eb00e0aae34239894e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 15:58:04 -0400 Subject: [PATCH 1/3] remove state dict compress and disk decompress Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 267 +----------------- .../quantization/lifecycle/apply.py | 58 ---- .../utils/safetensors_load.py | 45 +-- .../test_model_compressor.py | 112 +------- 4 files changed, 13 insertions(+), 469 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 8896c060d..12d7e11a8 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -13,17 +13,12 @@ # limitations under the License. import json -import logging -import operator import os -import re -from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union import compressed_tensors import torch -import transformers from compressed_tensors.base import ( COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME, @@ -39,8 +34,6 @@ QuantizationConfig, QuantizationScheme, QuantizationStatus, - apply_quantization_config, - load_pretrained_quantization_parameters, ) from compressed_tensors.transform import TransformConfig from compressed_tensors.utils import ( @@ -48,17 +41,13 @@ delete_offload_parameter, get_execution_device, get_offloaded_device, - get_safetensors_folder, - has_offloaded_params, register_offload_parameter, - update_parameter_data, ) from compressed_tensors.utils.helpers import ( fix_fsdp_module_name, is_compressed_tensors_config, ) from compressed_tensors.utils.match import match_named_modules -from torch import Tensor from torch.nn import Module from tqdm import tqdm from transformers import AutoConfig @@ -71,8 +60,6 @@ __all__ = ["ModelCompressor", "map_module_to_scheme"] -_LOGGER: logging.Logger = logging.getLogger(__name__) - if TYPE_CHECKING: # dummy type if not available from transformers @@ -488,153 +475,6 @@ def decompress_model(self, model: Module): module.quantization_status = QuantizationStatus.FROZEN - # ----- state dict compression pathways ----- # - - def compress( - self, - model: Module, - state_dict: Optional[Dict[str, Tensor]] = None, - show_progress: bool = False, - ) -> Dict[str, Tensor]: - """ - Compresses a dense state dict or model with sparsity and/or quantization - - :param model: uncompressed model to compress - :param state_dict: optional uncompressed state_dict to insert into model - :return: compressed state dict - """ - - if state_dict is None: - state_dict = model.state_dict() - - if self.quantization_compressor is not None: - module_to_scheme = map_module_to_scheme(model) - # Note - compress only supports one compression format atm - quant_compressor = next(iter(self.quantization_compressor.values())) - state_dict = quant_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=show_progress, - ) - - # TODO: consider sparse compression to also be compression - if self.quantization_config.format != CompressionFormat.dense.value: - self.quantization_config.quantization_status = ( - QuantizationStatus.COMPRESSED - ) - - if self.sparsity_compressor is not None: - sparse_compression_targets: Set[str] = { - module_name - for module_name, _module in match_named_modules( - model=model, - targets=self.sparsity_config.targets, - ignore=self.sparsity_config.ignore, - ) - } - state_dict = self.sparsity_compressor.compress( - state_dict, - compression_targets=sparse_compression_targets, - show_progress=show_progress, - ) - - # HACK: Override the dtype_byte_size function in transformers to - # support float8 types. Fix is posted upstream - # https://github.com/huggingface/transformers/pull/30488 - transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - - return state_dict - - # ----- disk decompression pathways ----- # - - def decompress(self, model_path: str, model: Module): - """ - Overwrites the weights in model with weights decompressed from model_path - - :param model_path: path to compressed weights - :param model: pytorch model to load decompressed weights into - - Note: decompress makes use of both _replace_sparsity_weights and - _replace_weights. The variations in these methods are a result of the subtle - variations between the sparsity and quantization compressors. Specifically, - quantization compressors return not just the decompressed weight, but the - quantization parameters (e.g scales, zero_point) whereas sparsity compressors - only return the decompressed weight. - - """ - model_path = get_safetensors_folder(model_path) - sparse_decompressed = False - quant_compressor = ( - next(iter(self.quantization_compressor.values())) - if self.quantization_compressor is not None - else None - ) - - if ( - self.sparsity_compressor is not None - and self.sparsity_config.format != CompressionFormat.dense.value - ): - # note - decompress only supports one compressor atm - params_to_ignore = None - if quant_compressor is not None: - params_to_ignore = quant_compressor.compression_param_names - # Sparse decompression is applied on the model_path - # The compressor will try and load any quantization parameters as well - # params_to_skip_load will skip over quantization params from being loaded - dense_gen = self.sparsity_compressor.decompress( - model_path, params_to_skip_load=params_to_ignore - ) - self._replace_sparsity_weights(dense_gen, model) - setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) - sparse_decompressed = True - - if quant_compressor is not None: - # Temporarily set quantization status to FROZEN to prevent - # quantization during apply_quantization_config. This ensures - # that the dtypes of the weights are not unintentionally updated. - # The status is restored after quantization params are loaded. - - with override_quantization_status( - self.quantization_config, QuantizationStatus.FROZEN - ): - apply_quantization_config(model, self.quantization_config) - names_to_scheme: Set[QuantizationScheme] = { - name: getattr(module, "quantization_scheme") - for name, module in model.named_modules() - if getattr(module, "quantization_scheme", None) is not None - } - # Load activation scales/zp or any other quantization parameters - # Conditionally load the weight quantization parameters if we have a - # dense compressor or if a sparsity compressor has already been applied - load_weight_qparams = sparse_decompressed or isinstance( - quant_compressor, DenseCompressor - ) - load_pretrained_quantization_parameters( - model, - model_path, - # TODO: all weight quantization params will be moved to the - # compressor in a follow-up including initialization - load_weight_qparams=load_weight_qparams, - ) - model_path_or_state_dict = ( - model.state_dict() if sparse_decompressed else model_path - ) - - dense_gen = quant_compressor.decompress( - model_path_or_state_dict, names_to_scheme=names_to_scheme - ) - # TODO: all weight quantization params will be moved to the compressor - # to prevent duplicate parameter updates in update_parameter_data - self._replace_weights( - dense_gen, model, load_weight_qparams=not load_weight_qparams - ) - - def freeze_quantization_status(module): - module.quantization_status = QuantizationStatus.FROZEN - - model.apply(freeze_quantization_status) - setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config) - def update_config(self, save_directory: str): """ Update the model config located at save_directory with compression configs @@ -688,79 +528,6 @@ def update_config(self, save_directory: str): with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True) - def _replace_sparsity_weights(self, dense_weight_generator, model: Module): - """ - Replace the weights of the model with the - provided dense weights. - - This method iterates over the dense_weight_generator and - updates the corresponding weights in the model. If a parameter - name does not exist in the model, it will be skipped. - - :param dense_weight_generator (generator): A generator that yields - tuples of (name, data), where 'name' is the parameter name and - 'data' is the updated param data - :param model: The model whose weights are to be updated. - """ - for name, data in tqdm(dense_weight_generator, desc="Decompressing model"): - split_name = name.split(".") - prefix, param_name = ".".join(split_name[:-1]), split_name[-1] - module = operator.attrgetter(prefix)(model) - - params_device = next(module.parameters()).device - device = "cpu" if has_offloaded_params(module) else params_device - delattr(module, param_name) - requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) - param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad) - register_offload_parameter(module, param_name, param) - - def _replace_weights( - self, dense_weight_generator, model: Module, load_weight_qparams: bool = True - ): - """ - Replace the weights of the model with the - provided dense weights. - - This method iterates over the dense_weight_generator and - updates the corresponding weights in the model. If a parameter - name does not exist in the model, it will be skipped. - - :param dense_weight_generator (generator): A generator that yields - tuples of (name, data), where 'name' is the parameter name and - 'data' is the updated param data - :param model: The model whose weights are to be updated. - """ - - for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"): - module = operator.attrgetter(mod_path)(model) - - params_device = next(module.parameters()).device - device = "cpu" if has_offloaded_params(module) else params_device - - for param_name, param_data in data.items(): - if hasattr(module, param_name): - # If compressed, will have an incorrect dtype for transformers >4.49 - # TODO: we can also just skip initialization of scales/zp if in - # decompression in init to be consistent with loading which happens - # later as well however, update_data does a good shape check - - # should be moved to the compressor - - if param_name == "weight": - delattr(module, param_name) - requires_grad = param_data.dtype in ( - torch.float16, - torch.float32, - torch.bfloat16, - ) - param = torch.nn.Parameter( - param_data.to(device), requires_grad=requires_grad - ) - register_offload_parameter(module, param_name, param) - elif load_weight_qparams: - # Should already be registered to the correct device for - # for scales/zero-points - update_parameter_data(module, param_data, param_name) - def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: """ @@ -775,35 +542,3 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]: and module.quantization_scheme.weights is not None ) } - - -# HACK: Override the dtype_byte_size function in transformers to support float8 types -# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 -def new_dtype_byte_size(dtype): - if dtype == torch.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - -@contextmanager -def override_quantization_status( - config: QuantizationConfig, status: QuantizationStatus -): - """ - Within this context, the quantization status will be set to the - supplied status. After the context exits, the original status - will be restored. - - :param config: the quantization config to override - :param status: the status to temporarily set - """ - original_status = config.quantization_status - config.quantization_status = status - try: - yield - finally: - config.quantization_status = original_status diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index faa48df20..399308b18 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -41,78 +41,20 @@ from compressed_tensors.utils.helpers import deprecated, replace_module from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data -from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open from torch.nn import Module __all__ = [ - "load_pretrained_quantization_parameters", "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", ] -from compressed_tensors.quantization.utils.helpers import is_module_quantized -from compressed_tensors.utils.safetensors_load import ( - get_quantization_parameter_to_path_mapping, -) - _LOGGER = logging.getLogger(__name__) -def load_pretrained_quantization_parameters( - model: Module, - model_name_or_path: Optional[str] = None, - load_weight_qparams: Optional[bool] = False, -): - """ - Loads the quantization parameters (scale and zero point) from model_name_or_path to - a model that has already been initialized with a quantization config. - - NOTE: Will always load inputs/output parameters. Will conditioanlly load weight - parameters, if load_weight_qparams is set to True. - - :param model: model to load pretrained quantization parameters to - :param model_name_or_path: Hugging Face stub or local folder containing a quantized - model, which is used to load quantization parameters - :param load_weight_qparams: whether or not the weight quantization parameters - should be loaded - """ - model_path = get_safetensors_folder(model_name_or_path) - mapping = get_quantization_parameter_to_path_mapping(model_path) - - for name, submodule in model.named_modules(): - if not is_module_quantized(submodule): - continue - if submodule.quantization_scheme.input_activations is not None: - base_name = "input" - _load_quant_args_from_mapping( - base_name=base_name, - module_name=name, - module=submodule, - mapping=mapping, - ) - if submodule.quantization_scheme.output_activations is not None: - base_name = "output" - _load_quant_args_from_mapping( - base_name=base_name, - module_name=name, - module=submodule, - mapping=mapping, - ) - - if load_weight_qparams and submodule.quantization_scheme.weights: - base_name = "weight" - _load_quant_args_from_mapping( - base_name=base_name, - module_name=name, - module=submodule, - mapping=mapping, - ) - - def apply_quantization_config( model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False ): diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index cb2b913bb..0463addcf 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -19,11 +19,10 @@ from typing import Dict, Iterable, Optional, Tuple, Union from torch import Tensor -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME __all__ = [ - "get_safetensors_folder", "get_safetensors_header", "match_param_name", "merge_names", @@ -39,48 +38,6 @@ NestedWeightMappingType = Dict[str, WeightMappingType] -def get_safetensors_folder( - pretrained_model_name_or_path: str, cache_dir: Optional[str] = None -) -> str: - """ - Given a Hugging Face stub or a local path, return the folder containing the - safetensors weight files - - :param pretrained_model_name_or_path: local path to model or HF stub - :param cache_dir: optional cache dir to search through, if none is specified the - model will be searched for in the default TRANSFORMERS_CACHE - :return: local folder containing model data - """ - if os.path.exists(pretrained_model_name_or_path): - # argument is a path to a local folder - return os.path.abspath(pretrained_model_name_or_path) - - safetensors_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - index_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_INDEX_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - if safetensors_path is not None: - # found a single cached safetensors file - return os.path.split(safetensors_path)[0] - if index_path is not None: - # found a cached safetensors weight index file - return os.path.split(index_path)[0] - - # model weights could not be found locally or cached from HF Hub - raise ValueError( - "Could not locate safetensors weight or index file from " - f"{pretrained_model_name_or_path}." - ) - - def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: """ Extracts the metadata from a safetensors file as JSON diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 115cf3f5a..55c75867a 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -175,70 +175,24 @@ def create_quantization_config(bits=8, type="int", strategy="tensor"): create_quantization_config(bits=8, type="float", strategy="channel"), ], ) -def test_composability(tmp_path, sparsity_config, quantization_config): - +def test_composability(sparsity_config, quantization_config): model_compressor = ModelCompressor( sparsity_config=sparsity_config, quantization_config=quantization_config ) - fake_oneshot_model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model( + model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model( sparsity_config=sparsity_config, quantization_config=quantization_config ) - fake_oneshot_model = fake_oneshot_model.to(torch.float32) + model = model.to(torch.float32) + # does both sparse and quantization compression - compressed_state_dict = model_compressor.compress(fake_oneshot_model) + model_compressor.compress_model(model) + compressed_state_dict = {key: value.clone() for key, value in model.state_dict()} - save_dir = tmp_path / "model" - save_dir = _create_dummy_checkpoint( - compressed_state_dict, save_dir, model_compressor - ) - - decompressed_model = DummyLinearModel( - torch.zeros_like(fake_oneshot_model.linear.weight) - ) - decompressed_model = decompressed_model.float() - model_compressor.decompress(model=decompressed_model, model_path=save_dir) + model_compressor.decompress_model(model) + decompressed_state_dict = {key: value.clone() for key, value in model.state_dict()} # check that the decompressed model is the same as the original model - _check_state_dicts(fake_oneshot_model.state_dict(), decompressed_model.state_dict()) - - -@pytest.mark.parametrize( - "sparsity_config, quantization_config, missing, unexpected", - [ - ( - get_bitmask_sparsity_config(), - create_quantization_config(bits=8, type="int", strategy="channel"), - {"linear.weight"}, - { - "linear.bitmask", - "linear.compressed", - "linear.row_offsets", - "linear.shape", - "linear.weight_scale", - }, - ) - ], -) -def test_missing_and_unexpected_keys_on_compression( - tmp_path, sparsity_config, quantization_config, missing, unexpected -): - - model_compressor = ModelCompressor( - sparsity_config=sparsity_config, quantization_config=quantization_config - ) - fake_oneshot_model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model( - sparsity_config=sparsity_config, quantization_config=quantization_config - ) - - og_state_dict_keys = set( - DummyLinearModel(weights=torch.randn(10, 5)).state_dict().keys() - ) - compressed_state_dict_keys = set( - model_compressor.compress(fake_oneshot_model).keys() - ) - - assert og_state_dict_keys - compressed_state_dict_keys == missing - assert compressed_state_dict_keys - og_state_dict_keys == unexpected + _check_state_dicts(compressed_state_dict, decompressed_state_dict) class TwoLayerModel(nn.Module): @@ -317,50 +271,6 @@ def _get_combined_config(s_config, q_config): return combined -@pytest.mark.parametrize( - "model_stub,q_format,s_config", - [ - ( - "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", - "float-quantized", - None, - ), - ( - "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", - None, - "sparse-24-bitmask", - ), - ( - "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", - "float-quantized", - "sparse-24-bitmask", - ), - ( - "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed", - "pack-quantized", - None, - ), - ], -) -def test_compress_model(model_stub, q_format, s_config, tmpdir): - model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) - - # compress model by eagerly compressing state dict - true_compressed = dict(compressor.compress(model)) - true_compressed = {key: value.clone() for key, value in true_compressed.items()} - - # compress model directly - compressor.compress_model(model) - compressed = dict(model.state_dict()) - - # equivalent to eagerly compressing state dict - assert compressed.keys() == true_compressed.keys() - for key in compressed.keys(): - assert compressed[key].dtype == true_compressed[key].dtype - assert torch.all(compressed[key] == true_compressed[key]), f"{key}" - - @pytest.mark.parametrize( "model_stub,q_format,s_config", [ @@ -395,13 +305,13 @@ def test_compress_model_meta(model_stub, q_format, s_config): cpu_model, s_config, [q_format] ) # Only stores dtype because meta model does not store values - expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} + reference_compressor.compress_model(cpu_model) + expected = {k: v.dtype for k, v in cpu_model.state_dict()} # Load model on meta device meta_model = AutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.float32, - low_cpu_mem_usage=True, ) for module in meta_model.modules(): if hasattr(module, "to_empty"): From 25bd87a6924a0321c933f98693a6539cc2141503 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 16:20:47 -0400 Subject: [PATCH 2/3] fix zero points initialize Signed-off-by: Kyle Sayers --- .../test_model_compressor.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 55c75867a..ba2680fae 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -123,11 +123,11 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None): # Attach weight_scale and weight_zero_point as parameters if weight_scale is not None: self.linear.weight_scale = nn.Parameter( - torch.tensor(weight_scale), requires_grad=False + weight_scale.detach().clone(), requires_grad=False ) if weight_zero_point is not None: self.linear.weight_zero_point = nn.Parameter( - torch.tensor(weight_zero_point), requires_grad=False + weight_zero_point.detach().clone(), requires_grad=False ) def forward(self, x): @@ -176,23 +176,21 @@ def create_quantization_config(bits=8, type="int", strategy="tensor"): ], ) def test_composability(sparsity_config, quantization_config): - model_compressor = ModelCompressor( - sparsity_config=sparsity_config, quantization_config=quantization_config - ) + model_compressor = ModelCompressor(sparsity_config, quantization_config) model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model( - sparsity_config=sparsity_config, quantization_config=quantization_config + quantization_config, + sparsity_config, ) model = model.to(torch.float32) # does both sparse and quantization compression + original_state_dict = {k: v.clone() for k, v in model.state_dict().items()} model_compressor.compress_model(model) - compressed_state_dict = {key: value.clone() for key, value in model.state_dict()} - model_compressor.decompress_model(model) - decompressed_state_dict = {key: value.clone() for key, value in model.state_dict()} + decompressed_state_dict = {k: v.clone() for k, v in model.state_dict().items()} # check that the decompressed model is the same as the original model - _check_state_dicts(compressed_state_dict, decompressed_state_dict) + _check_state_dicts(original_state_dict, decompressed_state_dict) class TwoLayerModel(nn.Module): @@ -252,6 +250,9 @@ def _get_fake_oneshot_sparse_quantized_model(quantization_config, sparsity_confi args=quantization_args, ) + if quantization_args.symmetric: + zero_point = None # do not include in model + fake_oneshot_model = DummyLinearModel(quantized_weights, scale, zero_point) fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[ "group_0" @@ -306,7 +307,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): ) # Only stores dtype because meta model does not store values reference_compressor.compress_model(cpu_model) - expected = {k: v.dtype for k, v in cpu_model.state_dict()} + expected = {k: v.dtype for k, v in cpu_model.state_dict().items()} # Load model on meta device meta_model = AutoModelForCausalLM.from_pretrained( From 22bedc9d9bf7927fca2e703681ab7047cc9456b0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 16:24:14 -0400 Subject: [PATCH 3/3] remove function Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 56 +------------------ 1 file changed, 1 insertion(+), 55 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 399308b18..61f28461d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -15,7 +15,7 @@ import logging from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional +from typing import Iterable, List from typing import OrderedDict as OrderedDictType from typing import Union @@ -40,8 +40,6 @@ ) from compressed_tensors.utils.helpers import deprecated, replace_module from compressed_tensors.utils.match import match_named_modules, match_targets -from compressed_tensors.utils.offload import update_parameter_data -from safetensors import safe_open from torch.nn import Module @@ -196,58 +194,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - -def _load_quant_args_from_mapping( - base_name: str, module_name: str, module: Module, mapping: Dict -): - # TODO: skip update and just register here, don't do it in initialize - """ - Loads scale and zero point from a state_dict into the specified module - - :param base_name: quantization target, one of: weights, input_activations or - output_activations - :param module_name: pytorch module name to look up in state_dict - :module: pytorch module associated with module_name - :mapping: mapping to search fetch paths on disk for a given parameter - """ - scale_name = f"{base_name}_scale" - zp_name = f"{base_name}_zero_point" - g_idx_name = f"{base_name}_g_idx" - - state_dict_scale_path = mapping.get(f"{module_name}.{scale_name}", None) - state_dict_zp_path = mapping.get(f"{module_name}.{zp_name}", None) - state_dict_g_idx_path = mapping.get(f"{module_name}.{g_idx_name}", None) - - if state_dict_g_idx_path is not None: - with safe_open(state_dict_g_idx_path, framework="pt", device="cpu") as f: - state_dict_g_idx = f.get_tensor(f"{module_name}.{g_idx_name}") - - update_parameter_data(module, state_dict_g_idx, g_idx_name) - - if state_dict_scale_path is not None: - # module is quantized - with safe_open(state_dict_scale_path, framework="pt", device="cpu") as f: - state_dict_scale = f.get_tensor(f"{module_name}.{scale_name}") - - update_parameter_data(module, state_dict_scale, scale_name) - - if state_dict_zp_path is None: - # fill in zero point for symmetric quantization - state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu") - else: - with safe_open(state_dict_zp_path, framework="pt", device="cpu") as f: - state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") - - update_parameter_data(module, state_dict_zp, zp_name) - - def _scheme_from_targets( target_to_scheme: OrderedDictType[str, QuantizationScheme], targets: List[str],