Skip to content

Commit 7611d99

Browse files
committed
use pipeline registry
Signed-off-by: Kyle Sayers <[email protected]>
1 parent aad4e06 commit 7611d99

File tree

16 files changed

+343
-319
lines changed

16 files changed

+343
-319
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
from transformers import DefaultDataCollator
55

6-
from llmcompressor.pipelines.registry import PIPELINES
7-
86

97
@dataclass
108
class DVCDatasetArguments:
@@ -176,7 +174,8 @@ class DatasetArguments(CustomDatasetArguments):
176174
pipeline: Optional[str] = field(
177175
default="independent",
178176
metadata={
179-
"help": "Calibration pipeline used to calibrate model. "
180-
f"Options: {PIPELINES.keys()}"
177+
"help": "Calibration pipeline used to calibrate model"
178+
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
179+
"independent]"
181180
},
182181
)

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from llmcompressor.core.session_functions import active_session
88
from llmcompressor.datasets import get_calibration_dataloader
99
from llmcompressor.entrypoints.utils import post_process, pre_process
10-
from llmcompressor.pipelines.registry import get_pipeline_fn
10+
from llmcompressor.pipelines.registry import CalibrationPipeline
1111

1212
__all__ = ["Oneshot", "oneshot"]
1313

@@ -168,9 +168,10 @@ def apply_recipe_modifiers(
168168
recipe_args=self.recipe_args.recipe_args,
169169
)
170170

171+
user_pipeline = self.dataset_args.pipeline
171172
modifiers = session.get_modifiers()
172-
_, pipeline_fn = get_pipeline_fn(self.dataset_args.pipeline, modifiers)
173-
pipeline_fn(self.model, calibration_dataloader, self.dataset_args)
173+
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
174+
pipeline(self.model, calibration_dataloader, self.dataset_args)
174175

