Skip to content

Commit 3fdbb8d

Browse files
committed
Extract pipelines
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9aedc8b commit 3fdbb8d

File tree

33 files changed

+778
-827
lines changed

33 files changed

+778
-827
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,3 @@ jobs:
9595
if: (success() || failure()) && steps.install.outcome == 'success'
9696
run: |
9797
pytest -v tests/llmcompressor/transformers/obcq
98-
- name: Running KV Cache Tests
99-
if: (success() || failure()) && steps.install.outcome == 'success'
100-
run: |
101-
pytest -v tests/llmcompressor/transformers/kv_cache

src/llmcompressor/args/dataset_arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from transformers import DefaultDataCollator
55

6+
from llmcompressor.pipelines.registry import PIPELINES
7+
68

79
@dataclass
810
class DVCDatasetArguments:
@@ -171,3 +173,10 @@ class DatasetArguments(CustomDatasetArguments):
171173
"will execute code present on the Hub on your local machine."
172174
},
173175
)
176+
pipeline: Optional[str] = field(
177+
default="independent",
178+
metadata={
179+
"help": "Calibration pipeline used to calibrate model. "
180+
f"Options: {PIPELINES.keys()}"
181+
},
182+
)

src/llmcompressor/core/events/event.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class EventType(Enum):
4444
BATCH_START = "batch_start"
4545
LOSS_CALCULATED = "loss_calculated"
4646
BATCH_END = "batch_end"
47+
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
48+
CALIBRATION_EPOCH_END = "calibration_epoch_end"
4749

4850
# step lifecycle
4951
OPTIM_PRE_STEP = "optim_pre_step"

src/llmcompressor/core/lifecycle.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def reset(self):
7777
self.__init__()
7878
logger.info("Compression lifecycle reset")
7979

