-
Notifications
You must be signed in to change notification settings - Fork 306
[WIP] [Design] LLMCompressor Class #1256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
45
commits into
main
Choose a base branch
from
kylesayrs/llm-compressor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 13 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
351841a
model loading
kylesayrs 75d7d1e
datasets
kylesayrs 4a22b90
post train works
kylesayrs 48a1c40
Merge remote-tracking branch 'origin' into kylesayrs/llm-compressor
kylesayrs 4761351
pipeline resolution
kylesayrs 7c4dee4
style
kylesayrs 63252ad
implement train skeleton
kylesayrs e32e1c4
cleanup
kylesayrs 51fb047
extract data pipelines
kylesayrs 710fe24
extract data pipeline events, integrate smoothquant, begin independen…
kylesayrs 0a3f8f2
model saving
kylesayrs 7f59359
add calibration data check
kylesayrs abf1818
add save path
kylesayrs e33793e
only send after start and before end
kylesayrs e8a2fe9
move initialize and finalize into pipelines
kylesayrs 058ccf6
WIP: implement get_modifiers_from_recipe
kylesayrs 81c60f1
merge with extract pipelines, remove event dependency for current_index
kylesayrs 244ae34
merge in layerwise performance
kylesayrs 43708af
trainer integration, remove pipeline from quantization modifier, remo…
kylesayrs bb2def2
add entrypoints
kylesayrs 6274601
remove custom data classes
kylesayrs 0281234
remove some no-longer-relevant tests
kylesayrs 02834f4
simplify data args
kylesayrs ebb0410
reduce import path length
kylesayrs 34b88f8
remove llmcompressor folder
kylesayrs eccddaa
remove unused file
kylesayrs 02d81e9
move out resolve_modifier_quantization_config
kylesayrs fffd20a
rename file
kylesayrs e5c66b7
reduce core import dependency on modifiers
kylesayrs c3ba7ca
validated training
kylesayrs d4552eb
training with distillation works
kylesayrs 4c3e70d
cleanup
kylesayrs bd9ca1f
remove typehinting
kylesayrs a525a3c
enable quantization during calibration
kylesayrs 30c7169
update script
kylesayrs 2c3e39b
break out register_calibration_hooks
kylesayrs da62925
WIP
kylesayrs 1e88239
clean up calibration, allow shapes to be iterated during tracing
kylesayrs 65f7912
comment
kylesayrs 2da0916
confirm whisper
kylesayrs 6c7dad7
WIP
kylesayrs 4e5fb5c
use calibration_epoch_end in basic pipeline
kylesayrs d0f6790
qmod
kylesayrs 12eb66f
handle no-data
kylesayrs b00ca59
skip
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
src/llmcompressor/core/llmcompressor/event_lifecycle.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from functools import wraps | ||
| from typing import TYPE_CHECKING, Any, Callable, List, Optional | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from llmcompressor.core.events import Event, EventType | ||
| from llmcompressor.utils.singleton import SingletonMixin | ||
|
|
||
| if TYPE_CHECKING: | ||
| from llmcompressor.core.llmcompressor.events_mixin import EventsMixin | ||
|
|
||
|
|
||
| class EventsLifecycle(SingletonMixin): | ||
| auto_step: Optional[bool] = None | ||
| event_order: List[EventType] = [ | ||
| EventType.BATCH_START, | ||
| EventType.LOSS_CALCULATED, | ||
| EventType.OPTIM_PRE_STEP, | ||
| EventType.OPTIM_POST_STEP, | ||
| EventType.BATCH_END, | ||
| ] | ||
| last_event_type: Optional[EventType] = EventType.BATCH_END | ||
| initialized: bool = False | ||
| finalized: bool = False | ||
|
|
||
| @classmethod | ||
| def initialize(cls, fn: Callable[[Any], Any]): | ||
| def validator(self: "EventsMixin", **kwargs): | ||
| if cls.initialized: | ||
| raise ValueError("Cannot initialize twice") | ||
| cls.initialized = True | ||
| cls.finalized = False | ||
|
|
||
| return cls._wrap_with_validation(fn, validator) | ||
|
|
||
| @classmethod | ||
| def finalize(cls, fn: Callable[[Any], Any]): | ||
| def validator(self: "EventsMixin", **kwargs): | ||
| if not cls.initialized: | ||
| raise ValueError("Cannot finalize before initializing") | ||
| if cls.finalized: | ||
| raise ValueError("Cannot finalize twice") | ||
| cls.finalized = True | ||
| cls.initialized = False | ||
|
|
||
| return cls._wrap_with_validation(fn, validator) | ||
|
|
||
| @classmethod | ||
| def global_step(cls, fn: Callable[[Any], Any]): | ||
| def validator(self: "EventsMixin", global_step: Optional[int] = None, **kwargs): | ||
| # configure auto step | ||
| if cls.auto_step is None: | ||
| if global_step is None: | ||
| logger.info( | ||
| "No global_step was passed to batch_start event, " | ||
| "auto-stepping based on batches" | ||
| ) | ||
| cls.auto_step = True | ||
| else: | ||
| cls.auto_step = False | ||
|
|
||
| # auto step | ||
| if global_step is None: | ||
| if not cls.auto_step: | ||
| raise ValueError( | ||
| "Cannot auto-step batches if global_step was " | ||
| "previously passed to batch_start event" | ||
| ) | ||
| global_step = self.state.current_index + 1 | ||
| else: | ||
| if cls.auto_step: | ||
| raise ValueError( | ||
| "Cannot auto-step batches if global_step " | ||
| "was passed to batch_start event" | ||
| ) | ||
|
|
||
| # validate order | ||
| if global_step <= self.state.current_index: | ||
| raise ValueError("global_step must be greater than the current index") | ||
|
|
||
| self.state.current_index = global_step | ||
|
|
||
| return cls._wrap_with_validation(fn, validator) | ||
|
|
||
| @classmethod | ||
| def event(cls, fn: Callable[[Any], Any]): | ||
| def validator(self: "EventsMixin", event: Event): | ||
| event_type = event.type_ | ||
|
|
||
| # ignore unhandled events | ||
| if event_type not in cls.event_order: | ||
| return | ||
|
|
||
| # validate | ||
| if event_type == EventType.BATCH_START: | ||
| valid = cls.last_event_type != EventType.BATCH_START | ||
| else: | ||
| last_event_index = cls.event_order.index(cls.last_event_type) | ||
| curr_event_index = cls.event_order.index(event_type) | ||
| valid = last_event_index <= curr_event_index | ||
|
|
||
| if not valid: | ||
| raise ValueError( | ||
| f"Lifecycle events must appear in order: {cls.event_order}. " | ||
| f"Instead, {cls.last_event_type} was called before {event_type}" | ||
| ) | ||
|
|
||
| cls.last_event_type = event_type | ||
|
|
||
| return cls._wrap_with_validation(fn, validator) | ||
|
|
||
| @classmethod | ||
| def _wrap_with_validation( | ||
| cls, fn: Callable[[Any], Any], validator: Callable[[Any], Any] | ||
| ) -> Callable: | ||
| @wraps(fn) | ||
| def wrapped(*args, **kwargs): | ||
| validator(*args, **kwargs) | ||
| return fn(*args, **kwargs) | ||
|
|
||
| return wrapped |
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| from abc import ABC | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from llmcompressor.core import Event, EventType, State | ||
| from llmcompressor.core.llmcompressor.event_lifecycle import EventsLifecycle | ||
| from llmcompressor.modifiers import Modifier | ||
| from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( | ||
| modify_save_pretrained, | ||
| ) | ||
|
|
||
|
|
||
| class EventsMixin(ABC): | ||
| state: State | ||
| modifiers: List[Modifier] | ||
|
|
||
| @EventsLifecycle.initialize | ||
| def initialize(self): | ||
| for modifier in self.modifiers: | ||
| modifier.on_initialize(self.state) | ||
|
|
||
| @EventsLifecycle.finalize | ||
| def finalize(self): | ||
| for modifier in self.modifiers: | ||
| modifier.on_finalize(self.state) | ||
|
|
||
| # TODO: log info stating that save_pretrained has been modified | ||
| # TODO: make sure wrapped function can access new recipe and processor | ||
| modify_save_pretrained(self.state.model) | ||
|
|
||
| def update_state(self, **kwargs): | ||
| self.state.update(**kwargs) | ||
| # if future modifiers require update, do that update here | ||
|
|
||
| @EventsLifecycle.global_step | ||
| def batch_start(self, **kwargs): | ||
| # modifiers can only start on batch_start | ||
| for modifier in self.modifiers: | ||
| if modifier.should_start(self.state): | ||
| modifier.on_start(self.state) | ||
|
|
||
| event = Event(type_=EventType.BATCH_START, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| def pre_optim(self, **kwargs): | ||
| event = Event(type_=EventType.OPTIM_PRE_STEP, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| def post_optim(self, **kwargs): | ||
| event = Event(type_=EventType.OPTIM_POST_STEP, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| def update_loss(self, loss: torch.Tensor, **kwargs): | ||
| event = Event(type_=EventType.LOSS_CALCULATED, loss=loss, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| def sequential_batch_end(self, **kwargs): | ||
| event = Event(type_=EventType.SEQUENTIAL_BATCH_END, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| def batch_end(self, **kwargs): | ||
| # modifiers can only end on batch_end | ||
| for modifier in self.modifiers: | ||
| if modifier.should_end(self.state): | ||
| modifier.on_end(self.state) | ||
|
|
||
| event = Event(type_=EventType.BATCH_END, **kwargs) | ||
| self._handle_event(event) | ||
|
|
||
| @EventsLifecycle.event | ||
| def _handle_event(self, event: Event): | ||
| for modifier in self.modifiers: | ||
| modifier.on_event(self.state, event) |
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from transformers import PreTrainedModel | ||
|
|
||
| if TYPE_CHECKING: | ||
| from llmcompressor.core import State | ||
| from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor | ||
|
|
||
|
|
||
| def get_compressor() -> "LLMCompressor": | ||
| from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor | ||
|
|
||
| return LLMCompressor.instance() | ||
|
|
||
|
|
||
| def get_state() -> "State": | ||
| from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor | ||
|
|
||
| return LLMCompressor.instance().state | ||
|
|
||
|
|
||
| def get_model() -> PreTrainedModel: | ||
| from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor | ||
|
|
||
| return LLMCompressor.instance().state.model |
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| from typing import List, Optional, Union | ||
|
|
||
| from torch.utils.data import DataLoader | ||
|
|
||
| from llmcompressor.args.model_arguments import ModelArguments | ||
| from llmcompressor.core import State | ||
| from llmcompressor.core.llmcompressor.events_mixin import EventsMixin | ||
| from llmcompressor.core.llmcompressor.train import HFSFTMixin | ||
| from llmcompressor.core.llmcompressor.utils import ( | ||
| LCDatasetArguments, | ||
| check_for_calibration_data, | ||
| get_modifiers_from_recipe, | ||
| parse_args, | ||
| prepare_models, | ||
| resolve_calibration_pipeline, | ||
| ) | ||
| from llmcompressor.datasets.utils import get_calibration_dataloader | ||
| from llmcompressor.modifiers import Modifier | ||
| from llmcompressor.pytorch.model_load.helpers import save_checkpoint | ||
| from llmcompressor.recipe import RecipeInput | ||
| from llmcompressor.typing import DatasetType, ModelInput | ||
| from llmcompressor.utils.singleton import SingletonMixin | ||
|
|
||
|
|
||
| class LLMCompressor(SingletonMixin, EventsMixin, HFSFTMixin): | ||
| state: State | ||
| modifiers: List[Modifier] | ||
| calibration_loader: Optional[DataLoader] = None | ||
|
|
||
| def __init__(self, model: ModelInput, recipe: RecipeInput, **kwargs): | ||
| model_args = parse_args(ModelArguments, model=model, **kwargs) | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.modifiers = get_modifiers_from_recipe(recipe) | ||
|
|
||
| model, teacher, processor = prepare_models(model_args) | ||
| self.state = State(model=model, teacher_model=teacher, processor=processor) | ||
|
|
||
| def set_calibration_dataset(self, dataset: Union[str, DatasetType], **kwargs): | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dataset_args = parse_args(LCDatasetArguments, dataset=dataset, **kwargs) | ||
|
|
||
| # temporary hack to support better interface | ||
| if dataset_args.split is not None: | ||
| dataset_args.splits = {"calibration": dataset_args.split} | ||
|
|
||
| self.calibration_loader = get_calibration_dataloader( | ||
| dataset_args, self.state.processor | ||
| ) | ||
|
|
||
| def post_train(self, pipeline: Optional[str] = None, save_path: Optional[str] = None): | ||
| check_for_calibration_data(self.modifiers, self.calibration_loader) | ||
| pipeline_fn, pipeline_kwargs = resolve_calibration_pipeline( | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pipeline, self.modifiers | ||
| ) | ||
|
|
||
| self.initialize() | ||
| pipeline_fn(self.state.model, self.calibration_loader, **pipeline_kwargs) | ||
| self.finalize() | ||
|
|
||
| if save_path is not None: | ||
| save_checkpoint(save_path, self.state.model, self.state.processor) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import math | ||
| from typing import TYPE_CHECKING, Optional, Union | ||
|
|
||
| from llmcompressor.args.training_arguments import TrainingArguments | ||
| from llmcompressor.core import State | ||
| from llmcompressor.core.llmcompressor.utils import LCDatasetArguments, parse_args | ||
| from llmcompressor.datasets.utils import get_processed_dataset | ||
| from llmcompressor.transformers.finetune.trainer import Trainer | ||
| from llmcompressor.typing import DatasetType | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers.data.data_collator import DataCollator | ||
|
|
||
|
|
||
| class HFSFTMixin: | ||
| state: State | ||
| train_dataset: Optional[DatasetType] = None | ||
| train_data_collator: Optional["DataCollator"] = None | ||
|
|
||
| def set_train_dataset(self, dataset: Union[str, DatasetType], **kwargs): | ||
| dataset_args = parse_args(LCDatasetArguments, dataset=dataset, **kwargs) | ||
|
|
||
| processed_dataset = get_processed_dataset( | ||
| dataset_args=dataset_args, | ||
| processor=self.state.processor, | ||
| ) | ||
| self.train_dataset = processed_dataset.get("train") | ||
|
|
||
| def train(self, **kwargs): | ||
| raise NotImplementedError( | ||
| "Implementing LLMCompressor.train would require " | ||
| "changes which break existing training pathways" | ||
| ) | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| training_args = parse_args(TrainingArguments, **kwargs) | ||
|
|
||
| trainer = Trainer( | ||
| model=self.state.model, | ||
| teacher=self.state.teacher_model, | ||
| # recipe=recipe_args.recipe, | ||
| # recipe_args=recipe_args.recipe_args, | ||
| args=training_args, | ||
| # model_args=model_args, | ||
| # dataset_args=dataset_args, | ||
| train_dataset=self.train_dataset, | ||
| processing_class=self.state.processor, | ||
| data_collator=self.train_data_collator, | ||
| ) | ||
|
|
||
| # run training | ||
| checkpoint = training_args.resume_from_checkpoint | ||
| train_result = trainer.train(resume_from_checkpoint=checkpoint) | ||
|
|
||
| # save metrics | ||
| metrics = train_result.metrics | ||
| metrics["train_samples"] = len(self.train_dataset) | ||
| metrics["perplexity"] = math.exp(metrics["train_loss"]) | ||
| trainer.log_metrics("train", metrics) | ||
| trainer.save_metrics("train", metrics) | ||
|
|
||
| # save model | ||
| trainer.save_model(output_dir=training_args.output_dir) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.