Skip to content

Commit e4debea

Browse files
committed
change QuantizationModifier functions
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 3935734 commit e4debea

File tree

4 files changed

+105
-65
lines changed

4 files changed

+105
-65
lines changed

src/llmcompressor/modifiers/modifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):
8989

9090
self.initialized_ = self.on_initialize(state=state, **kwargs)
9191

92-
# trigger starts
92+
# trigger start
9393
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
9494
if self.should_start(fake_start_event):
9595
self.on_start(state, fake_start_event, **kwargs)

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, List, Optional, Tuple, Union
44

55
import torch
6+
from compressed_tensors.quantization import disable_quantization
67
from compressed_tensors.utils import (
78
align_module_device,
89
get_execution_device,
@@ -14,10 +15,6 @@
1415

1516
from llmcompressor.core import State
1617
from llmcompressor.modifiers import Modifier
17-
from llmcompressor.modifiers.quantization.calibration import (
18-
apply_calibration_status,
19-
freeze_module_quantization,
20-
)
2118
from llmcompressor.modifiers.quantization.gptq.gptq_quantize import (
2219
accumulate_hessian,
2320
make_empty_hessian,
@@ -139,8 +136,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
139136
"""
140137
# apply config to model and prepare calibration hooks
141138
if QuantizationMixin.has_config(self):
142-
QuantizationMixin.attach_scheme_and_observers(self, state.model)
143-
QuantizationMixin.register_calibration_hooks(self, state.model)
139+
QuantizationMixin.initialize_quantization(self, state.model)
140+
141+
# assume quantization has been initialized by this modifier or one before it
142+
QuantizationMixin.start_calibration(self, state.model)
143+
# Unlike qmod, do not quantize as we calibrate
144+
# This choice does not seem to have a meaningful impact on accuracy
145+
state.model.apply(disable_quantization)
144146

145147
# prepare module names
146148
self._module_names = {m: name for name, m in state.model.named_modules()}
@@ -162,9 +164,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
162164
"modifier or a modifier preceding it"
163165
)
164166

165-
# prepare for calibration
166-
state.model.apply(apply_calibration_status)
167-
168167
# infer sequential targets
169168
if self.sequential_targets is None:
170169
self.sequential_targets = get_no_split_params(state.model)
@@ -233,8 +232,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:
233232
self._hessians = dict()
234233
self._num_samples = dict()
235234

236-
state.model.apply(freeze_module_quantization) # remove observers
237-
self.remove_hooks() # remove hooks
235+
QuantizationMixin.end_calibration(self, state.model)
236+
self.remove_hooks() # remove gptq hooks
238237

239238
return True
240239

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import torch
22
import tqdm
3-
from compressed_tensors.quantization import disable_quantization, enable_quantization
43
from loguru import logger
54

65
from llmcompressor.core import Event, State
76
from llmcompressor.modifiers import Modifier
8-
from llmcompressor.modifiers.quantization.calibration import (
9-
apply_calibration_status,
10-
freeze_module_quantization,
11-
update_weight_zp_scale,
12-
)
7+
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
138
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
149
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
1510
from llmcompressor.utils.helpers import calibration_forward_context
@@ -64,8 +59,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6459
"QuantizationModifier requires that quantization fields to be specified"
6560
)
6661

67-
QuantizationMixin.attach_scheme_and_observers(self, state.model)
68-
state.model.apply(disable_quantization) # disable quantization until start
62+
QuantizationMixin.initialize_quantization(self, state.model)
6963

7064
# FUTURE: modify oneshot lifecycle to trigger on_start for on initialize
7165
if self.calculate_start() == -1: # one shot
@@ -77,9 +71,7 @@ def on_start(self, state: State):
7771
"""
7872
Begin calibrating activations and weights. Calibrate weights only once on start
7973
"""
80-
QuantizationMixin.register_calibration_hooks(self, state.model)
81-
state.model.apply(apply_calibration_status)
82-
state.model.apply(enable_quantization)
74+
QuantizationMixin.start_calibration(self, state.model)
8375

8476
modules = list(state.model.modules())
8577
for module in tqdm.tqdm(modules, desc="Calibrating weights"):
@@ -93,8 +85,9 @@ def on_end(self, state: State, event: Event, **kwargs):
9385
"""
9486
Finish calibrating by removing observers and calibration hooks
9587
"""
96-
state.model.apply(freeze_module_quantization) # remove observers
97-
self.remove_hooks() # remove hooks
88+
QuantizationMixin.end_calibration(
89+
self, state.model
90+
) # keep quantization enabled
9891

9992
def on_finalize(self, state: State, **kwargs) -> bool:
10093
# TODO: modify lifecycle so modifiers end on finalize

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,22 @@
77
QuantizationScheme,
88
QuantizationStatus,
99
apply_quantization_config,
10+
disable_quantization,
11+
enable_quantization,
1012
is_attention_module,
1113
is_preset_scheme,
1214
preset_name_to_scheme,
1315
)
14-
from pydantic import Field, field_validator
16+
from pydantic import Field, PrivateAttr, field_validator
17+
from torch.utils.hooks import RemovableHandle
1518

1619
from llmcompressor.modifiers.quantization.calibration import (
20+
apply_calibration_status,
1721
calibrate_input_hook,
1822
calibrate_kv_cache_input_hook,
1923
calibrate_kv_cache_output_hook,
2024
calibrate_output_hook,
25+
freeze_module_quantization,
2126
initialize_observer,
2227
initialize_quantized_kv_cache,
2328
reset_quantization_status,
@@ -33,18 +38,18 @@ class QuantizationMixin(HooksMixin):
3338
calibration hooks, and compression wrappers to modifiers
3439
3540
Lifecycle:
36-
- QuantizationMixin.attach_scheme_and_observers(model)
37-
- Wraps model forward and attaches quantization scheme and observers
38-
- QuantizationMixin.register_calibration_hooks(model)
39-
- Registers calibration hooks which utilize observers to calibrate qparams
40-
- model.apply(apply_calibration_status)
41-
- [ Calibrate model ]
42-
- model.apply(freeze_module_quantization)
43-
- Remove observers
44-
- self.remove_hooks()
41+
- on_initialize: QuantizationMixin.initialize_quantization
42+
- Attach schemes to modules
43+
- Attach observers to modules
44+
- Disable quantization until calibration starts/finishes
45+
- on_start: QuantizationMixin.start_calibration
46+
- Attach calibration hooks
47+
- Apply calibration status
48+
- Enable quantization during calibration
49+
- on_end: QuantizationMixin.end_calibration
4550
- Remove calibration hooks
46-
47-
Scheme is left attached to modules after PTQ finishes
51+
- Apply freeze status
52+
- Keep quantization enabled for future steps
4853
4954
:param config_groups: dictionary specifying quantization schemes to apply to target
5055
modules. Modules not matching a scheme target will NOT be quantized.
@@ -76,6 +81,8 @@ class QuantizationMixin(HooksMixin):
7681
scheme: Optional[Union[str, Dict[str, Any]]] = None
7782
kv_cache_scheme: Optional[QuantizationArgs] = None
7883

84+
_calibration_hooks: List[RemovableHandle] = PrivateAttr(default_factory=list)
85+
7986
@field_validator("targets", mode="before")
8087
def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
8188
if isinstance(value, str):
@@ -102,25 +109,49 @@ def validate_scheme(
102109

103110
return value
104111

105-
def attach_scheme_and_observers(self, model: torch.nn.Module):
112+
def initialize_quantization(self, model: torch.nn.Module):
106113
"""
107-
Apply this modifier as a quantization config to the model. Attach observers
108-
according to the schemes attached to each module
114+
Attach quantization schemes and observers to modules in the model according to
115+
the quantization config specified on this modifier
116+
117+
:param model: model to attach schemes and observers to
109118
"""
110119
reset_quantization_status(model) # reset any previously applied qconfigs
111120

121+
# apply scheme and status to model
112122
config = self.resolve_quantization_config()
113123
apply_quantization_config(model, config)
114124

125+
# apply observers, disable quantization until calibration
115126
model.apply(self._initialize_observers)
127+
model.apply(disable_quantization)
128+
129+
def start_calibration(self, model: torch.nn.Module):
130+
"""
131+
Register activation calibration hooks (including kv_cache quantization) and
132+
enable quantization as we calibrate
133+
134+
:param model: model to prepare for calibration
135+
"""
136+
self._calibration_hooks = self._initialize_hooks(model)
137+
model.apply(apply_calibration_status)
138+
model.apply(enable_quantization) # quantize at the same time as calibrate
116139

117-
def register_calibration_hooks(self, model: torch.nn.Module):
140+
def end_calibration(self, model: torch.nn.Module):
118141
"""
119-
Register activation calibration hooks (including kv_cache quantization)
142+
Remove calibration hooks and set the model status to frozen. Keep quantization
143+
enabled for future operations
144+
145+
:param model: model to end calibration for
120146
"""
121-
model.apply(self._initialize_hooks)
147+
self.remove_hooks(self._calibration_hooks)
148+
model.apply(freeze_module_quantization) # remove observers
149+
model.apply(enable_quantization) # keep quantization enabled
122150

123151
def has_config(self) -> bool:
152+
"""
153+
Determine if the user has specified a quantization config on this modifier
154+
"""
124155
return not (
125156
self.config_groups is None
126157
and self.targets == ["Linear"]
@@ -199,27 +230,44 @@ def _initialize_observers(self, module: torch.nn.Module):
199230
elif output:
200231
initialize_observer(module, base_name="output")
201232

202-
def _initialize_hooks(self, module: torch.nn.Module):
203-
if not hasattr(module, "quantization_scheme"):
204-
return
205-
206-
scheme: QuantizationScheme = module.quantization_scheme
207-
input = scheme.input_activations and not scheme.input_activations.dynamic
208-
output = scheme.output_activations and not scheme.output_activations.dynamic
209-
is_attention = is_attention_module(module)
210-
211-
# input activations
212-
if input:
213-
self.register_hook(module, calibrate_input_hook, "forward_pre")
233+
def _initialize_hooks(self, model: torch.nn.Module) -> List[RemovableHandle]:
234+
hooks = []
235+
for module in model.modules():
236+
if not hasattr(module, "quantization_scheme"):
237+
continue
238+
239+
scheme: QuantizationScheme = module.quantization_scheme
240+
input = scheme.input_activations and not scheme.input_activations.dynamic
241+
output = scheme.output_activations and not scheme.output_activations.dynamic
242+
is_attention = is_attention_module(module)
243+
244+
# input activations
245+
if input:
246+
hooks.append(
247+
self.register_hook(module, calibrate_input_hook, "forward_pre")
248+
)
249+
250+
# kv_cache activations. Within `apply_quantization_config`, the config is
251+
# modified to use attention output quantization if a kv_cache_scheme exists
252+
if is_attention and output:
253+
hooks.append(
254+
self.register_hook(
255+
module,
256+
calibrate_kv_cache_input_hook,
257+
"forward_pre",
258+
with_kwargs=True,
259+
)
260+
)
261+
hooks.append(
262+
self.register_hook(
263+
module, calibrate_kv_cache_output_hook, "forward"
264+
)
265+
)
214266

215-
# kv_cache activations. Within `apply_quantization_config`, the config is
216-
# modified to use attention output quantization if a kv_cache_scheme exists
217-
if is_attention and output:
218-
self.register_hook(
219-
module, calibrate_kv_cache_input_hook, "forward_pre", with_kwargs=True
220-
)
221-
self.register_hook(module, calibrate_kv_cache_output_hook, "forward")
267+
# output activations
268+
elif output:
269+
hooks.append(
270+
self.register_hook(module, calibrate_output_hook, "forward")
271+
)
222272

223-
# output activations
224-
elif output:
225-
self.register_hook(module, calibrate_output_hook, "forward")
273+
return hooks

0 commit comments

Comments
 (0)