Skip to content

Commit ce56bac

Browse files
kylesayrsshanjiaz
authored andcommitted
Pipeline Extraction (#1279)
## Purpose ## * Extract data pipelines from modifiers to enable multiple modifiers to be active at the same time * This enables faster compression of larger models * This enables more memory efficient compression of larger models (not limited to just GPTQ/SGPT) ## Prerequisites ## * #1351 * #1298 ## Callback Changes ## * Implement `calibration_epoch_start` * This callback should be called at the start of every calibration pipeline * This callback causes modifiers to attach hooks * Implement `sequential_epoch_end` * This callback should be called after one sequential layer has been calibrated with one epoch * This callback triggers compression and replaces passing a `callback_modifier` * Implement `calibration_epoch_end` * This callback triggers at the end of a calibration epoch, and is used to *trigger compression* in between pipelines composed using the independent pipeline and *remove hooks* in between independent pipelines ## Lifecycle Changes ## * Oneshot modifiers implement on_end, which removes hooks when calibration finishes * In the future, calibration_epoch_start is treated like batch_start, where it is an opportunity for modifiers to start * In the future, calibration_epoch_end is treated like batch_end, where it is an opportunity for modifiers to end * In the future, finalize is treated like batch_end, where it is an opportunity for modifiers to end * Right now, these opportunities are implemented manually on each oneshot modifier, rather than being a lifecycle rule ## Data Pipeline Changes ## * Implement data pipeline registry * Inferred pipeline is selected using modifiers and can be overridden by user * Implement independent pipeline * This pipeline treats each modifier as a separate stage and assigns a pipeline to each modifier * Meant to replicate current LC behavior * Originally, these compression events were triggered by reaching the end of each module’s initialize function. Now a separate event is required * Implement `session.get_modifiers` * In order to perform data pipeline inference and other sequential pipeline inference, these functions must get the list of active modifiers before they initialize * This function gets all the active modifiers across all `ModifierStages` * Prepare smoothquant for pipeline extraction * Trigger `_apply_smoothing` on the `sequential_epoch_end ` and `calibration_epoch_end` * Add a [guard](https://github.com/vllm-project/llm-compressor/pull/1244/files#diff-90bb5efcbf5f23ba1db62664a91f6b2d6492a909c387cd82c1589f45d5e8615cR285) which allows the `_apply_smoothing` function to be called multiple times per session (as is required by sequential pipeline) ## Testing ## * Quantized llama3-8b using both the independent (basic + sequential) and sequential pipelines * There was no accuracy regression from using a shared pipeline, although we keep the `independent` pipeline as the default for now * Transformers tests pass * https://github.com/neuralmagic/llm-compressor-testing/actions/runs/14622080074 --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: shanjiaz <[email protected]>
1 parent ecb8b80 commit ce56bac

File tree

32 files changed

+737
-473
lines changed

32 files changed

+737
-473
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ jobs:
103103
- name: Running KV Cache Tests
104104
if: (success() || failure()) && steps.install.outcome == 'success'
105105
run: |
106-
pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr"
106+
pytest -v tests/llmcompressor/transformers/kv_cache

src/llmcompressor/args/dataset_arguments.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,11 @@ class DatasetArguments(CustomDatasetArguments):
171171
"will execute code present on the Hub on your local machine."
172172
},
173173
)
174+
pipeline: Optional[str] = field(
175+
default="independent",
176+
metadata={
177+
"help": "Calibration pipeline used to calibrate model"
178+
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
179+
"independent]"
180+
},
181+
)

