Skip to content
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
351841a
model loading
kylesayrs Mar 15, 2025
75d7d1e
datasets
kylesayrs Mar 15, 2025
4a22b90
post train works
kylesayrs Mar 15, 2025
48a1c40
Merge remote-tracking branch 'origin' into kylesayrs/llm-compressor
kylesayrs Mar 15, 2025
4761351
pipeline resolution
kylesayrs Mar 15, 2025
7c4dee4
style
kylesayrs Mar 15, 2025
63252ad
implement train skeleton
kylesayrs Mar 15, 2025
e32e1c4
cleanup
kylesayrs Mar 15, 2025
51fb047
extract data pipelines
kylesayrs Mar 15, 2025
710fe24
extract data pipeline events, integrate smoothquant, begin independen…
kylesayrs Mar 16, 2025
0a3f8f2
model saving
kylesayrs Mar 16, 2025
7f59359
add calibration data check
kylesayrs Mar 16, 2025
abf1818
add save path
kylesayrs Mar 17, 2025
e33793e
only send after start and before end
kylesayrs Mar 17, 2025
e8a2fe9
move initialize and finalize into pipelines
kylesayrs Mar 18, 2025
058ccf6
WIP: implement get_modifiers_from_recipe
kylesayrs Mar 24, 2025
81c60f1
merge with extract pipelines, remove event dependency for current_index
kylesayrs Mar 26, 2025
244ae34
merge in layerwise performance
kylesayrs Mar 27, 2025
43708af
trainer integration, remove pipeline from quantization modifier, remo…
kylesayrs Mar 27, 2025
bb2def2
add entrypoints
kylesayrs Mar 27, 2025
6274601
remove custom data classes
kylesayrs Mar 27, 2025
0281234
remove some no-longer-relevant tests
kylesayrs Mar 28, 2025
02834f4
simplify data args
kylesayrs Mar 28, 2025
ebb0410
reduce import path length
kylesayrs Mar 28, 2025
34b88f8
remove llmcompressor folder
kylesayrs Mar 28, 2025
eccddaa
remove unused file
kylesayrs Mar 28, 2025
02d81e9
move out resolve_modifier_quantization_config
kylesayrs Mar 28, 2025
fffd20a
rename file
kylesayrs Mar 28, 2025
e5c66b7
reduce core import dependency on modifiers
kylesayrs Mar 28, 2025
c3ba7ca
validated training
kylesayrs Mar 28, 2025
d4552eb
training with distillation works
kylesayrs Mar 29, 2025
4c3e70d
cleanup
kylesayrs Mar 29, 2025
bd9ca1f
remove typehinting
kylesayrs Mar 29, 2025
a525a3c
enable quantization during calibration
kylesayrs Mar 29, 2025
30c7169
update script
kylesayrs Mar 31, 2025
2c3e39b
break out register_calibration_hooks
kylesayrs Mar 31, 2025
da62925
WIP
kylesayrs Apr 1, 2025
1e88239
clean up calibration, allow shapes to be iterated during tracing
kylesayrs Apr 1, 2025
65f7912
comment
kylesayrs Apr 1, 2025
2da0916
confirm whisper
kylesayrs Apr 1, 2025
6c7dad7
WIP
kylesayrs Apr 2, 2025
4e5fb5c
use calibration_epoch_end in basic pipeline
kylesayrs Apr 2, 2025
d0f6790
qmod
kylesayrs Apr 2, 2025
12eb66f
handle no-data
kylesayrs Apr 2, 2025
b00ca59
skip
kylesayrs Apr 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class EventType(Enum):
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"
SEQUENTIAL_BATCH_END = "sequential_batch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
Expand Down Expand Up @@ -82,6 +83,10 @@ class Event:
global_step: int = 0
global_batch: int = 0

def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