80+
def initialize_recipe(
81+
self,
82+
recipe: Optional[RecipeInput] = None,
83+
recipe_stage: Optional[RecipeStageInput] = None,
84+
recipe_args: Optional[RecipeArgsInput] = None,
85+
):
86+
self.recipe_container.append(recipe, recipe_stage, recipe_args)
87+
self.modifiers = self.recipe_container.get_modifiers()
88+
8089
def initialize(
8190
self,
8291
recipe: Optional[RecipeInput] = None,
@@ -92,12 +101,10 @@ def initialize(
92101
:rtype: List[Any]
93102
"""
94103
self.state.update(**kwargs)
95-
if self.initialized_: # TODO: do not initialize twice
96-
return
97104

98105
logger.debug("Initializing compression lifecycle")
99-
self.recipe_container.append(recipe, recipe_stage, recipe_args)
100-
self.modifiers = self.recipe_container.get_modifiers()
106+
if not (recipe is recipe_stage is recipe_args is None):
107+
self.initialize_recipe(recipe, recipe_stage, recipe_args)
101108
self._set_model_layer_prefix()
102109

103110
mod_data = []

src/llmcompressor/core/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ def get_serialized_recipe(self) -> Optional[str]:
220220

221221
logger.warning("Recipe not found in session - it may have been reset")
222222

223+
def get_modifiers(self):
224+
stage_modifiers = self.lifecycle.modifiers
225+
return [
226+
modifier
227+
for stage_modifier in stage_modifiers
228+
for modifier in stage_modifier.modifiers
229+
] # noqa: E127
230+
223231
def _log_model_info(self):
224232
# Log model level logs if cadence reached
225233
current_index = self._lifecycle.global_step

src/llmcompressor/core/session_functions.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
from contextlib import contextmanager
3-
from typing import Any, Optional
3+
from typing import Any, Generator, Optional
44

55
from llmcompressor.core.events import EventType
66
from llmcompressor.core.session import CompressionSession
@@ -21,7 +21,7 @@
2121

2222

2323
@contextmanager
24-
def create_session() -> CompressionSession:
24+
def create_session() -> Generator[CompressionSession, None, None]:
2525
"""
2626
Context manager to create and yield a new session for sparsification.
2727
This will set the active session to the new session for the duration
@@ -136,5 +136,26 @@ def batch_end(cls, **kwargs) -> ModifiedState:
136136
active_session()._log_model_info()
137137
return cls.event(EventType.BATCH_END, **kwargs)
138138

139+
@classmethod
140+
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
141+
"""
142+
Invoke a sequential epoch end event for the active session. This event should be
143+
called after one sequential layer has been calibrated/trained for one epoch
144+
145+
This is called after a sequential layer has been calibrated with one batch, see
146+
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
147+
"""
148+
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
149+
150+
@classmethod
151+
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
152+
"""
153+
Invoke a epoch end event for the active session during calibration. This event
154+
should be called after the model has been calibrated for one epoch
155+
156+
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
157+
"""
158+
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)
159+
139160

140161
callbacks = LifecycleCallbacks

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from llmcompressor.core.session_functions import active_session
88
from llmcompressor.datasets import get_calibration_dataloader
99
from llmcompressor.entrypoints.utils import post_process, pre_process
10+
from llmcompressor.pipelines.registry import get_pipeline_fn
1011

1112
__all__ = ["Oneshot", "oneshot"]
1213

@@ -157,21 +158,20 @@ def apply_recipe_modifiers(
157158
"""
158159

159160
session = active_session()
161+
session.reset()
160162

161-
session_kwargs = dict(
162-
model=self.model,
163+
session.lifecycle.state.update(model=self.model, start=-1)
164+
session.lifecycle.initialize_recipe(
163165
recipe=self.recipe,
164-
recipe_args=self.recipe_args.recipe_args,
165-
calib_data=calibration_dataloader,
166-
start=-1, # oneshot-specific argument
167-
copy_data=False,
168-
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
169166
recipe_stage=recipe_stage,
167+
recipe_args=self.recipe_args.recipe_args,
170168
)
171169

172-
session.reset()
173-
session.initialize(**session_kwargs)
174-
session.finalize(**session_kwargs)
170+
modifiers = session.get_modifiers()
171+
_, pipeline_fn = get_pipeline_fn(self.dataset_args.pipeline, modifiers)
172+
pipeline_fn(self.model, calibration_dataloader)
173+
174+
session.finalize()
175175

176176

177177
def oneshot(**kwargs) -> PreTrainedModel:

src/llmcompressor/modifiers/modifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):
8989

9090
self.initialized_ = self.on_initialize(state=state, **kwargs)
9191

92-
# trigger start
92+
# trigger starts
9393
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
9494
if self.should_start(fake_start_event):
9595
self.on_start(state, fake_start_event, **kwargs)
@@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
103103
:param state: The current state of the model
104104
:param kwargs: Additional arguments for finalizing the modifier
105105
"""
106-
if self.finalized_ or not self.initialized_:
107-
return
106+
if self.finalized_:
107+
raise RuntimeError("cannot finalize a modifier twice")
108108

109109
if not self.initialized_:
110110
raise RuntimeError("cannot finalize an uninitialized modifier")

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
4141
Lifecycle:
4242
- on_initialize
4343
- register_hook(module, calibrate_module, "forward")
44-
- run_sequential / run_layer_sequential / run_basic
45-
- make_empty_hessian
46-
- accumulate_hessian
4744
- on_sequential_batch_end
4845
- sparsify_weight
4946
- on_finalize
@@ -90,6 +87,14 @@ def calibrate_module(
9087
args: Tuple[torch.Tensor, ...],
9188
_output: torch.Tensor,
9289
):
90+
"""
91+
Calibration hook used to accumulate the hessian of the input to the module
92+
93+
:param module: module being calibrated
94+
:param args: inputs to the module, the first element of which is the
95+
cannonical input
96+
:param _output: uncompressed module output, unused
97+
"""
9398
# Assume that the first argument is the input
9499
inp = args[0]
95100

@@ -108,10 +113,9 @@ def calibrate_module(
108113
self._num_samples[module],
109114
)
110115

111-
def on_sequential_batch_end(self):
116+
def compress_modules(self):
112117
"""
113-
Sparsify modules
114-
TODO: implement with event callback
118+
Sparsify modules which have been calibrated
115119
"""
116120
for module in list(self._num_samples.keys()):
117121
name = self._module_names[module]
@@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
152156
self._hessians[module] = self._hessians[module].to(device="cpu")
153157

154158
def on_finalize(self, state: State, **kwargs) -> bool:
155-
self.remove_hooks()
159+
# TODO: modify lifecycle to end on finalize
160+
if not self.ended_:
161+
self.on_end(state, None) # remove hooks
162+
163+
if len(self._num_samples) > 0:
164+
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
165+
156166
self._hessians = dict()
157167
self._num_samples = dict()
158168
self._module_names = dict()

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99
from loguru import logger
1010
from pydantic import Field, PrivateAttr, field_validator, model_validator
1111

12-
from llmcompressor.core import State
12+
from llmcompressor.core import Event, EventType, State
13+
from llmcompressor.modifiers.modifier import Modifier
1314
from llmcompressor.modifiers.utils.hooks import HooksMixin
1415
from llmcompressor.pipelines.basic import run_pipeline as run_basic
15-
from llmcompressor.pipelines.layer_sequential import (
16-
run_pipeline as run_layer_sequential,
17-
)
18-
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
1916
from llmcompressor.utils.pytorch.module import (
2017
get_layers,
2118
get_no_split_params,
@@ -24,7 +21,7 @@
2421
)
2522

2623

27-
class SparsityModifierMixin(HooksMixin):
24+
class SparsityModifierMixin(Modifier):
2825
# modifier arguments
2926
sparsity: Optional[Union[float, List[float]]]
3027
sparsity_profile: Optional[str] = None
@@ -97,6 +94,10 @@ def calibrate_module(
9794
):
9895
raise NotImplementedError()
9996

97+
@abstractmethod
98+
def compress_modules(self):
99+
raise NotImplementedError()
100+
100101
def on_initialize(self, state: "State", **kwargs) -> bool:
101102
"""
102103
Initialize and run the OBCQ algorithm on the current state
@@ -160,48 +161,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
160161
self._module_sparsities[module] = layer_sparsity
161162
self.register_hook(module, self.calibrate_module, "forward")
162163

163-
# infer and run pipeline
164-
model_name = state.model.__class__.__name__
165-
input_names = dataloader.dataset.column_names
166-
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
167-
try:
168-
run_sequential(
169-
state.model,
170-
state.data.calib,
171-
self.sequential_targets,
172-
self.ignore,
173-
self,
174-
)
175-
return True
176-
177-
except Exception as exception:
178-
if isinstance(exception, torch.fx.proxy.TraceError):
179-
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
180-
if isinstance(exception, unfixable_errors):
181-
raise exception
182-
183-
warnings.warn("Falling back to layer_sequential pipeline")
184-
try:
185-
run_layer_sequential(
186-
state.model,
187-
state.data.calib,
188-
self.sequential_targets,
189-
self,
190-
)
191-
return True
192-
193-
except Exception as exception:
194-
if isinstance(exception, TypeError):
195-
warnings.warn(f"{model_name} fails layer-wise assumptions")
196-
if isinstance(exception, unfixable_errors):
197-
raise exception
198-
199-
warnings.warn(
200-
"Falling back to basic pipeline, which requires extra memory and "
201-
"may result in decreased accuracy"
202-
)
203-
run_basic(state.model, state.data.calib, self)
204-
return True
164+
return True
165+
166+
def on_event(self, state: State, event: Event, **kwargs):
167+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
168+
self.compress_modules()
169+
170+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
171+
self.compress_modules()
172+
173+
# TODO: modify lifecycle to end on calibration epoch end
174+
if not self.ended_:
175+
self.on_end(state, None)
176+
177+
def on_end(self, state: State, event: Event, **kwargs):
178+
self.ended_ = True # TODO: move to super call
179+
self.remove_hooks()
205180

206181
def _infer_sequential_targets(
207182
self, model: torch.nn.Module

0 commit comments

Comments
 (0)