Skip to content
Merged
67 changes: 36 additions & 31 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from compressed_tensors.quantization import (
KVCacheScaleType,
QuantizationScheme,
QuantizationStatus,
is_attention_module,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
Expand All @@ -14,6 +14,7 @@

from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

__all__ = [
"initialize_observer",
Expand All @@ -22,9 +23,10 @@
"calibrate_output_hook",
"calibrate_kv_cache_input_hook",
"calibrate_kv_cache_output_hook",
"set_unset_kv_cache",
"initialize_quantized_kv_cache",
"freeze_module_quantization",
"apply_calibration_status",
"reset_quantization_status",
]


Expand All @@ -49,10 +51,6 @@ def initialize_observer(
# no quantization scheme nothing to do
return

# observers have a different lifecycle for kv_cache
if is_attention_module(module):
return

quantization_args = getattr(quantization_scheme, arg_name, None)
# dont need observers for dynamic
if quantization_args is not None and not quantization_args.dynamic:
Expand Down Expand Up @@ -102,25 +100,15 @@ def update_weight_zp_scale(module: Module):
:param quantize_weights_upfront: whether to automatically
run weight quantization at the start of calibration
"""
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return

status = getattr(module, "quantization_status", None)
if not status:
# not set to initialize; no scales/zp to update
return
if status != QuantizationStatus.INITIALIZED:
if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
logger.warning(
f"Attempting set module with status {status} to calibration mode. "
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
"be calibrating an uninitialized module which may fail or attempting "
"to re-calibrate a frozen module"
"Attempting to calibrate weights of a module not in calibration mode"
)

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
call_observer(module=module, base_name="weight")
call_observer(module=module, base_name="weight")


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
Expand Down Expand Up @@ -200,21 +188,26 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)


def set_unset_kv_cache(module: Module):
def initialize_quantized_kv_cache(module: Module):
"""
Set or unset singleton QuantizedKVParameterCache for each
attn module when running kv_cache quantization.
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
When a config specifying kv_cache quantization is applied to a model, the kv_cache
args are redefined as the output_activations targeting attention modules.

This function should be called on attention modules with output_activations
"""
if not hasattr(module, "quantization_scheme"):
scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
existing_kv_cache = getattr(module, "kv_cache", None)

if (
scheme is None
or not is_kv_cache_quant_scheme(scheme)
or isinstance(existing_kv_cache, QuantizedKVParameterCache)
):
return

if is_kv_cache_quant_scheme(module.quantization_scheme):
output_args = module.quantization_scheme.output_activations
kv_cache = QuantizedKVParameterCache(output_args)
if hasattr(module, "kv_cache"):
delattr(module, "kv_cache")
else:
setattr(module, "kv_cache", kv_cache)
quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
setattr(module, "kv_cache", quantized_kv_cache)


def apply_calibration_status(module: Module):
Expand Down Expand Up @@ -242,9 +235,21 @@ def freeze_module_quantization(module: Module):
# nothing to do, already frozen
return

# remove observers
for name in ("input", "weight", "output"):
obs_name = f"{name}_observer"
if hasattr(module, obs_name):
delattr(module, obs_name)

# remove quantized kv_cache
kv_cache = getattr(module, "kv_cache", None)
if isinstance(kv_cache, QuantizedKVParameterCache):
delattr(module, "kv_cache")

module.quantization_status = QuantizationStatus.FROZEN


def reset_quantization_status(model: Module):
for module in model.modules():
if hasattr(module, "quantization_status"):
delattr(module, "quantization_status")
Loading