Skip to content

Commit fa75986

Browse files
committed
Implement QuantizationMixin
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 72c87aa commit fa75986

File tree

9 files changed

+384
-579
lines changed

9 files changed

+384
-579
lines changed

src/llmcompressor/modifiers/modifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):
8989

9090
self.initialized_ = self.on_initialize(state=state, **kwargs)
9191

92-
# trigger start
92+
# trigger starts
9393
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
9494
if self.should_start(fake_start_event):
9595
self.on_start(state, fake_start_event, **kwargs)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
from compressed_tensors.quantization import (
55
KVCacheScaleType,
6+
QuantizationScheme,
67
QuantizationStatus,
7-
is_attention_module,
88
)
99
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1010
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -14,6 +14,7 @@
1414

1515
from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
1616
from llmcompressor.observers import Observer
17+
from llmcompressor.utils.helpers import getattr_chain
1718

1819
__all__ = [
1920
"initialize_observer",
@@ -22,7 +23,7 @@
2223
"calibrate_output_hook",
2324
"calibrate_kv_cache_input_hook",
2425
"calibrate_kv_cache_output_hook",
25-
"set_unset_kv_cache",
26+
"initialize_quantized_kv_cache",
2627
"freeze_module_quantization",
2728
"apply_calibration_status",
2829
]
@@ -49,10 +50,6 @@ def initialize_observer(
4950
# no quantization scheme nothing to do
5051
return
5152

52-
# observers have a different lifecycle for kv_cache
53-
if is_attention_module(module):
54-
return
55-
5653
quantization_args = getattr(quantization_scheme, arg_name, None)
5754
# dont need observers for dynamic
5855
if quantization_args is not None and not quantization_args.dynamic:
@@ -102,25 +99,15 @@ def update_weight_zp_scale(module: Module):
10299
:param quantize_weights_upfront: whether to automatically
103100
run weight quantization at the start of calibration
104101
"""
105-
if not getattr(module, "quantization_scheme", None):
106-
# no quantization scheme nothing to do
102+
if getattr_chain(module, "quantization_scheme.weights", None) is None:
107103
return
108104

109-
status = getattr(module, "quantization_status", None)
110-
if not status:
111-
# not set to initialize; no scales/zp to update
112-
return
113-
if status != QuantizationStatus.INITIALIZED:
105+
if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
114106
logger.warning(
115-
f"Attempting set module with status {status} to calibration mode. "
116-
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
117-
"be calibrating an uninitialized module which may fail or attempting "
118-
"to re-calibrate a frozen module"
107+
"Attempting to calibrate weights of a module not in calibration mode"
119108
)
120109

121-
if module.quantization_scheme.weights is not None:
122-
# set weight scale and zero_point up front, calibration data doesn't affect it
123-
call_observer(module=module, base_name="weight")
110+
call_observer(module=module, base_name="weight")
124111

125112

126113
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
@@ -200,21 +187,26 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
200187
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
201188

202189

203-
def set_unset_kv_cache(module: Module):
190+
def initialize_quantized_kv_cache(module: Module):
204191
"""
205-
Set or unset singleton QuantizedKVParameterCache for each
206-
attn module when running kv_cache quantization.
192+
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
193+
When a config specifying kv_cache quantization is applied to a model, the kv_cache
194+
args are redefined as the output_activations targeting attention modules.
195+
196+
This function should be called on attention modules with output_activations
207197
"""
208-
if not hasattr(module, "quantization_scheme"):
198+
scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
199+
existing_kv_cache = getattr(module, "kv_cache", None)
200+
201+
if (
202+
scheme is None
203+
or not is_kv_cache_quant_scheme(scheme)
204+
or isinstance(existing_kv_cache, QuantizedKVParameterCache)
205+
):
209206
return
210207

211-
if is_kv_cache_quant_scheme(module.quantization_scheme):
212-
output_args = module.quantization_scheme.output_activations
213-
kv_cache = QuantizedKVParameterCache(output_args)
214-
if hasattr(module, "kv_cache"):
215-
delattr(module, "kv_cache")
216-
else:
217-
setattr(module, "kv_cache", kv_cache)
208+
quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
209+
setattr(module, "kv_cache", quantized_kv_cache)
218210

219211

220212
def apply_calibration_status(module: Module):
@@ -242,9 +234,15 @@ def freeze_module_quantization(module: Module):
242234
# nothing to do, already frozen
243235
return
244236

237+
# remove observers
245238
for name in ("input", "weight", "output"):
246239
obs_name = f"{name}_observer"
247240
if hasattr(module, obs_name):
248241
delattr(module, obs_name)
249242

243+
# remove quantized kv_cache
244+
kv_cache = getattr(module, "kv_cache", None)
245+
if isinstance(kv_cache, QuantizedKVParameterCache):
246+
delattr(module, "kv_cache")
247+
250248
module.quantization_status = QuantizationStatus.FROZEN

0 commit comments

Comments
 (0)