@property
def epoch_based(self) -> bool:
"""
Expand Down
121 changes: 121 additions & 0 deletions src/llmcompressor/core/llmcompressor/event_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, List, Optional

from loguru import logger

from llmcompressor.core.events import Event, EventType
from llmcompressor.utils.singleton import SingletonMixin

if TYPE_CHECKING:
from llmcompressor.core.llmcompressor.events_mixin import EventsMixin


class EventsLifecycle(SingletonMixin):
auto_step: Optional[bool] = None
event_order: List[EventType] = [
EventType.BATCH_START,
EventType.LOSS_CALCULATED,
EventType.OPTIM_PRE_STEP,
EventType.OPTIM_POST_STEP,
EventType.BATCH_END,
]
last_event_type: Optional[EventType] = EventType.BATCH_END
initialized: bool = False
finalized: bool = False

@classmethod
def initialize(cls, fn: Callable[[Any], Any]):
def validator(self: "EventsMixin", **kwargs):
if cls.initialized:
raise ValueError("Cannot initialize twice")
cls.initialized = True
cls.finalized = False

return cls._wrap_with_validation(fn, validator)

@classmethod
def finalize(cls, fn: Callable[[Any], Any]):
def validator(self: "EventsMixin", **kwargs):
if not cls.initialized:
raise ValueError("Cannot finalize before initializing")
if cls.finalized:
raise ValueError("Cannot finalize twice")
cls.finalized = True
cls.initialized = False

return cls._wrap_with_validation(fn, validator)

@classmethod
def global_step(cls, fn: Callable[[Any], Any]):
def validator(self: "EventsMixin", global_step: Optional[int] = None, **kwargs):
# configure auto step
if cls.auto_step is None:
if global_step is None:
logger.info(
"No global_step was passed to batch_start event, "
"auto-stepping based on batches"
)
cls.auto_step = True
else:
cls.auto_step = False

# auto step
if global_step is None:
if not cls.auto_step:
raise ValueError(
"Cannot auto-step batches if global_step was "
"previously passed to batch_start event"
)
global_step = self.state.current_index + 1
else:
if cls.auto_step:
raise ValueError(
"Cannot auto-step batches if global_step "
"was passed to batch_start event"
)

# validate order
if global_step <= self.state.current_index:
raise ValueError("global_step must be greater than the current index")

self.state.current_index = global_step

return cls._wrap_with_validation(fn, validator)

@classmethod
def event(cls, fn: Callable[[Any], Any]):
def validator(self: "EventsMixin", event: Event):
event_type = event.type_

# ignore unhandled events
if event_type not in cls.event_order:
return

# validate
if event_type == EventType.BATCH_START:
valid = cls.last_event_type != EventType.BATCH_START
else:
last_event_index = cls.event_order.index(cls.last_event_type)
curr_event_index = cls.event_order.index(event_type)
valid = last_event_index <= curr_event_index

if not valid:
raise ValueError(
f"Lifecycle events must appear in order: {cls.event_order}. "
f"Instead, {cls.last_event_type} was called before {event_type}"
)

cls.last_event_type = event_type

return cls._wrap_with_validation(fn, validator)

@classmethod
def _wrap_with_validation(
cls, fn: Callable[[Any], Any], validator: Callable[[Any], Any]
) -> Callable:
@wraps(fn)
def wrapped(*args, **kwargs):
validator(*args, **kwargs)
return fn(*args, **kwargs)

return wrapped
74 changes: 74 additions & 0 deletions src/llmcompressor/core/llmcompressor/events_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from abc import ABC
from typing import List

import torch

from llmcompressor.core import Event, EventType, State
from llmcompressor.core.llmcompressor.event_lifecycle import EventsLifecycle
from llmcompressor.modifiers import Modifier
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)


class EventsMixin(ABC):
state: State
modifiers: List[Modifier]

@EventsLifecycle.initialize
def initialize(self):
for modifier in self.modifiers:
modifier.on_initialize(self.state)

@EventsLifecycle.finalize
def finalize(self):
for modifier in self.modifiers:
modifier.on_finalize(self.state)

# TODO: log info stating that save_pretrained has been modified
# TODO: make sure wrapped function can access new recipe and processor
modify_save_pretrained(self.state.model)

def update_state(self, **kwargs):
self.state.update(**kwargs)
# if future modifiers require update, do that update here

@EventsLifecycle.global_step
def batch_start(self, **kwargs):
# modifiers can only start on batch_start
for modifier in self.modifiers:
if modifier.should_start(self.state):
modifier.on_start(self.state)

event = Event(type_=EventType.BATCH_START, **kwargs)
self._handle_event(event)

def pre_optim(self, **kwargs):
event = Event(type_=EventType.OPTIM_PRE_STEP, **kwargs)
self._handle_event(event)

def post_optim(self, **kwargs):
event = Event(type_=EventType.OPTIM_POST_STEP, **kwargs)
self._handle_event(event)

def update_loss(self, loss: torch.Tensor, **kwargs):
event = Event(type_=EventType.LOSS_CALCULATED, loss=loss, **kwargs)
self._handle_event(event)

def sequential_batch_end(self, **kwargs):
event = Event(type_=EventType.SEQUENTIAL_BATCH_END, **kwargs)
self._handle_event(event)

def batch_end(self, **kwargs):
# modifiers can only end on batch_end
for modifier in self.modifiers:
if modifier.should_end(self.state):
modifier.on_end(self.state)

event = Event(type_=EventType.BATCH_END, **kwargs)
self._handle_event(event)

@EventsLifecycle.event
def _handle_event(self, event: Event):
for modifier in self.modifiers:
modifier.on_event(self.state, event)
25 changes: 25 additions & 0 deletions src/llmcompressor/core/llmcompressor/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import TYPE_CHECKING

from transformers import PreTrainedModel

if TYPE_CHECKING:
from llmcompressor.core import State
from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor


def get_compressor() -> "LLMCompressor":
from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor

return LLMCompressor.instance()


def get_state() -> "State":
from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor

return LLMCompressor.instance().state


def get_model() -> PreTrainedModel:
from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor

return LLMCompressor.instance().state.model
60 changes: 60 additions & 0 deletions src/llmcompressor/core/llmcompressor/llmcompressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Union

from torch.utils.data import DataLoader

from llmcompressor.args.model_arguments import ModelArguments
from llmcompressor.core import State
from llmcompressor.core.llmcompressor.events_mixin import EventsMixin
from llmcompressor.core.llmcompressor.train import HFSFTMixin
from llmcompressor.core.llmcompressor.utils import (
LCDatasetArguments,
check_for_calibration_data,
get_modifiers_from_recipe,
parse_args,
prepare_models,
resolve_calibration_pipeline,
)
from llmcompressor.datasets.utils import get_calibration_dataloader
from llmcompressor.modifiers import Modifier
from llmcompressor.pytorch.model_load.helpers import save_checkpoint
from llmcompressor.recipe import RecipeInput
from llmcompressor.typing import DatasetType, ModelInput
from llmcompressor.utils.singleton import SingletonMixin


class LLMCompressor(SingletonMixin, EventsMixin, HFSFTMixin):
state: State
modifiers: List[Modifier]
calibration_loader: Optional[DataLoader] = None

def __init__(self, model: ModelInput, recipe: RecipeInput, **kwargs):
model_args = parse_args(ModelArguments, model=model, **kwargs)

self.modifiers = get_modifiers_from_recipe(recipe)

model, teacher, processor = prepare_models(model_args)
self.state = State(model=model, teacher_model=teacher, processor=processor)

def set_calibration_dataset(self, dataset: Union[str, DatasetType], **kwargs):
dataset_args = parse_args(LCDatasetArguments, dataset=dataset, **kwargs)

# temporary hack to support better interface
if dataset_args.split is not None:
dataset_args.splits = {"calibration": dataset_args.split}

self.calibration_loader = get_calibration_dataloader(
dataset_args, self.state.processor
)

def post_train(self, pipeline: Optional[str] = None, save_path: Optional[str] = None):
check_for_calibration_data(self.modifiers, self.calibration_loader)
pipeline_fn, pipeline_kwargs = resolve_calibration_pipeline(
pipeline, self.modifiers
)

self.initialize()
pipeline_fn(self.state.model, self.calibration_loader, **pipeline_kwargs)
self.finalize()

if save_path is not None:
save_checkpoint(save_path, self.state.model, self.state.processor)
62 changes: 62 additions & 0 deletions src/llmcompressor/core/llmcompressor/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import math
from typing import TYPE_CHECKING, Optional, Union

from llmcompressor.args.training_arguments import TrainingArguments
from llmcompressor.core import State
from llmcompressor.core.llmcompressor.utils import LCDatasetArguments, parse_args
from llmcompressor.datasets.utils import get_processed_dataset
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.typing import DatasetType

if TYPE_CHECKING:
from transformers.data.data_collator import DataCollator


class HFSFTMixin:
state: State
train_dataset: Optional[DatasetType] = None
train_data_collator: Optional["DataCollator"] = None

def set_train_dataset(self, dataset: Union[str, DatasetType], **kwargs):
dataset_args = parse_args(LCDatasetArguments, dataset=dataset, **kwargs)

processed_dataset = get_processed_dataset(
dataset_args=dataset_args,
processor=self.state.processor,
)
self.train_dataset = processed_dataset.get("train")

def train(self, **kwargs):
raise NotImplementedError(
"Implementing LLMCompressor.train would require "
"changes which break existing training pathways"
)

training_args = parse_args(TrainingArguments, **kwargs)

trainer = Trainer(
model=self.state.model,
teacher=self.state.teacher_model,
# recipe=recipe_args.recipe,
# recipe_args=recipe_args.recipe_args,
args=training_args,
# model_args=model_args,
# dataset_args=dataset_args,
train_dataset=self.train_dataset,
processing_class=self.state.processor,
data_collator=self.train_data_collator,
)

# run training
checkpoint = training_args.resume_from_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)

# save metrics
metrics = train_result.metrics
metrics["train_samples"] = len(self.train_dataset)
metrics["perplexity"] = math.exp(metrics["train_loss"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

# save model
trainer.save_model(output_dir=training_args.output_dir)
Loading
Loading