3
3
import torch
4
4
from compressed_tensors .quantization import (
5
5
KVCacheScaleType ,
6
+ QuantizationScheme ,
6
7
QuantizationStatus ,
7
- is_attention_module ,
8
8
)
9
9
from compressed_tensors .quantization .lifecycle .forward import forward_quantize
10
10
from compressed_tensors .quantization .utils import is_kv_cache_quant_scheme
14
14
15
15
from llmcompressor .modifiers .quantization .cache import QuantizedKVParameterCache
16
16
from llmcompressor .observers import Observer
17
+ from llmcompressor .utils .helpers import getattr_chain
17
18
18
19
__all__ = [
19
20
"initialize_observer" ,
22
23
"calibrate_output_hook" ,
23
24
"calibrate_kv_cache_input_hook" ,
24
25
"calibrate_kv_cache_output_hook" ,
25
- "set_unset_kv_cache " ,
26
+ "initialize_quantized_kv_cache " ,
26
27
"freeze_module_quantization" ,
27
28
"apply_calibration_status" ,
28
29
]
@@ -49,10 +50,6 @@ def initialize_observer(
49
50
# no quantization scheme nothing to do
50
51
return
51
52
52
- # observers have a different lifecycle for kv_cache
53
- if is_attention_module (module ):
54
- return
55
-
56
53
quantization_args = getattr (quantization_scheme , arg_name , None )
57
54
# dont need observers for dynamic
58
55
if quantization_args is not None and not quantization_args .dynamic :
@@ -102,25 +99,15 @@ def update_weight_zp_scale(module: Module):
102
99
:param quantize_weights_upfront: whether to automatically
103
100
run weight quantization at the start of calibration
104
101
"""
105
- if not getattr (module , "quantization_scheme" , None ):
106
- # no quantization scheme nothing to do
102
+ if getattr_chain (module , "quantization_scheme.weights" , None ) is None :
107
103
return
108
104
109
- status = getattr (module , "quantization_status" , None )
110
- if not status :
111
- # not set to initialize; no scales/zp to update
112
- return
113
- if status != QuantizationStatus .INITIALIZED :
105
+ if getattr (module , "quantization_status" , None ) != QuantizationStatus .CALIBRATION :
114
106
logger .warning (
115
- f"Attempting set module with status { status } to calibration mode. "
116
- f"but status is not { QuantizationStatus .INITIALIZED } - you may "
117
- "be calibrating an uninitialized module which may fail or attempting "
118
- "to re-calibrate a frozen module"
107
+ "Attempting to calibrate weights of a module not in calibration mode"
119
108
)
120
109
121
- if module .quantization_scheme .weights is not None :
122
- # set weight scale and zero_point up front, calibration data doesn't affect it
123
- call_observer (module = module , base_name = "weight" )
110
+ call_observer (module = module , base_name = "weight" )
124
111
125
112
126
113
def calibrate_activations (module : Module , value : torch .Tensor , base_name : str ):
@@ -200,21 +187,26 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
200
187
update_parameter_data (module , v_scale , KVCacheScaleType .VALUE .value )
201
188
202
189
203
- def set_unset_kv_cache (module : Module ):
190
+ def initialize_quantized_kv_cache (module : Module ):
204
191
"""
205
- Set or unset singleton QuantizedKVParameterCache for each
206
- attn module when running kv_cache quantization.
192
+ Initialize a quantized kv_cache on a module (analogous to initializing an observer)
193
+ When a config specifying kv_cache quantization is applied to a model, the kv_cache
194
+ args are redefined as the output_activations targeting attention modules.
195
+
196
+ This function should be called on attention modules with output_activations
207
197
"""
208
- if not hasattr (module , "quantization_scheme" ):
198
+ scheme : Optional [QuantizationScheme ] = getattr (module , "quantization_scheme" , None )
199
+ existing_kv_cache = getattr (module , "kv_cache" , None )
200
+
201
+ if (
202
+ scheme is None
203
+ or not is_kv_cache_quant_scheme (scheme )
204
+ or isinstance (existing_kv_cache , QuantizedKVParameterCache )
205
+ ):
209
206
return
210
207
211
- if is_kv_cache_quant_scheme (module .quantization_scheme ):
212
- output_args = module .quantization_scheme .output_activations
213
- kv_cache = QuantizedKVParameterCache (output_args )
214
- if hasattr (module , "kv_cache" ):
215
- delattr (module , "kv_cache" )
216
- else :
217
- setattr (module , "kv_cache" , kv_cache )
208
+ quantized_kv_cache = QuantizedKVParameterCache (scheme .output_activations )
209
+ setattr (module , "kv_cache" , quantized_kv_cache )
218
210
219
211
220
212
def apply_calibration_status (module : Module ):
@@ -242,9 +234,15 @@ def freeze_module_quantization(module: Module):
242
234
# nothing to do, already frozen
243
235
return
244
236
237
+ # remove observers
245
238
for name in ("input" , "weight" , "output" ):
246
239
obs_name = f"{ name } _observer"
247
240
if hasattr (module , obs_name ):
248
241
delattr (module , obs_name )
249
242
243
+ # remove quantized kv_cache
244
+ kv_cache = getattr (module , "kv_cache" , None )
245
+ if isinstance (kv_cache , QuantizedKVParameterCache ):
246
+ delattr (module , "kv_cache" )
247
+
250
248
module .quantization_status = QuantizationStatus .FROZEN
0 commit comments