7
7
QuantizationScheme ,
8
8
QuantizationStatus ,
9
9
apply_quantization_config ,
10
+ disable_quantization ,
11
+ enable_quantization ,
10
12
is_attention_module ,
11
13
is_preset_scheme ,
12
14
preset_name_to_scheme ,
13
15
)
14
- from pydantic import Field , field_validator
16
+ from pydantic import Field , PrivateAttr , field_validator
17
+ from torch .utils .hooks import RemovableHandle
15
18
16
19
from llmcompressor .modifiers .quantization .calibration import (
20
+ apply_calibration_status ,
17
21
calibrate_input_hook ,
18
22
calibrate_kv_cache_input_hook ,
19
23
calibrate_kv_cache_output_hook ,
20
24
calibrate_output_hook ,
25
+ freeze_module_quantization ,
21
26
initialize_observer ,
22
27
initialize_quantized_kv_cache ,
23
28
reset_quantization_status ,
@@ -33,18 +38,18 @@ class QuantizationMixin(HooksMixin):
33
38
calibration hooks, and compression wrappers to modifiers
34
39
35
40
Lifecycle:
36
- - QuantizationMixin.attach_scheme_and_observers(model)
37
- - Wraps model forward and attaches quantization scheme and observers
38
- - QuantizationMixin.register_calibration_hooks(model)
39
- - Registers calibration hooks which utilize observers to calibrate qparams
40
- - model.apply(apply_calibration_status)
41
- - [ Calibrate model ]
42
- - model.apply(freeze_module_quantization)
43
- - Remove observers
44
- - self.remove_hooks()
41
+ - on_initialize: QuantizationMixin.initialize_quantization
42
+ - Attach schemes to modules
43
+ - Attach observers to modules
44
+ - Disable quantization until calibration starts/finishes
45
+ - on_start: QuantizationMixin.start_calibration
46
+ - Attach calibration hooks
47
+ - Apply calibration status
48
+ - Enable quantization during calibration
49
+ - on_end: QuantizationMixin.end_calibration
45
50
- Remove calibration hooks
46
-
47
- Scheme is left attached to modules after PTQ finishes
51
+ - Apply freeze status
52
+ - Keep quantization enabled for future steps
48
53
49
54
:param config_groups: dictionary specifying quantization schemes to apply to target
50
55
modules. Modules not matching a scheme target will NOT be quantized.
@@ -76,6 +81,8 @@ class QuantizationMixin(HooksMixin):
76
81
scheme : Optional [Union [str , Dict [str , Any ]]] = None
77
82
kv_cache_scheme : Optional [QuantizationArgs ] = None
78
83
84
+ _calibration_hooks : List [RemovableHandle ] = PrivateAttr (default_factory = list )
85
+
79
86
@field_validator ("targets" , mode = "before" )
80
87
def validate_targets (cls , value : Union [str , List [str ]]) -> List [str ]:
81
88
if isinstance (value , str ):
@@ -102,25 +109,49 @@ def validate_scheme(
102
109
103
110
return value
104
111
105
- def attach_scheme_and_observers (self , model : torch .nn .Module ):
112
+ def initialize_quantization (self , model : torch .nn .Module ):
106
113
"""
107
- Apply this modifier as a quantization config to the model. Attach observers
108
- according to the schemes attached to each module
114
+ Attach quantization schemes and observers to modules in the model according to
115
+ the quantization config specified on this modifier
116
+
117
+ :param model: model to attach schemes and observers to
109
118
"""
110
119
reset_quantization_status (model ) # reset any previously applied qconfigs
111
120
121
+ # apply scheme and status to model
112
122
config = self .resolve_quantization_config ()
113
123
apply_quantization_config (model , config )
114
124
125
+ # apply observers, disable quantization until calibration
115
126
model .apply (self ._initialize_observers )
127
+ model .apply (disable_quantization )
128
+
129
+ def start_calibration (self , model : torch .nn .Module ):
130
+ """
131
+ Register activation calibration hooks (including kv_cache quantization) and
132
+ enable quantization as we calibrate
133
+
134
+ :param model: model to prepare for calibration
135
+ """
136
+ self ._calibration_hooks = self ._initialize_hooks (model )
137
+ model .apply (apply_calibration_status )
138
+ model .apply (enable_quantization ) # quantize at the same time as calibrate
116
139
117
- def register_calibration_hooks (self , model : torch .nn .Module ):
140
+ def end_calibration (self , model : torch .nn .Module ):
118
141
"""
119
- Register activation calibration hooks (including kv_cache quantization)
142
+ Remove calibration hooks and set the model status to frozen. Keep quantization
143
+ enabled for future operations
144
+
145
+ :param model: model to end calibration for
120
146
"""
121
- model .apply (self ._initialize_hooks )
147
+ self .remove_hooks (self ._calibration_hooks )
148
+ model .apply (freeze_module_quantization ) # remove observers
149
+ model .apply (enable_quantization ) # keep quantization enabled
122
150
123
151
def has_config (self ) -> bool :
152
+ """
153
+ Determine if the user has specified a quantization config on this modifier
154
+ """
124
155
return not (
125
156
self .config_groups is None
126
157
and self .targets == ["Linear" ]
@@ -199,27 +230,44 @@ def _initialize_observers(self, module: torch.nn.Module):
199
230
elif output :
200
231
initialize_observer (module , base_name = "output" )
201
232
202
- def _initialize_hooks (self , module : torch .nn .Module ):
203
- if not hasattr (module , "quantization_scheme" ):
204
- return
205
-
206
- scheme : QuantizationScheme = module .quantization_scheme
207
- input = scheme .input_activations and not scheme .input_activations .dynamic
208
- output = scheme .output_activations and not scheme .output_activations .dynamic
209
- is_attention = is_attention_module (module )
210
-
211
- # input activations
212
- if input :
213
- self .register_hook (module , calibrate_input_hook , "forward_pre" )
233
+ def _initialize_hooks (self , model : torch .nn .Module ) -> List [RemovableHandle ]:
234
+ hooks = []
235
+ for module in model .modules ():
236
+ if not hasattr (module , "quantization_scheme" ):
237
+ continue
238
+
239
+ scheme : QuantizationScheme = module .quantization_scheme
240
+ input = scheme .input_activations and not scheme .input_activations .dynamic
241
+ output = scheme .output_activations and not scheme .output_activations .dynamic
242
+ is_attention = is_attention_module (module )
243
+
244
+ # input activations
245
+ if input :
246
+ hooks .append (
247
+ self .register_hook (module , calibrate_input_hook , "forward_pre" )
248
+ )
249
+
250
+ # kv_cache activations. Within `apply_quantization_config`, the config is
251
+ # modified to use attention output quantization if a kv_cache_scheme exists
252
+ if is_attention and output :
253
+ hooks .append (
254
+ self .register_hook (
255
+ module ,
256
+ calibrate_kv_cache_input_hook ,
257
+ "forward_pre" ,
258
+ with_kwargs = True ,
259
+ )
260
+ )
261
+ hooks .append (
262
+ self .register_hook (
263
+ module , calibrate_kv_cache_output_hook , "forward"
264
+ )
265
+ )
214
266
215
- # kv_cache activations. Within `apply_quantization_config`, the config is
216
- # modified to use attention output quantization if a kv_cache_scheme exists
217
- if is_attention and output :
218
- self .register_hook (
219
- module , calibrate_kv_cache_input_hook , "forward_pre" , with_kwargs = True
220
- )
221
- self .register_hook (module , calibrate_kv_cache_output_hook , "forward" )
267
+ # output activations
268
+ elif output :
269
+ hooks .append (
270
+ self .register_hook (module , calibrate_output_hook , "forward" )
271
+ )
222
272
223
- # output activations
224
- elif output :
225
- self .register_hook (module , calibrate_output_hook , "forward" )
273
+ return hooks
0 commit comments