Skip to content

Commit 46f6811

Browse files
committed
Pipeline extraction
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ebd6ae9 commit 46f6811

File tree

23 files changed

+374
-291
lines changed

23 files changed

+374
-291
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from transformers import DefaultDataCollator
55

6+
from llmcompressor.pipelines.registry import PIPELINES
7+
68

79
@dataclass
810
class DVCDatasetArguments:
@@ -171,3 +173,10 @@ class DatasetArguments(CustomDatasetArguments):
171173
"will execute code present on the Hub on your local machine."
172174
},
173175
)
176+
pipeline: Optional[str] = field(
177+
default="independent",
178+
metadata={
179+
"help": "Calibration pipeline used to calibrate model. "
180+
f"Options: {PIPELINES.keys()}"
181+
},
182+
)

src/llmcompressor/core/events/event.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class EventType(Enum):
4444
BATCH_START = "batch_start"
4545
LOSS_CALCULATED = "loss_calculated"
4646
BATCH_END = "batch_end"
47+
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
48+
CALIBRATION_EPOCH_END = "calibration_epoch_end"
4749

4850
# step lifecycle
4951
OPTIM_PRE_STEP = "optim_pre_step"

src/llmcompressor/core/lifecycle.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def reset(self):
7777
self.__init__()
7878
logger.info("Compression lifecycle reset")
7979