src/llmcompressor/core/events/event.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class EventType(Enum):
3232
:param BATCH_START: Event type for the start of a batch.
3333
:param LOSS_CALCULATED: Event type for when loss is calculated.
3434
:param BATCH_END: Event type for the end of a batch.
35+
:param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch.
36+
:param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch,
37+
specifically used by `src/llmcompressor/pipelines/sequential/pipeline.py`
38+
:param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch.
3539
:param OPTIM_PRE_STEP: Event type for pre-optimization step.
3640
:param OPTIM_POST_STEP: Event type for post-optimization step.
3741
"""
@@ -45,6 +49,11 @@ class EventType(Enum):
4549
LOSS_CALCULATED = "loss_calculated"
4650
BATCH_END = "batch_end"
4751

52+
# calibration lifecycle
53+
CALIBRATION_EPOCH_START = "calibration_epoch_start"
54+
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
55+
CALIBRATION_EPOCH_END = "calibration_epoch_end"
56+
4857
# step lifecycle
4958
OPTIM_PRE_STEP = "optim_pre_step"
5059
OPTIM_POST_STEP = "optim_post_step"

src/llmcompressor/core/session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,17 @@ def get_serialized_recipe(self) -> Optional[str]:
220220

221221
logger.warning("Recipe not found in session - it may have been reset")
222222

223+
def get_modifiers(self):
224+
"""
225+
Get all modifiers across all stages
226+
"""
227+
stage_modifiers = self.lifecycle.modifiers
228+
return [
229+
modifier
230+
for stage_modifier in stage_modifiers
231+
for modifier in stage_modifier.modifiers
232+
] # noqa: E127
233+
223234
def _log_model_info(self):
224235
# Log model level logs if cadence reached
225236
current_index = self._lifecycle.global_step

src/llmcompressor/core/session_functions.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
from contextlib import contextmanager
3-
from typing import Any, Optional
3+
from typing import Any, Generator, Optional
44

55
from llmcompressor.core.events import EventType
66
from llmcompressor.core.session import CompressionSession
@@ -21,7 +21,7 @@
2121

2222

2323
@contextmanager
24-
def create_session() -> CompressionSession:
24+
def create_session() -> Generator[CompressionSession, None, None]:
2525
"""
2626
Context manager to create and yield a new session for sparsification.
2727
This will set the active session to the new session for the duration
@@ -136,5 +136,36 @@ def batch_end(cls, **kwargs) -> ModifiedState:
136136
active_session()._log_model_info()
137137
return cls.event(EventType.BATCH_END, **kwargs)
138138

139+
@classmethod
140+
def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
141+
"""
142+
Invoke a epoch start event for the active session during calibration. This event
143+
should be called before calibration starts for one epoch
144+
145+
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
146+
"""
147+
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)
148+
149+
@classmethod
150+
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
151+
"""
152+
Invoke a sequential epoch end event for the active session. This event should be
153+
called after one sequential layer has been calibrated/trained for one epoch
154+
155+
This is called after a sequential layer has been calibrated with one batch, see
156+
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
157+
"""
158+
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
159+
160+
@classmethod
161+
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
162+
"""
163+
Invoke a epoch end event for the active session during calibration. This event
164+
should be called after the model has been calibrated for one epoch
165+
166+
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
167+
"""
168+
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)
169+
139170

140171
callbacks = LifecycleCallbacks

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from llmcompressor.core.session_functions import active_session
88
from llmcompressor.datasets import get_calibration_dataloader
99
from llmcompressor.entrypoints.utils import post_process, pre_process
10+
from llmcompressor.pipelines.registry import CalibrationPipeline
1011

1112
__all__ = ["Oneshot", "oneshot"]
1213

