diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index fa19948e8..97a946c6e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -3,8 +3,8 @@ import torch from compressed_tensors.quantization import ( KVCacheScaleType, + QuantizationScheme, QuantizationStatus, - is_attention_module, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme @@ -14,6 +14,7 @@ from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer +from llmcompressor.utils.helpers import getattr_chain __all__ = [ "initialize_observer", @@ -22,9 +23,10 @@ "calibrate_output_hook", "calibrate_kv_cache_input_hook", "calibrate_kv_cache_output_hook", - "set_unset_kv_cache", + "initialize_quantized_kv_cache", "freeze_module_quantization", "apply_calibration_status", + "reset_quantization_status", ] @@ -49,10 +51,6 @@ def initialize_observer( # no quantization scheme nothing to do return - # observers have a different lifecycle for kv_cache - if is_attention_module(module): - return - quantization_args = getattr(quantization_scheme, arg_name, None) # dont need observers for dynamic if quantization_args is not None and not quantization_args.dynamic: @@ -102,25 +100,15 @@ def update_weight_zp_scale(module: Module): :param quantize_weights_upfront: whether to automatically run weight quantization at the start of calibration """ - if not getattr(module, "quantization_scheme", None): - # no quantization scheme nothing to do + if getattr_chain(module, "quantization_scheme.weights", None) is None: return - status = getattr(module, "quantization_status", None) - if not status: - # not set to initialize; no scales/zp to update - return - if status != QuantizationStatus.INITIALIZED: + if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION: logger.warning( - f"Attempting set module with status {status} to calibration mode. " - f"but status is not {QuantizationStatus.INITIALIZED} - you may " - "be calibrating an uninitialized module which may fail or attempting " - "to re-calibrate a frozen module" + "Attempting to calibrate weights of a module not in calibration mode" ) - if module.quantization_scheme.weights is not None: - # set weight scale and zero_point up front, calibration data doesn't affect it - call_observer(module=module, base_name="weight") + call_observer(module=module, base_name="weight") def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): @@ -200,21 +188,26 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value) -def set_unset_kv_cache(module: Module): +def initialize_quantized_kv_cache(module: Module): """ - Set or unset singleton QuantizedKVParameterCache for each - attn module when running kv_cache quantization. + Initialize a quantized kv_cache on a module (analogous to initializing an observer) + When a config specifying kv_cache quantization is applied to a model, the kv_cache + args are redefined as the output_activations targeting attention modules. + + This function should be called on attention modules with output_activations """ - if not hasattr(module, "quantization_scheme"): + scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) + existing_kv_cache = getattr(module, "kv_cache", None) + + if ( + scheme is None + or not is_kv_cache_quant_scheme(scheme) + or isinstance(existing_kv_cache, QuantizedKVParameterCache) + ): return - if is_kv_cache_quant_scheme(module.quantization_scheme): - output_args = module.quantization_scheme.output_activations - kv_cache = QuantizedKVParameterCache(output_args) - if hasattr(module, "kv_cache"): - delattr(module, "kv_cache") - else: - setattr(module, "kv_cache", kv_cache) + quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) + setattr(module, "kv_cache", quantized_kv_cache) def apply_calibration_status(module: Module): @@ -242,9 +235,21 @@ def freeze_module_quantization(module: Module): # nothing to do, already frozen return + # remove observers for name in ("input", "weight", "output"): obs_name = f"{name}_observer" if hasattr(module, obs_name): delattr(module, obs_name) + # remove quantized kv_cache + kv_cache = getattr(module, "kv_cache", None) + if isinstance(kv_cache, QuantizedKVParameterCache): + delattr(module, "kv_cache") + module.quantization_status = QuantizationStatus.FROZEN + + +def reset_quantization_status(model: Module): + for module in model.modules(): + if hasattr(module, "quantization_status"): + delattr(module, "quantization_status") diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index de428cce7..d9c74a496 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,9 +1,9 @@ import contextlib import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import QuantizationScheme +from compressed_tensors.quantization import disable_quantization from compressed_tensors.utils import ( align_module_device, get_execution_device, @@ -11,30 +11,28 @@ update_offload_parameter, ) from loguru import logger -from pydantic import Field, PrivateAttr, field_validator +from pydantic import PrivateAttr, field_validator from llmcompressor.core import State -from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization +from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.gptq.gptq_quantize import ( accumulate_hessian, make_empty_hessian, quantize_weight, ) -from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.pipelines.basic import run_pipeline as run_basic from llmcompressor.pipelines.layer_sequential import ( run_pipeline as run_layer_sequential, ) from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active +from llmcompressor.utils.pytorch.module import get_no_split_params __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier, HooksMixin): +class GPTQModifier(Modifier, QuantizationMixin): """ Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier uses activations to calibrate a hessian matrix, which is then used to determine @@ -79,32 +77,31 @@ class GPTQModifier(Modifier, HooksMixin): :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm - :param quantize: Set to True to quantize using an existing quantization modifier, - or pass in the configuration for a quantization modifier if one does not - already exist in the recipe :param offload_hessians: Set to True for decreased memory usage but increased runtime. - :param config_groups: [Used, if a quantization modifier is not specified], - dictionary specifying quantization schemes to apply to target + + :param config_groups: dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. - :param scheme: [Used, if a quantization modifier is not specified], the quantization - scheme to apply to the model, this is a dictionary that supports all keys from - QuantizationScheme except targets, which will be set to the targets parameter - set at the modifier level. Can also be set to a dictionary of the format - `preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit - or a string of a preset scheme if targets is provided - and activation 8 bit quantization on the Linear layers. :param targets: list of layer names to quantize if a scheme is provided. Defaults to Linear layers - :param ignore: [Used, if a quantization modifier is not specified] - optional list of module class names or submodule names to not + :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - :param disable_quantization_observer_epoch: [Used, if a quantization modifier is - not specified] Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None + :param scheme: a single quantization scheme to apply to the model. This is a + dictionary that supports all keys from QuantizationScheme except targets, which + will be set to the targets parameter set at the modifier level. Can also be set + to a dictionary of the format `preset_scheme_name: targets` for example: + `W8A8: ['Linear']` for weight and activation 8-bit. + :param kv_cache_scheme: optional QuantizationArgs, that specify the + quantization of the kv cache. If None, kv cache is not quantized. + When applying kv cache quantization to transformer AutoModelForCausalLM, + the kv_cache_scheme gets converted into a QuantizationScheme that: + - targets the `q_proj` and `k_proj` modules of the model. The outputs + of those modules are the keys and values that might be cached + - quantizes the outputs of the aformentioned layers, so that + keys and values are compressed before storing them in the cache + There is an explicit assumption that the model contains modules with + `k_proj` and `v_proj` in their names. If this is not the case + and kv_cache_scheme != None, the quantization of kv cache will fail """ # gptq modifier arguments @@ -115,16 +112,7 @@ class GPTQModifier(Modifier, HooksMixin): quantize: Union[bool, Dict] = True offload_hessians: bool = False - # arguments used for attached quant modifier - config_groups: Optional[Dict[str, QuantizationScheme]] = None - scheme: Optional[Union[str, Dict[str, Any]]] = None - targets: Union[str, List[str], None] = None - ignore: List[str] = Field(default_factory=list) - num_calibration_steps: Optional[int] = None - disable_quantization_observer_epoch: Optional[float] = None - # private variables - _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr(default=None) _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) @@ -140,74 +128,27 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def _check_build_quant_modifier(self, model: torch.nn.Module): - """ - Check the model's quantization state matches that expected by this modifier, - adding a default quantization scheme if needed - - # TODO: build modifier during recipe validation - - :param state: session state storing input model and calibration data - """ - quantization_already_active = qat_active(model) - if isinstance(self.quantize, bool): - if not self.quantize and quantization_already_active: - logger.warning( - "GPTQ quantization is set to False, but a " - "quantization modifier is already active on the model " - "resetting quantize to True" - ) - self.quantize = True - elif self.quantize and not quantization_already_active: - logger.warning( - "GPTQ quantization is set to True without an " - "active quantization modifier." - ) - self._build_quant_modifier() - return # use existing quantization modifier if there is one - else: - if not isinstance(self.quantize, Dict): - raise ValueError( - "GPTQModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"type {type(self.quantize)}" - ) - if len(self.quantize) != 1: - raise ValueError( - "GPTQModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"{len(self.quantize)} modifiers" - ) - if quantization_already_active: - logger.warning( - "Attempting to initialize quantization for GPTQ " - "but a quantization modifier has already been applied. " - "The quantization configuration defined under the " - "GPTQ modifier will be ignored." - ) - self.quantize = True - return - self._build_quant_modifier_from_dict(self.quantize) - self.quantize = True - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run the GPTQ algorithm on the current state :param state: session state storing input model and calibration data """ - # build quantization modifier - self._check_build_quant_modifier(state.model) + # apply config to model and prepare calibration hooks + if QuantizationMixin.has_config(self): + QuantizationMixin.initialize_quantization(self, state.model) - if self._quantization_modifier: - self._quantization_modifier.initialize(state, **kwargs) - if not self.quantize: - raise ValueError("To use the GPTQModifier, quantization must be enabled.") + # assume quantization has been initialized by this modifier or one before it + QuantizationMixin.start_calibration(self, state.model) + # Unlike qmod, do not quantize as we calibrate + # This choice does not seem to have a meaningful impact on accuracy + state.model.apply(disable_quantization) # prepare module names self._module_names = {m: name for name, m in state.model.named_modules()} # register hooks + added_hook = False for module in state.model.modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: # HACK: previously, embeddings were not quantized because they were not @@ -215,6 +156,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: # but in the FUTURE this should be ignored by the user if not isinstance(module, torch.nn.Embedding): self.register_hook(module, self.calibrate_module, "forward") + added_hook = True + + if not added_hook: + raise ValueError( + "GPTQModifier requires a quantization config be specified by this " + "modifier or a modifier preceding it" + ) # infer sequential targets if self.sequential_targets is None: @@ -281,13 +229,11 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: session state storing input model and calibration data """ - if self._quantization_modifier: - self._quantization_modifier.finalize(state, **kwargs) - - self.remove_hooks() self._hessians = dict() self._num_samples = dict() - state.model.apply(freeze_module_quantization) + + QuantizationMixin.end_calibration(self, state.model) + self.remove_hooks() # remove gptq hooks return True @@ -371,41 +317,3 @@ def _maybe_onload_hessian(self, module: torch.nn.Module): if self.offload_hessians: if module in self._hessians: # may have been deleted in context self._hessians[module] = self._hessians[module].to(device="cpu") - - def _build_quant_modifier(self): - """ - Build a quantization modifier based on the specified config_groups, - ignore list, and num_calibration_steps. - - :postcondition: self._quantization_modifier is set to the built - quantization modifier - """ - - quantization_args_names = [ - "config_groups", - "targets", - "scheme", - "num_calibration_steps", - "ignore", - "disable_quantization_observer_epoch", - ] - - quant_args = { - key: getattr(self, key) - for key in quantization_args_names - if getattr(self, key, False) - } - - logger.info(f"Building quantization modifier with args: {quant_args}") - vllm_quant_config = {"QuantizationModifier": quant_args} - self._build_quant_modifier_from_dict(vllm_quant_config) - - def _build_quant_modifier_from_dict(self, quant_config): - modifier_type = list(quant_config.keys())[0] - modifier_args = quant_config[modifier_type] - self._quantization_modifier = ModifierFactory.create( - modifier_type, - allow_registered=True, - allow_experimental=True, - **modifier_args, - ) diff --git a/src/llmcompressor/modifiers/quantization/quantization/__init__.py b/src/llmcompressor/modifiers/quantization/quantization/__init__.py index 8bdc93d14..f268f065f 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/__init__.py +++ b/src/llmcompressor/modifiers/quantization/quantization/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .base import * +from .mixin import * diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 3a8946aef..3c309a074 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -1,43 +1,18 @@ -from typing import Any, Dict, List, Optional, Union - -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationConfig, - QuantizationScheme, - QuantizationStatus, - apply_quantization_config, - is_attention_module, - is_preset_scheme, - preset_name_to_scheme, -) +import torch +import tqdm from loguru import logger -from pydantic import Field, field_validator -from torch.nn import Module -from llmcompressor.core import Event, EventType, State +from llmcompressor.core import Event, State from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.quantization.calibration import ( - apply_calibration_status, - calibrate_input_hook, - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, - calibrate_output_hook, - freeze_module_quantization, - initialize_observer, - set_unset_kv_cache, - update_weight_zp_scale, -) -from llmcompressor.modifiers.utils.pytorch_helpers import ( - is_moe_model, - run_calibration_forward, -) -from llmcompressor.observers.helpers import get_observer_token_count +from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["QuantizationModifier"] -class QuantizationModifier(Modifier): +class QuantizationModifier(Modifier, QuantizationMixin): """ Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), @@ -46,6 +21,8 @@ class QuantizationModifier(Modifier): :param config_groups: dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. + :param targets: list of layer names to quantize if a scheme is provided. Defaults + to Linear layers :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. :param scheme: a single quantization scheme to apply to the model. This is a @@ -64,313 +41,87 @@ class QuantizationModifier(Modifier): There is an explicit assumption that the model contains modules with `k_proj` and `v_proj` in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail - :param targets: list of layer names to quantize if a scheme is provided. Defaults - to Linear layers - :param disable_quantization_observer_epoch: Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used """ - config_groups: Optional[Dict[str, QuantizationScheme]] = None - ignore: List[str] = Field(default_factory=list) - targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) - scheme: Optional[Union[str, Dict[str, Any]]] = None - kv_cache_scheme: Optional[QuantizationArgs] = None - disable_quantization_observer_epoch: Optional[float] = None - num_calibration_steps: Optional[int] = None - - calibration_dataloader_: Any = None - calibration_function_: Any = None - - @field_validator("targets", mode="before") - def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: - if isinstance(value, str): - return [value] + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Prepare to calibrate activations and weights - return value + According to the quantization config, a quantization scheme is attached to each + targeted module. The module's forward call is also overwritten to perform + quantization to inputs, weights, and outputs. - def on_initialize(self, state: State, **kwargs) -> bool: - if self.end and self.end != -1: + Then, according to the module's quantization scheme, observers and calibration + hooks are added. These hooks are disabled until the modifier starts. + """ + if not QuantizationMixin.has_config(self): raise ValueError( - "end_epoch is disabled for QuantizationModifier and can only be set to" - " -1 or None. Given {}".format(self.end) + "QuantizationModifier requires that quantization fields to be specified" ) - self.calibration_dataloader_ = state.data.calib - module = state.model - - # initialize quantization in appropriate modules - config = self._apply_modifier_to_model(module) - module.apply(lambda module: initialize_observer(module, base_name="weight")) + QuantizationMixin.initialize_quantization(self, state.model) - if self.calculate_start() == -1: # one-shot - self._check_calibration_data(config) - module.apply(update_weight_zp_scale) - module.apply(apply_calibration_status) - self._calibrate_if_possible(module) - self._check_token_distribution( - module, threshold=kwargs.get("min_tokens_per_module") - ) - module.apply(freeze_module_quantization) + # FUTURE: modify oneshot lifecycle to trigger on_start for on initialize + if self.calculate_start() == -1: # one shot + self.on_start(state) return True - def on_start(self, state: State, event: Event, **kwargs): - module = state.model - module.apply(update_weight_zp_scale) - - def on_update(self, state: State, event: Event, **kwargs): - if event.type_ == EventType.BATCH_START: - if self.check_should_disable_observer(event): - module = state.model - module.apply(freeze_module_quantization) - - def on_end(self, state: State, event: Event, **kwargs): - module = state.model - module.apply(freeze_module_quantization) - - def create_init_config(self) -> QuantizationConfig: - if self.scheme is not None: - # takes precedence over config_groups - - if isinstance(self.scheme, str) and is_preset_scheme(self.scheme): - # attach targets to scheme - self.scheme = {self.scheme: self.targets} - - self.config_groups = {} - for idx, key in enumerate(self.scheme.keys()): - if is_preset_scheme(key): - scheme = preset_name_to_scheme(key, self.scheme[key]) - else: - scheme = QuantizationScheme.model_validate( - {"targets": self.scheme[key], **self.scheme} - ) - - group_name = f"group_{idx}" - self.config_groups[group_name] = scheme - - if self.config_groups is None or len(self.config_groups) == 0: - default_quant_scheme = QuantizationScheme(targets=self.targets) - self.config_groups = {"group_0": default_quant_scheme} - logger.info( - f"No config groups were provided, using default {self.config_groups}" - ) - - return QuantizationConfig( - config_groups=self.config_groups, - kv_cache_scheme=self.kv_cache_scheme, - quantization_status=QuantizationStatus.INITIALIZED, - ignore=self.ignore, - ) - - def calculate_disable_observer_epoch(self) -> float: + def on_start(self, state: State): """ - Get the epoch at which we want to disable to quantization observer - :return epoch to disable at, or -1 if it is not set - """ - return ( - self.disable_quantization_observer_epoch - if self.disable_quantization_observer_epoch is not None - else -1 - ) - - def check_should_disable_observer(self, event: Event) -> bool: + Begin calibrating activations and weights. Calibrate weights only once on start """ - Given the current index, determine if we should disable the observer + QuantizationMixin.start_calibration(self, state.model) - :param event: Event to get index from - :return: True if observer should be disabled, False otherwise - """ - disable_epoch = self.calculate_disable_observer_epoch() - if disable_epoch == -1: - return False - if event.current_index >= disable_epoch: - return True - return False + modules = list(state.model.modules()) + for module in tqdm.tqdm(modules, desc="Calibrating weights"): + update_weight_zp_scale(module) - def _check_calibration_data(self, config: QuantizationConfig): - has_calibration_data = self.calibration_dataloader_ is not None - requires_calibration = config.requires_calibration_data() + # FUTURE: below will be removed after pipeline extraction if self.calculate_start() == -1: # one shot - if requires_calibration and not has_calibration_data: - raise ValueError( - "The provided quantization configuration requires calibration data " - "but none was provided. Calibration data is required for static " - "quantization of input or output activations." - ) - if not requires_calibration and has_calibration_data: - logger.info( - "Skipping QuantizationModifier calibration, it is not required for " - "the provided quantization config." - ) - self.calibration_dataloader_ = None - - def _apply_modifier_to_model(self, model: Module): - modifier_as_config = self.create_init_config() - # Add step to attach kv_cache to the model, if present within the config - apply_quantization_config(model, modifier_as_config) - model.apply(set_unset_kv_cache) - return modifier_as_config - - def _calibrate_if_possible(self, module: Module): - # TODO: @dsikka restructure such that all of calibration isn't happening - # on init - # flake8: noqa - """# noqa: E501 - Run calibration if running input/output activation quantization or kv_cache - quantization. + self._calibrate_if_possible(state) - Calibration Lifecycle for a single torch.nn.Module: - - initialize_observer(): - if input/output activation: - - observer = Observer.load_from_registry(...) - - module.register_module(f"{base_name}_observer", observer) + def on_end(self, state: State, event: Event, **kwargs): + """ + Finish calibrating by removing observers and calibration hooks + """ + QuantizationMixin.end_calibration( + self, state.model + ) # keep quantization enabled - register_calibration_hooks(): - if input activation and not dynamic quant (used to call observers before intput QDQ): - - pre_hook := calibrate_input_hook - if output activation and not dynamic quant (used to call observers before output QDQ): - - post_hook := calibrate_kv_cache_output_hook - if kv_cache quantization (used to set kv_cache to QuantizedKVParameterCache and update k_scale/v_scale) - - pre_hook := calibrate_kv_cache_input_hook - - post_hook := calibrate_kv_cache_output_hook + def on_finalize(self, state: State, **kwargs) -> bool: + # TODO: modify lifecycle so modifiers end on finalize + if not self.ended_: + self.on_end(state, None) - self._calibrate(module) # run forward pass through model using calibration data - set_unset_kv_cache() # remove kv_cache objects attached to attention layers - # initially set in _apply_modifier_to_model - remove calibration hooks in self.calibration_hooks_ - remove observers + def _calibrate_if_possible(self, state: State): + model = state.model + calibration_dataloader = state.data.calib + config = QuantizationMixin.resolve_quantization_config(self) - """ - if self.num_calibration_steps == 0 and self.calibration_dataloader_: - logger.warning( - f"num_calibration_steps is {self.num_calibration_steps}." - f"Calibration data loader will not be used." - ) - elif self.num_calibration_steps and not self.calibration_dataloader_: + has_calibration_data = calibration_dataloader is not None + requires_calibration = config.requires_calibration_data() + if requires_calibration and not has_calibration_data: raise ValueError( - f"num_calibration_steps is {self.num_calibration_steps}. " - "Calibration data loader is not set. Pass a " - "calibration_data_loader with initialize(...) method." + "The provided quantization configuration requires calibration data " + "but none was provided. Calibration data is required for static " + "quantization of input or output activations." + ) + if not requires_calibration and has_calibration_data: + logger.info( + "Skipping QuantizationModifier calibration, it is not required for " + "the provided quantization config." ) - - elif not self.calibration_dataloader_: return - module.apply(lambda model: initialize_observer(model, base_name="input")) - module.apply(lambda model: initialize_observer(model, base_name="output")) - module.apply(self.register_calibration_hooks) - self._calibrate(module) - module.apply(set_unset_kv_cache) - self.remove_hooks() - - def register_calibration_hooks(self, module: Module): - """ - Register hooks for input/output activation or kv_cache quantization. - """ - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: + if not requires_calibration: return - is_attention_module_ = is_attention_module(module) - input_quant = quantization_scheme.input_activations - output_quant = quantization_scheme.output_activations - - calibrate_inputs = ( - input_quant and not is_attention_module_ and not input_quant.dynamic - ) - - # Calibrate inputs if an input_quant is provided and not running dynamic quant - if calibrate_inputs: - self.register_hook(module, calibrate_input_hook, "forward_pre") - - if output_quant: - # hooks for attn modules if running kv_cache quant - if is_attention_module_: - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, - ) + self._calibrate(model, calibration_dataloader) - self.register_hook(module, calibrate_kv_cache_output_hook, "forward") - - # hooks for output quant if not running dynamic quant - elif not output_quant.dynamic: - self.register_hook(module, calibrate_output_hook, "forward") - - def _calibrate(self, module: Module): + def _calibrate(self, module: torch.nn.Module, data: torch.utils.data.DataLoader): class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " - f"{len(self.calibration_dataloader_)} samples..." - ) + logger.info(f"Running {class_name} calibration with {len(data)} samples...") with calibration_forward_context(module): - run_calibration_forward( - module, - self.calibration_dataloader_, - self.num_calibration_steps, - self.calibration_function_, - ) - - def _check_token_distribution( - self, model: Module, threshold: Optional[float] = None - ): - """ - A helper function that warns when a module has seen - fewer than threshold % of all the tokens throughout - the calibration process. - Checks are only triggered if threshold is not None. - :param model: the model to validate - :param threshold: the minimum percentage of tokens - (out of all the tokens in a batch) a module should - receive during calibration - """ - - if self.calibration_dataloader_ is None: - logger.debug("Skipping token distribution check. No calibration data.") - return - - if not is_moe_model(model): - logger.debug("Skipping token distribution check. Not a MoE model.") - return - - if threshold is None: - logger.warning( - "Mixture of Experts model detected, but threshold not set. " - "Defaulting token threshold to 1/num_experts." - ) - - if not hasattr(model.config, "num_local_experts"): - logger.warning( - "Mixture of Experts model detected but `num_local_experts` " - "not found in model config. Skipping distribution check." - ) - return - - threshold = 1 / model.config.num_local_experts - logger.debug(f"Setting token threshold to {threshold}.") - - all_tokens = self.calibration_dataloader_.dataset["input_ids"] - total_token_count = sum(len(sample) for sample in all_tokens) - counter = get_observer_token_count(model) - for module_name, token_count in counter.items(): - if token_count is None: - # the module has not been observed - # or its token_count is not being recorded - # by the observer (refer to the observer's - # implementation in the source code) - continue - if token_count / total_token_count < threshold: - logger.warning( - f"The module_name: {module_name} " - f"received less than {int(threshold * 100)}% " - "of calibration batch tokens " - f"({token_count}/{total_token_count} tokens). " - "This could harm the quantization quality." - ) + run_calibration_forward(module, data) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py new file mode 100644 index 000000000..75e5e8935 --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -0,0 +1,271 @@ +from typing import Any, Dict, List, Optional, Set, Union + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, + apply_quantization_config, + disable_quantization, + enable_quantization, + is_attention_module, + is_preset_scheme, + preset_name_to_scheme, +) +from pydantic import Field, PrivateAttr, field_validator +from torch.utils.hooks import RemovableHandle + +from llmcompressor.modifiers.quantization.calibration import ( + apply_calibration_status, + calibrate_input_hook, + calibrate_kv_cache_input_hook, + calibrate_kv_cache_output_hook, + calibrate_output_hook, + freeze_module_quantization, + initialize_observer, + initialize_quantized_kv_cache, + reset_quantization_status, +) +from llmcompressor.modifiers.utils.hooks import HooksMixin + +__all__ = ["QuantizationMixin"] + + +class QuantizationMixin(HooksMixin): + """ + Mixin which enables a Modifier to act as a quantization config, attching observers, + calibration hooks, and compression wrappers to modifiers + + Lifecycle: + - on_initialize: QuantizationMixin.initialize_quantization + - Attach schemes to modules + - Attach observers to modules + - Disable quantization until calibration starts/finishes + - on_start: QuantizationMixin.start_calibration + - Attach calibration hooks + - Apply calibration status + - Enable quantization during calibration + - on_end: QuantizationMixin.end_calibration + - Remove calibration hooks + - Apply freeze status + - Keep quantization enabled for future steps + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param targets: list of layer names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param scheme: a single quantization scheme to apply to the model. This is a + dictionary that supports all keys from QuantizationScheme except targets, which + will be set to the targets parameter set at the modifier level. Can also be set + to a dictionary of the format `preset_scheme_name: targets` for example: + `W8A8: ['Linear']` for weight and activation 8-bit. + :param kv_cache_scheme: optional QuantizationArgs, that specify the + quantization of the kv cache. If None, kv cache is not quantized. + When applying kv cache quantization to transformer AutoModelForCausalLM, + the kv_cache_scheme gets converted into a QuantizationScheme that: + - targets the `q_proj` and `k_proj` modules of the model. The outputs + of those modules are the keys and values that might be cached + - quantizes the outputs of the aformentioned layers, so that + keys and values are compressed before storing them in the cache + There is an explicit assumption that the model contains modules with + `k_proj` and `v_proj` in their names. If this is not the case + and kv_cache_scheme != None, the quantization of kv cache will fail + """ + + config_groups: Optional[Dict[str, QuantizationScheme]] = None + targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) + ignore: List[str] = Field(default_factory=list) + scheme: Optional[Union[str, Dict[str, Any]]] = None + kv_cache_scheme: Optional[QuantizationArgs] = None + + _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) + + @field_validator("targets", mode="before") + def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: + if isinstance(value, str): + return [value] + + return value + + @field_validator("scheme", mode="before") + def validate_scheme( + cls, value: Optional[Union[str, Dict[str, Any]]] + ) -> Optional[Union[str, Dict[str, Any]]]: + if isinstance(value, str) and not is_preset_scheme(value): + raise ValueError( + "`scheme` must either be a preset scheme name or a dictionary " + "of preset scheme names" + ) + + if isinstance(value, dict): + for scheme_name in value.keys(): + cls.validate_scheme(scheme_name) + + for key, target in value.items(): + value[key] = cls.validate_targets(target) + + return value + + def initialize_quantization(self, model: torch.nn.Module): + """ + Attach quantization schemes and observers to modules in the model according to + the quantization config specified on this modifier + + :param model: model to attach schemes and observers to + """ + reset_quantization_status(model) # reset any previously applied qconfigs + + # apply scheme and status to model + config = self.resolve_quantization_config() + apply_quantization_config(model, config) + + # apply observers, disable quantization until calibration + model.apply(self._initialize_observers) + model.apply(disable_quantization) + + def start_calibration(self, model: torch.nn.Module): + """ + Register activation calibration hooks (including kv_cache quantization) and + enable quantization as we calibrate + + :param model: model to prepare for calibration + """ + self._calibration_hooks = self._initialize_hooks(model) + model.apply(apply_calibration_status) + model.apply(enable_quantization) # quantize at the same time as calibrate + + def end_calibration(self, model: torch.nn.Module): + """ + Remove calibration hooks and set the model status to frozen. Keep quantization + enabled for future operations + + :param model: model to end calibration for + """ + self.remove_hooks(self._calibration_hooks) + model.apply(freeze_module_quantization) # remove observers + model.apply(enable_quantization) # keep quantization enabled + + def has_config(self) -> bool: + """ + Determine if the user has specified a quantization config on this modifier + """ + return not ( + self.config_groups is None + and self.targets == ["Linear"] + and self.ignore == [] + and self.scheme is None + and self.kv_cache_scheme is None + ) + + def resolve_quantization_config(self) -> QuantizationConfig: + """ + Returns the quantization config specified by this modifier + """ + scheme = self.scheme + targets = self.targets + config_groups = self.config_groups + kv_cache_scheme = self.kv_cache_scheme + ignore = self.ignore + + if scheme is not None and config_groups is not None: + raise ValueError("Please specify either `scheme` or `config_groups`") + + if scheme is not None: + # takes precedence over config_groups + + if isinstance(scheme, str) and is_preset_scheme(scheme): + # attach targets to scheme + scheme = {scheme: targets} + + config_groups = {} + for idx, key in enumerate(scheme.keys()): + if is_preset_scheme(key): + scheme = preset_name_to_scheme(key, scheme[key]) + else: + scheme = QuantizationScheme.model_validate( + {"targets": scheme[key], **scheme} + ) + + group_name = f"group_{idx}" + config_groups[group_name] = scheme + + if config_groups is None or len(config_groups) == 0: + default_quant_scheme = QuantizationScheme(targets=targets) + config_groups = {"group_0": default_quant_scheme} + + return QuantizationConfig( + config_groups=config_groups, + kv_cache_scheme=kv_cache_scheme, + quantization_status=QuantizationStatus.INITIALIZED, + ignore=ignore, + ) + + def _initialize_observers(self, module: torch.nn.Module): + if not hasattr(module, "quantization_scheme"): + return + + scheme: QuantizationScheme = module.quantization_scheme + input = scheme.input_activations and not scheme.input_activations.dynamic + weight = scheme.weights is not None + output = scheme.output_activations and not scheme.output_activations.dynamic + is_attention = is_attention_module(module) + + # input activations + if input: + initialize_observer(module, base_name="input") + + # weight observers (used by `update_weight_zp_scale` or child modifier) + if weight: + initialize_observer(module, base_name="weight") + + # kv_cache activations. Within `apply_quantization_config`, the config is + # modified to use attention output quantization if a kv_cache_scheme exists + if is_attention and output: + initialize_quantized_kv_cache(module) + + # output activations + elif output: + initialize_observer(module, base_name="output") + + def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: + hooks = set() + for module in model.modules(): + if not hasattr(module, "quantization_scheme"): + continue + + scheme: QuantizationScheme = module.quantization_scheme + input = scheme.input_activations and not scheme.input_activations.dynamic + output = scheme.output_activations and not scheme.output_activations.dynamic + is_attention = is_attention_module(module) + + # input activations + if input: + hooks.add( + self.register_hook(module, calibrate_input_hook, "forward_pre") + ) + + # kv_cache activations. Within `apply_quantization_config`, the config is + # modified to use attention output quantization if a kv_cache_scheme exists + if is_attention and output: + hooks.add( + self.register_hook( + module, + calibrate_kv_cache_input_hook, + "forward_pre", + with_kwargs=True, + ) + ) + hooks.add( + self.register_hook( + module, calibrate_kv_cache_output_hook, "forward" + ) + ) + + # output activations + elif output: + hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) + + return hooks diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py index 25b8468f4..b22e7ec40 100644 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py @@ -25,7 +25,7 @@ calibrate_kv_cache_input_hook, calibrate_kv_cache_output_hook, freeze_module_quantization, - set_unset_kv_cache, + initialize_quantized_kv_cache, ) config = { @@ -75,7 +75,7 @@ def test_kv_cache_quantization(config): config = QuantizationConfig(**config) config.quantization_status = QuantizationStatus.CALIBRATION apply_quantization_config(model, config) - model.apply(set_unset_kv_cache) + model.apply(initialize_quantized_kv_cache) model.apply(_prep_for_calibration) with torch.no_grad(): diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index 11e630c19..2a8c58ea4 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -2,7 +2,6 @@ import pytest -from llmcompressor.core.events import Event from llmcompressor.modifiers.factory import ModifierFactory from llmcompressor.modifiers.quantization import QuantizationModifier from tests.llmcompressor.modifiers.conf import setup_modifier_factory @@ -25,51 +24,3 @@ def test_quantization_registered(self): ) self.assertIsInstance(quant_obj, QuantizationModifier) - - -@pytest.mark.unit -class TestEndEpochs(unittest.TestCase): - def setUp(self): - self.start = 0.0 - self.scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=6, symmetric=False), - ) - - def test_end_epochs(self): - disable_quant_epoch = None - obj_modifier = QuantizationModifier( - start=self.start, - scheme=self.scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - config_groups={}, - ) - - self.assertEqual(obj_modifier.calculate_disable_observer_epoch(), -1) - - for epoch in range(3): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - - disable_quant_epoch = 3.5 - obj_modifier = QuantizationModifier( - start=self.start, - scheme=self.scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - config_groups={}, - ) - - self.assertEqual( - obj_modifier.calculate_disable_observer_epoch(), disable_quant_epoch - ) - - for epoch in range(4): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - - event = Event(steps_per_epoch=1, global_step=4) - assert obj_modifier.check_should_disable_observer(event) - - for epoch in range(5, 8): - event = Event(steps_per_epoch=1, global_step=epoch) - assert obj_modifier.check_should_disable_observer(event) diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index ab63a5414..14fd7dcb8 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -1,13 +1,11 @@ import unittest import pytest -from compressed_tensors.quantization import QuantizationScheme +import torch from parameterized import parameterized from llmcompressor.modifiers.obcq import SparseGPTModifier from llmcompressor.modifiers.quantization.gptq import GPTQModifier -from llmcompressor.modifiers.quantization.quantization import QuantizationModifier -from llmcompressor.utils.pytorch.module import qat_active from tests.llmcompressor.modifiers.conf import ( LifecyleTestingHarness, setup_modifier_factory, @@ -62,50 +60,26 @@ def test_successful_layerwise_recipe(self): @pytest.mark.unit -class TestCreateDefaultQuantModifier(unittest.TestCase): +class TestApplyQuantization(unittest.TestCase): def setUp(self): setup_modifier_factory() def test_create_default_quant_modifier(self): - modifier = GPTQModifier(block_size=128) - assert modifier._quantization_modifier is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier._check_build_quant_modifier(testing_harness.get_state().model) - assert modifier.quantize - assert isinstance(modifier._quantization_modifier, QuantizationModifier) - modifier._quantization_modifier.create_init_config() - default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ - default_config_group_name - ] - assert should_be_default_quant_scheme.input_activations is None - assert should_be_default_quant_scheme.weights is None - - -@pytest.mark.unit -class TestSetQuantIfModifierAlreadyExists(unittest.TestCase): - def setUp(self): - setup_modifier_factory() + modifier = GPTQModifier(block_size=128, targets=["Linear"], scheme="FP8") - def test_set_quant_if_modifer_already_exists(self): - model = LinearNet() - scheme = QuantizationScheme( - targets=["Linear"], - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ) - - modifier = QuantizationModifier(config_groups={"group_0": scheme}) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - assert not qat_active(testing_harness.get_state().model) + testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) modifier.initialize(testing_harness.get_state()) - assert qat_active(testing_harness.get_state().model) - modifier = GPTQModifier(block_size=128) - assert not modifier._quantization_modifier - assert modifier.quantize + model = testing_harness.state.model + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + assert hasattr(module, "quantization_scheme") + assert hasattr(module, "input_observer") + assert hasattr(module, "weight_observer") + pre_hooks = list(module._forward_pre_hooks.values()) + post_hooks = list(module._forward_hooks.values()) + assert pre_hooks[0].__name__ == "calibrate_input_hook" + assert post_hooks[0].__name__ == "calibrate_module" class TestSetQuantInGPTQ(unittest.TestCase): @@ -131,24 +105,17 @@ def setUp(self): } } } - self.quant_config = {"QuantizationModifier": self.quant_kwargs} def test_set_quant_in_gptq(self): - modifier = GPTQModifier(block_size=128, quantize=self.quant_config) - assert modifier._quantization_modifier is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier._check_build_quant_modifier(testing_harness.get_state().model) - assert modifier.quantize - self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) + modifier = GPTQModifier(block_size=128, **self.quant_kwargs) + config = modifier.resolve_quantization_config() - dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( - dict(dict_scheme["config_group_0"].weights), + dict(config.config_groups["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"], ) self._check_config( - dict(dict_scheme["config_group_0"].input_activations), + dict(config.config_groups["config_group_0"].input_activations), self.quant_kwargs["config_groups"]["config_group_0"]["input_activations"], )