175176
session.finalize()
176177

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from llmcompressor.core import Event, EventType, State
1313
from llmcompressor.modifiers.modifier import Modifier
1414
from llmcompressor.modifiers.utils.hooks import HooksMixin
15-
from llmcompressor.pipelines.basic import run_pipeline as run_basic
1615
from llmcompressor.utils.pytorch.module import (
1716
get_layers,
1817
get_no_split_params,
@@ -247,7 +246,7 @@ def _infer_owl_layer_sparsity(
247246
return sparsities
248247

249248
def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
250-
from llmcompressor.args import DatasetArguments
249+
from llmcompressor.pipelines.basic import run_calibration
251250

252251
acts = defaultdict(int)
253252

@@ -263,7 +262,7 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
263262
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
264263
)
265264
with HooksMixin.disable_hooks(keep=hooks):
266-
run_basic(model, dataloader, DatasetArguments())
265+
run_calibration(model, dataloader)
267266
self.remove_hooks(hooks)
268267

269268
return acts
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# flake8: noqa
2+
# populate registry
3+
from .basic import *
4+
from .data_free import *
5+
from .independent import *
6+
from .layer_sequential import *
7+
from .registry import *
8+
from .sequential import *
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# flake8: noqa
2-
from .pipeline import run_pipeline
2+
from .pipeline import *
Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Union
22

33
import torch
44
import tqdm
@@ -7,40 +7,49 @@
77

88
from llmcompressor.core import LifecycleCallbacks
99
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
10+
from llmcompressor.pipelines.registry import CalibrationPipeline
1011
from llmcompressor.pytorch.utils.helpers import tensors_to_device
1112
from llmcompressor.utils.helpers import calibration_forward_context
1213

1314
if TYPE_CHECKING:
1415
from llmcompressor.args.dataset_arguments import DatasetArguments
1516

16-
__all__ = ["run_pipeline"]
17+
__all__ = ["BasicPipeline", "run_calibration"]
1718

1819

19-
def run_pipeline(
20-
model: torch.nn.Module,
21-
dataloader: DataLoader,
22-
dataset_args: "DatasetArguments",
23-
):
24-
"""
25-
Run a basic data pipeline.
20+
@CalibrationPipeline.register("basic")
21+
class BasicPipeline(CalibrationPipeline):
22+
@staticmethod
23+
def __call__(
24+
model: torch.nn.Module,
25+
dataloader: DataLoader,
26+
dataset_args: Union["DatasetArguments", None],
27+
):
28+
"""
29+
Run a basic data pipeline.
2630
27-
Batches are fetched from the data loader and are used to perform forward passes
28-
through the model. This pipeline is typically used for basic model calibration
29-
and, unlike the sequential pipelines, does not propagate compression error when
30-
used to calibrate model compression
31+
Batches are fetched from the data loader and are used to perform forward passes
32+
through the model. This pipeline is typically used for basic model calibration
33+
and, unlike the sequential pipelines, does not propagate compression error when
34+
used to calibrate model compression
3135
32-
:param model: model being calibrated
33-
:param dataloader: loads data for calibration
34-
:param dataset_args: dataset arguments relevant to pipelines
35-
"""
36-
model_device = get_execution_device(model)
36+
:param model: model being calibrated
37+
:param dataloader: loads data for calibration
38+
:param dataset_args: dataset arguments relevant to pipelines
39+
"""
40+
model_device = get_execution_device(model)
3741

38-
LifecycleCallbacks.calibration_epoch_start()
42+
LifecycleCallbacks.calibration_epoch_start()
3943

40-
with calibration_forward_context(model):
41-
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
42-
batch = apply_pad_mask_to_batch(batch)
43-
batch = tensors_to_device(batch, model_device)
44-
model(**batch)
44+
with calibration_forward_context(model):
45+
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
46+
batch = apply_pad_mask_to_batch(batch)
47+
batch = tensors_to_device(batch, model_device)
48+
model(**batch)
4549

46-
LifecycleCallbacks.calibration_epoch_end()
50+
LifecycleCallbacks.calibration_epoch_end()
51+
52+
53+
def run_calibration(model: torch.nn.Module, dataloader: DataLoader):
54+
pipeline = BasicPipeline()
55+
pipeline(model, dataloader, None)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# flake8: noqa
2-
from .pipeline import run_pipeline
2+
from .pipeline import *
Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
22

33
import torch
44
from torch.utils.data.dataloader import DataLoader
55

66
from llmcompressor.core.session_functions import LifecycleCallbacks
7+
from llmcompressor.pipelines.registry import CalibrationPipeline
78

89
if TYPE_CHECKING:
910
from llmcompressor.args.dataset_arguments import DatasetArguments
1011

11-
__all__ = ["run_pipeline"]
12+
__all__ = ["DataFreePipeline"]
1213

1314

14-
def run_pipeline(
15-
model: torch.nn.Module,
16-
dataloader: DataLoader,
17-
dataset_args: "DatasetArguments",
18-
):
19-
"""
20-
A pipeline for data-free calibration
15+
@CalibrationPipeline.register("datafree")
16+
class DataFreePipeline(CalibrationPipeline):
17+
@staticmethod
18+
def __call__(
19+
model: torch.nn.Module,
20+
dataloader: Optional[DataLoader],
21+
dataset_args: "DatasetArguments",
22+
):
23+
"""
24+
A pipeline for data-free calibration
2125
22-
:param model: model being calibrated
23-
:param dataloader: loads data for calibration
24-
:param dataset_args: dataset arguments relevant to pipelines
25-
"""
26-
LifecycleCallbacks.calibration_epoch_start()
27-
LifecycleCallbacks.calibration_epoch_end()
26+
:param model: model being calibrated
27+
:param dataloader: loads data for calibration
28+
:param dataset_args: dataset arguments relevant to pipelines
29+
"""
30+
LifecycleCallbacks.calibration_epoch_start()
31+
LifecycleCallbacks.calibration_epoch_end()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# flake8: noqa
2-
from .pipeline import run_pipeline
2+
from .pipeline import *

src/llmcompressor/pipelines/independent/pipeline.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,46 @@
66

77
from llmcompressor.core import active_session
88
from llmcompressor.modifiers.stage import StageModifiers
9+
from llmcompressor.pipelines.registry import CalibrationPipeline
910
from llmcompressor.utils.helpers import patch_attr
1011

1112
if TYPE_CHECKING:
1213
from llmcompressor.args.dataset_arguments import DatasetArguments
1314

14-
__all__ = ["run_pipeline"]
15-
16-
17-
def run_pipeline(
18-
model: torch.nn.Module,
19-
dataloader: DataLoader,
20-
dataset_args: "DatasetArguments",
21-
):
22-
"""
23-
Data pipeline where each modifier is assigned its own calibration epoch and data
24-
pipeline
25-
26-
:param model: model being calibrated
27-
:param dataloader: loads data for calibration
28-
:param dataset_args: dataset arguments relevant to pipelines
29-
"""
30-
# avoid circular import
31-
from llmcompressor.pipelines.registry import get_pipeline_fn
32-
33-
session = active_session()
34-
35-
modifiers = session.get_modifiers()
36-
with patch_attr(session.lifecycle, "modifiers", None):
37-
for index, modifier in enumerate(modifiers):
38-
mod_type = str(type(modifier).__name__)
39-
session.lifecycle.modifiers = [
40-
StageModifiers(modifiers=[modifier], group=mod_type, index=index)
41-
]
42-
43-
pipeline, pipeline_fn = get_pipeline_fn(user=None, modifiers=[modifier])
44-
logger.info(f"Inferred `{pipeline}` calibration pipeline for `{mod_type}`")
45-
46-
pipeline_fn(model, dataloader, dataset_args)
47-
48-
# restore modifiers on exit for proper model compression inference from recipe
15+
__all__ = ["IndependentPipeline"]
16+
17+
18+
@CalibrationPipeline.register("independent")
19+
class IndependentPipeline(CalibrationPipeline):
20+
@staticmethod
21+
def __call__(
22+
model: torch.nn.Module,
23+
dataloader: DataLoader,
24+
dataset_args: "DatasetArguments",
25+
):
26+
"""
27+
Data pipeline where each modifier is assigned its own calibration epoch and data
28+
pipeline
29+
30+
:param model: model being calibrated
31+
:param dataloader: loads data for calibration
32+
:param dataset_args: dataset arguments relevant to pipelines
33+
"""
34+
_logger = logger.patch(lambda r: r.update(function="IndependentPipeline"))
35+
36+
session = active_session()
37+
modifiers = session.get_modifiers()
38+
with patch_attr(session.lifecycle, "modifiers", None):
39+
for index, modifier in enumerate(modifiers):
40+
mod_type = str(type(modifier).__name__)
41+
session.lifecycle.modifiers = [
42+
StageModifiers(modifiers=[modifier], group=mod_type, index=index)
43+
]
44+
45+
pipeline = CalibrationPipeline.from_modifiers([modifier])
46+
pipeline_name = pipeline.__class__.__name__
47+
_logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`")
48+
49+
pipeline(model, dataloader, dataset_args)
50+
51+
# restore modifiers on exit so model can be compressed based on recipe

0 commit comments

Comments
 (0)