@@ -157,21 +158,25 @@ def apply_recipe_modifiers(
157158
"""
158159

159160
session = active_session()
161+
session.reset()
160162

161-
session_kwargs = dict(
163+
# (Helen INFERENG-661): validate recipe modifiers before intialization
164+
session.initialize(
162165
model=self.model,
166+
start=-1,
163167
recipe=self.recipe,
164-
recipe_args=self.recipe_args.recipe_args,
165-
calib_data=calibration_dataloader,
166-
start=-1, # oneshot-specific argument
167-
copy_data=False,
168-
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
169168
recipe_stage=recipe_stage,
169+
recipe_args=self.recipe_args.recipe_args,
170+
calib_data=calibration_dataloader, # only used by AWQModifier, remove once
171+
# AWQModifier supports calibration pipelines
170172
)
171173

172-
session.reset()
173-
session.initialize(**session_kwargs)
174-
session.finalize(**session_kwargs)
174+
user_pipeline = self.dataset_args.pipeline
175+
modifiers = session.get_modifiers()
176+
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
177+
pipeline(self.model, calibration_dataloader, self.dataset_args)
178+
179+
session.finalize()
175180

176181

177182
def oneshot(**kwargs) -> PreTrainedModel:

src/llmcompressor/modifiers/modifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):
8989

9090
self.initialized_ = self.on_initialize(state=state, **kwargs)
9191

92-
# trigger start
92+
# trigger starts
9393
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
9494
if self.should_start(fake_start_event):
9595
self.on_start(state, fake_start_event, **kwargs)
@@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
103103
:param state: The current state of the model
104104
:param kwargs: Additional arguments for finalizing the modifier
105105
"""
106-
if self.finalized_ or not self.initialized_:
107-
return
106+
if self.finalized_:
107+
raise RuntimeError("cannot finalize a modifier twice")
108108

109109
if not self.initialized_:
110110
raise RuntimeError("cannot finalize an uninitialized modifier")

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
4141
Lifecycle:
4242
- on_initialize
4343
- register_hook(module, calibrate_module, "forward")
44-
- run_sequential / run_layer_sequential / run_basic
45-
- make_empty_hessian
46-
- accumulate_hessian
4744
- on_sequential_batch_end
4845
- sparsify_weight
4946
- on_finalize
@@ -90,6 +87,14 @@ def calibrate_module(
9087
args: Tuple[torch.Tensor, ...],
9188
_output: torch.Tensor,
9289
):
90+
"""
91+
Calibration hook used to accumulate the hessian of the input to the module
92+
93+
:param module: module being calibrated
94+
:param args: inputs to the module, the first element of which is the
95+
cannonical input
96+
:param _output: uncompressed module output, unused
97+
"""
9398
# Assume that the first argument is the input
9499
inp = args[0]
95100

@@ -108,10 +113,9 @@ def calibrate_module(
108113
self._num_samples[module],
109114
)
110115

111-
def on_sequential_batch_end(self):
116+
def compress_modules(self):
112117
"""
113-
Sparsify modules
114-
TODO: implement with event callback
118+
Sparsify modules which have been calibrated
115119
"""
116120
for module in list(self._num_samples.keys()):
117121
name = self._module_names[module]
@@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
152156
self._hessians[module] = self._hessians[module].to(device="cpu")
153157

154158
def on_finalize(self, state: State, **kwargs) -> bool:
155-
self.remove_hooks()
159+
# TODO: modify lifecycle to end on finalize
160+
if not self.ended_:
161+
self.on_end(state, None) # remove hooks
162+
163+
if len(self._num_samples) > 0:
164+
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
165+
156166
self._hessians = dict()
157167
self._num_samples = dict()
158168
self._module_names = dict()

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@
99
from loguru import logger
1010
from pydantic import Field, PrivateAttr, field_validator, model_validator
1111

12-
from llmcompressor.core import State
12+
from llmcompressor.core import Event, EventType, State
13+
from llmcompressor.modifiers.modifier import Modifier
1314
from llmcompressor.modifiers.utils.hooks import HooksMixin
14-
from llmcompressor.pipelines.basic import run_pipeline as run_basic
15-
from llmcompressor.pipelines.layer_sequential import (
16-
run_pipeline as run_layer_sequential,
17-
)
18-
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
1915
from llmcompressor.utils.pytorch.module import (
2016
get_layers,
2117
get_no_split_params,
@@ -24,7 +20,7 @@
2420
)
2521

2622

27-
class SparsityModifierMixin(HooksMixin):
23+
class SparsityModifierMixin(Modifier):
2824
# modifier arguments
2925
sparsity: Optional[Union[float, List[float]]]
3026
sparsity_profile: Optional[str] = None
@@ -42,6 +38,7 @@ class SparsityModifierMixin(HooksMixin):
4238
_prune_n: Optional[int] = PrivateAttr(default=None)
4339
_prune_m: Optional[int] = PrivateAttr(default=None)
4440
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
41+
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
4542
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
4643

4744
@field_validator("sequential_update", mode="before")
@@ -97,6 +94,10 @@ def calibrate_module(
9794
):
9895
raise NotImplementedError()
9996

97+
@abstractmethod
98+
def compress_modules(self):
99+
raise NotImplementedError()
100+
100101
def on_initialize(self, state: "State", **kwargs) -> bool:
101102
"""
102103
Initialize and run the OBCQ algorithm on the current state
@@ -109,7 +110,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
109110
# infer module and sequential targets
110111
self.sequential_targets = self._infer_sequential_targets(model)
111112
layers = get_layers(self.sequential_targets, model)
112-
target_layers = get_layers(self.targets, model) # layers containing targets
113+
self._target_layers = get_layers(
114+
self.targets, model
115+
) # layers containing targets
113116

114117
# infer layer sparsities
115118
if self.sparsity_profile == "owl":
@@ -120,16 +123,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
120123
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)
121124

122125
# get layers and validate sparsity
123-
if isinstance(self.sparsity, (list, dict)) and len(target_layers) != len(
126+
if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
124127
self.sparsity
125128
):
126129
raise ValueError(
127130
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
128131
f"sparsities values, but model has {len(layers)} target layers"
129132
)
130133

134+
return True
135+
136+
def on_start(self, state: State, event: Event, **kwargs):
137+
self.started_ = True
138+
131139
# register hooks
132-
for index, (layer_name, layer) in enumerate(target_layers.items()):
140+
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
133141
if isinstance(self.sparsity, dict):
134142
layer_sparsity = self.sparsity[layer_name]
135143
elif isinstance(self.sparsity, list):
@@ -160,48 +168,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
160168
self._module_sparsities[module] = layer_sparsity
161169
self.register_hook(module, self.calibrate_module, "forward")
162170

163-
# infer and run pipeline
164-
model_name = state.model.__class__.__name__
165-
input_names = dataloader.dataset.column_names
166-
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
167-
try:
168-
run_sequential(
169-
state.model,
170-
state.data.calib,
171-
self.sequential_targets,
172-
self.ignore,
173-
self,
174-
)
175-
return True
176-
177-
except Exception as exception:
178-
if isinstance(exception, torch.fx.proxy.TraceError):
179-
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
180-
if isinstance(exception, unfixable_errors):
181-
raise exception
182-
183-
warnings.warn("Falling back to layer_sequential pipeline")
184-
try:
185-
run_layer_sequential(
186-
state.model,
187-
state.data.calib,
188-
self.sequential_targets,
189-
self,
190-
)
191-
return True
192-
193-
except Exception as exception:
194-
if isinstance(exception, TypeError):
195-
warnings.warn(f"{model_name} fails layer-wise assumptions")
196-
if isinstance(exception, unfixable_errors):
197-
raise exception
198-
199-
warnings.warn(
200-
"Falling back to basic pipeline, which requires extra memory and "
201-
"may result in decreased accuracy"
202-
)
203-
run_basic(state.model, state.data.calib, self)
204-
return True
171+
def on_event(self, state: State, event: Event, **kwargs):
172+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
173+
if not self.started_:
174+
self.on_start(state, None)
175+
176+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
177+
self.compress_modules()
178+
179+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
180+
self.compress_modules()
181+
182+
if not self.ended_:
183+
self.on_end(state, None)
184+
185+
def on_end(self, state: State, event: Event, **kwargs):
186+
self.ended_ = True
187+
self.remove_hooks()
205188

206189
def _infer_sequential_targets(
207190
self, model: torch.nn.Module
@@ -261,6 +244,8 @@ def _infer_owl_layer_sparsity(
261244
return sparsities
262245

263246
def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
247+
from llmcompressor.pipelines.basic import run_calibration
248+
264249
acts = defaultdict(int)
265250

266251
def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
@@ -275,7 +260,7 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
275260
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
276261
)
277262
with HooksMixin.disable_hooks(keep=hooks):
278-
run_basic(model, dataloader)
263+
run_calibration(model, dataloader)
279264
self.remove_hooks(hooks)
280265

281266
return acts

0 commit comments

Comments
 (0)