|
9 | 9 | from loguru import logger
|
10 | 10 | from pydantic import Field, PrivateAttr, field_validator, model_validator
|
11 | 11 |
|
12 |
| -from llmcompressor.core import State |
| 12 | +from llmcompressor.core import Event, EventType, State |
| 13 | +from llmcompressor.modifiers.modifier import Modifier |
13 | 14 | from llmcompressor.modifiers.utils.hooks import HooksMixin
|
14 | 15 | from llmcompressor.pipelines.basic import run_pipeline as run_basic
|
15 |
| -from llmcompressor.pipelines.layer_sequential import ( |
16 |
| - run_pipeline as run_layer_sequential, |
17 |
| -) |
18 |
| -from llmcompressor.pipelines.sequential import run_pipeline as run_sequential |
19 | 16 | from llmcompressor.utils.pytorch.module import (
|
20 | 17 | get_layers,
|
21 | 18 | get_no_split_params,
|
|
24 | 21 | )
|
25 | 22 |
|
26 | 23 |
|
27 |
| -class SparsityModifierMixin(HooksMixin): |
| 24 | +class SparsityModifierMixin(Modifier): |
28 | 25 | # modifier arguments
|
29 | 26 | sparsity: Optional[Union[float, List[float]]]
|
30 | 27 | sparsity_profile: Optional[str] = None
|
@@ -97,6 +94,10 @@ def calibrate_module(
|
97 | 94 | ):
|
98 | 95 | raise NotImplementedError()
|
99 | 96 |
|
| 97 | + @abstractmethod |
| 98 | + def compress_modules(self): |
| 99 | + raise NotImplementedError() |
| 100 | + |
100 | 101 | def on_initialize(self, state: "State", **kwargs) -> bool:
|
101 | 102 | """
|
102 | 103 | Initialize and run the OBCQ algorithm on the current state
|
@@ -160,48 +161,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
|
160 | 161 | self._module_sparsities[module] = layer_sparsity
|
161 | 162 | self.register_hook(module, self.calibrate_module, "forward")
|
162 | 163 |
|
163 |
| - # infer and run pipeline |
164 |
| - model_name = state.model.__class__.__name__ |
165 |
| - input_names = dataloader.dataset.column_names |
166 |
| - unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) |
167 |
| - try: |
168 |
| - run_sequential( |
169 |
| - state.model, |
170 |
| - state.data.calib, |
171 |
| - self.sequential_targets, |
172 |
| - self.ignore, |
173 |
| - self, |
174 |
| - ) |
175 |
| - return True |
176 |
| - |
177 |
| - except Exception as exception: |
178 |
| - if isinstance(exception, torch.fx.proxy.TraceError): |
179 |
| - warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") |
180 |
| - if isinstance(exception, unfixable_errors): |
181 |
| - raise exception |
182 |
| - |
183 |
| - warnings.warn("Falling back to layer_sequential pipeline") |
184 |
| - try: |
185 |
| - run_layer_sequential( |
186 |
| - state.model, |
187 |
| - state.data.calib, |
188 |
| - self.sequential_targets, |
189 |
| - self, |
190 |
| - ) |
191 |
| - return True |
192 |
| - |
193 |
| - except Exception as exception: |
194 |
| - if isinstance(exception, TypeError): |
195 |
| - warnings.warn(f"{model_name} fails layer-wise assumptions") |
196 |
| - if isinstance(exception, unfixable_errors): |
197 |
| - raise exception |
198 |
| - |
199 |
| - warnings.warn( |
200 |
| - "Falling back to basic pipeline, which requires extra memory and " |
201 |
| - "may result in decreased accuracy" |
202 |
| - ) |
203 |
| - run_basic(state.model, state.data.calib, self) |
204 |
| - return True |
| 164 | + return True |
| 165 | + |
| 166 | + def on_event(self, state: State, event: Event, **kwargs): |
| 167 | + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: |
| 168 | + self.compress_modules() |
| 169 | + |
| 170 | + if event.type_ == EventType.CALIBRATION_EPOCH_END: |
| 171 | + self.compress_modules() |
| 172 | + |
| 173 | + # TODO: modify lifecycle to end on calibration epoch end |
| 174 | + if not self.ended_: |
| 175 | + self.on_end(state, None) |
| 176 | + |
| 177 | + def on_end(self, state: State, event: Event, **kwargs): |
| 178 | + self.ended_ = True # TODO: move to super call |
| 179 | + self.remove_hooks() |
205 | 180 |
|
206 | 181 | def _infer_sequential_targets(
|
207 | 182 | self, model: torch.nn.Module
|
|
0 commit comments