80+
def initialize_recipe(
81+
self,
82+
recipe: Optional[RecipeInput] = None,
83+
recipe_stage: Optional[RecipeStageInput] = None,
84+
recipe_args: Optional[RecipeArgsInput] = None,
85+
):
86+
if len(self.modifiers) <= 0:
87+
self.recipe_container.append(recipe, recipe_stage, recipe_args)
88+
self.modifiers = self.recipe_container.get_modifiers()
89+
8090
def initialize(
8191
self,
8292
recipe: Optional[RecipeInput] = None,
@@ -92,12 +102,9 @@ def initialize(
92102
:rtype: List[Any]
93103
"""
94104
self.state.update(**kwargs)
95-
if self.initialized_: # TODO: do not initialize twice
96-
return
97105

98106
logger.debug("Initializing compression lifecycle")
99-
self.recipe_container.append(recipe, recipe_stage, recipe_args)
100-
self.modifiers = self.recipe_container.get_modifiers()
107+
self.initialize_recipe(recipe, recipe_stage, recipe_args)
101108
self._set_model_layer_prefix()
102109

103110
mod_data = []

src/llmcompressor/core/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from functools import reduce
23
from typing import Any, Callable, Dict, List, Optional, Union
34

45
from loguru import logger
@@ -220,6 +221,10 @@ def get_serialized_recipe(self) -> Optional[str]:
220221

221222
logger.warning("Recipe not found in session - it may have been reset")
222223

224+
def get_modifiers(self):
225+
stage_modifiers = self.lifecycle.modifiers
226+
return list(reduce(sum, (mod.modifiers for mod in stage_modifiers)))
227+
223228
def _log_model_info(self):
224229
# Log model level logs if cadence reached
225230
current_index = self._lifecycle.global_step

src/llmcompressor/core/session_functions.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
from contextlib import contextmanager
3-
from typing import Any, Optional
3+
from typing import Any, Generator, Optional
44

55
from llmcompressor.core.events import EventType
66
from llmcompressor.core.session import CompressionSession
@@ -21,7 +21,7 @@
2121

2222

2323
@contextmanager
24-
def create_session() -> CompressionSession:
24+
def create_session() -> Generator[CompressionSession, None, None]:
2525
"""
2626
Context manager to create and yield a new session for sparsification.
2727
This will set the active session to the new session for the duration
@@ -136,5 +136,26 @@ def batch_end(cls, **kwargs) -> ModifiedState:
136136
active_session()._log_model_info()
137137
return cls.event(EventType.BATCH_END, **kwargs)
138138

139+
@classmethod
140+
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
141+
"""
142+
Invoke a sequential epoch end event for the active session. This event should be
143+
called after one sequential layer has been calibrated/trained for one epoch
144+
145+
This is called after a sequential layer has been calibrated with one batch, see
146+
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
147+
"""
148+
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
149+
150+
@classmethod
151+
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
152+
"""
153+
Invoke a epoch end event for the active session during calibration. This event
154+
should be called after the model has been calibrated for one epoch
155+
156+
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
157+
"""
158+
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)
159+
139160

140161
callbacks = LifecycleCallbacks

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from llmcompressor.core.session_functions import active_session
88
from llmcompressor.datasets import get_calibration_dataloader
99
from llmcompressor.entrypoints.utils import post_process, pre_process
10+
from llmcompressor.pipelines.registry import get_pipeline_fn
1011

1112
__all__ = ["Oneshot", "oneshot"]
1213

@@ -157,21 +158,20 @@ def apply_recipe_modifiers(
157158
"""
158159

159160
session = active_session()
161+
session.reset()
160162

161-
session_kwargs = dict(
162-
model=self.model,
163+
session.lifecycle.state.update(model=self.model, start=-1)
164+
session.lifecycle.initialize_recipe(
163165
recipe=self.recipe,
164-
recipe_args=self.recipe_args.recipe_args,
165-
calib_data=calibration_dataloader,
166-
start=-1, # oneshot-specific argument
167-
copy_data=False,
168-
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
169166
recipe_stage=recipe_stage,
167+
recipe_args=self.recipe_args.recipe_args,
170168
)
171169

172-
session.reset()
173-
session.initialize(**session_kwargs)
174-
session.finalize(**session_kwargs)
170+
modifiers = session.get_modifiers()
171+
_, pipeline_fn = get_pipeline_fn(self.dataset_args.pipeline, modifiers)
172+
pipeline_fn(self.model, calibration_dataloader)
173+
174+
session.finalize()
175175

176176

177177
def oneshot(**kwargs) -> PreTrainedModel:

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from loguru import logger
1111
from pydantic import PrivateAttr
1212

13-
from llmcompressor.core import State
13+
from llmcompressor.core import Event, EventType, State
1414
from llmcompressor.modifiers import Modifier
1515
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
1616
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
@@ -90,6 +90,14 @@ def calibrate_module(
9090
args: Tuple[torch.Tensor, ...],
9191
_output: torch.Tensor,
9292
):
93+
"""
94+
Calibration hook used to accumulate the hessian of the input to the module
95+
96+
:param module: module being calibrated
97+
:param args: inputs to the module, the first element of which is the
98+
cannonical input
99+
:param _output: uncompressed module output, unused
100+
"""
93101
# Assume that the first argument is the input
94102
inp = args[0]
95103

@@ -108,10 +116,16 @@ def calibrate_module(
108116
self._num_samples[module],
109117
)
110118

111-
def on_sequential_batch_end(self):
119+
def on_event(self, state: State, event: Event, **kwargs):
120+
if event.type_ in (
121+
EventType.SEQUENTIAL_EPOCH_END,
122+
EventType.CALIBRATION_EPOCH_END,
123+
):
124+
self.compress_modules()
125+
126+
def compress_modules(self):
112127
"""
113-
Sparsify modules
114-
TODO: implement with event callback
128+
Sparsify modules which have been calibrated
115129
"""
116130
for module in list(self._num_samples.keys()):
117131
name = self._module_names[module]
@@ -152,6 +166,9 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
152166
self._hessians[module] = self._hessians[module].to(device="cpu")
153167

154168
def on_finalize(self, state: State, **kwargs) -> bool:
169+
if len(self._num_samples) > 0:
170+
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
171+
155172
self.remove_hooks()
156173
self._hessians = dict()
157174
self._num_samples = dict()

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
from llmcompressor.core import State
1313
from llmcompressor.modifiers.utils.hooks import HooksMixin
1414
from llmcompressor.pipelines.basic import run_pipeline as run_basic
15-
from llmcompressor.pipelines.layer_sequential import (
16-
run_pipeline as run_layer_sequential,
17-
)
18-
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
1915
from llmcompressor.utils.pytorch.module import (
2016
get_layers,
2117
get_no_split_params,
@@ -160,49 +156,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
160156
self._module_sparsities[module] = layer_sparsity
161157
self.register_hook(module, self.calibrate_module, "forward")
162158

163-
# infer and run pipeline
164-
model_name = state.model.__class__.__name__
165-
input_names = dataloader.dataset.column_names
166-
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
167-
try:
168-
run_sequential(
169-
state.model,
170-
state.data.calib,
171-
self.sequential_targets,
172-
self.ignore,
173-
self,
174-
)
175-
return True
176-
177-
except Exception as exception:
178-
if isinstance(exception, torch.fx.proxy.TraceError):
179-
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
180-
if isinstance(exception, unfixable_errors):
181-
raise exception
182-
183-
warnings.warn("Falling back to layer_sequential pipeline")
184-
try:
185-
run_layer_sequential(
186-
state.model,
187-
state.data.calib,
188-
self.sequential_targets,
189-
self,
190-
)
191-
return True
192-
193-
except Exception as exception:
194-
if isinstance(exception, TypeError):
195-
warnings.warn(f"{model_name} fails layer-wise assumptions")
196-
if isinstance(exception, unfixable_errors):
197-
raise exception
198-
199-
warnings.warn(
200-
"Falling back to basic pipeline, which requires extra memory and "
201-
"may result in decreased accuracy"
202-
)
203-
run_basic(state.model, state.data.calib, self)
204-
return True
205-
206159
def _infer_sequential_targets(
207160
self, model: torch.nn.Module
208161
) -> Union[str, List[str]]:

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from loguru import logger
1010
from pydantic import PrivateAttr
1111

12-
from llmcompressor.core import State
12+
from llmcompressor.core import Event, EventType, State
1313
from llmcompressor.modifiers import Modifier
1414
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
1515
from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import (
@@ -74,6 +74,14 @@ def calibrate_module(
7474
args: Tuple[torch.Tensor, ...],
7575
_output: torch.Tensor,
7676
):
77+
"""
78+
Calibration hook used to accumulate the row scalars of the input to the module
79+
80+
:param module: module being calibrated
81+
:param args: inputs to the module, the first element of which is the
82+
cannonical input
83+
:param _output: uncompressed module output, unused
84+
"""
7785
# Assume that the first argument is the input
7886
inp = args[0]
7987

@@ -91,12 +99,17 @@ def calibrate_module(
9199
self._num_samples[module],
92100
)
93101

94-
def on_sequential_batch_end(self):
102+
def on_event(self, state: State, event: Event, **kwargs):
103+
if event.type_ in (
104+
EventType.SEQUENTIAL_EPOCH_END,
105+
EventType.CALIBRATION_EPOCH_END,
106+
):
107+
self.compress_modules()
108+
109+
def compress_modules(self):
95110
"""
96-
Sparsify modules
97-
TODO: implement with event callback
111+
Sparsify modules which have been calibrated
98112
"""
99-
100113
for module in list(self._num_samples.keys()):
101114
name = self._module_names[module]
102115
sparsity = self._module_sparsities[module]
@@ -120,6 +133,9 @@ def on_sequential_batch_end(self):
120133
del self._num_samples[module]
121134

122135
def on_finalize(self, state: State, **kwargs) -> bool:
136+
if len(self._num_samples) > 0:
137+
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
138+
123139
self.remove_hooks()
124140
self._row_scalars = dict()
125141
self._num_samples = dict()

0 commit comments

Comments
 (0)