Skip to content

[WIP] [Design] LLMCompressor Class #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from
Draft

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Mar 15, 2025

LLMCompressor Class

from llmcompressor.core.llmcompressor.llmcompressor import LLMCompressor
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier

model_id = "meta-llama/Llama-3.1-8B-Instruct"
recipe = [
    SmoothQuantModifier(smoothing_strength=0.8),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"])
]

compressor = LLMCompressor(model_id, recipe)
compressor.set_calibration_dataset("ultrachat_200k", split="train_sft[:512]")
compressor.post_train(save_path="save_path")

Status

  • All core functionality has been implemented except recipe args
  • All functionality needs rigorous testing and regression evaluation

Purpose

The primary purpose of this design is to simplify the core logic of LLM Compressor

Simplified Features

  • Separate functions for model loading, recipe loading, dataset loading, and oneshot/train
  • Variables are attached directly to the LLM Compressor class instance, rather than being obfuscated behind session functions
  • Recipes are handled by one function which returns a list of modifiers, rather than being handled by multiple layers of recipe abstractions

Maintained Functionality

  • PTQ and Training integration interface
    • Global access to compressor through which to trigger events
    • Global steps can either by handled by integrator using EventLifeCycle auto-stepping
  • Event lifecycle validation
    • Now handled by EventLifeCycle which implements minimally invasive decorators
    • EventLifeCycle handles auto-stepping and order validation
  • Dataset processing
    • Calibration datasets and training datasets are decoupled
  • SFT Training pathway
    • Implemented through a training mixin

Removed Classes/Abstractions

  • Recipe classes
  • StageModifiers
  • Stage Runner
  • LifecycleCallbacks
  • Session/Session Globals
  • Lifecycle
  • Event class is greatly simplified
  • ModifiedState

Questions

  • Is there any recipe metadata outside of modifiers worth saving/recording?
  • What is required in order to support recipe args?
  • Is there any case where a modifier would want to start/stops on events besides batch_start/end?

Integration Examples

PTQ

compressor = LLMCompressor(model, recipe)

global_step = 0
compressor.initialize()
for batch in calibration_data:
    compressor.batch_start(batch_index)
    outputs = model(**batch)
    compressor.batch_end()
    global_step += 1

compressor.finalize()
model.save_pretrained(...)

Training

compressor = LLMCompressor(model, recipe)

compressor.initialize()
for epoch in num_epochs:
    for batch in training_data:
        compressor.batch_start(global_step=epoch)
        outputs = model(**batch)

        loss = loss_fn(labels, outputs)
        loss = compressor.update_loss(loss)
        loss.backwards()
        
        compressor.pre_optim()
        optimizer.step()
        compressor.post_optim()
    
        compressor.batch_end()
        
    if save_checkpoint:
        model.save_pretrained(...)

compressor.finalize()
model.save_pretrained(...)

Future Extensions

### Teacher/Delayed-State Training Integration ### ```python3 compressor = LLMCompressor(model, recipe)

global_step = 0
compressor.initialize()
compressor.update_state(teacher=teacher)
for epoch in num_epochs:
for batch in training_data:
...


### Multi-round PTQ ###
Add a finalized_modifiers attribute. When modifiers finalize, move from modifiers list to finalized_modifiers list
```python3
compressor = LLMCompressor(model, pruning_recipe)
# round 1: sparsification
compressor.set_calibration_dataset(dataset_one)
compressor.compress(calibration_pipeline="basic")

# round 2: quantization
compressor.append_recipe(quantization_recipe)
compressor.set_calibration_dataset(dataset_two)
compressor.compress(calibration_pipeline="sequential")

model.save_pretrained(...)

Recipe-Tailored Custom Device Map

compressor = LLMCompressor(model_stub, recipe, device_map="auto")
compressor.set_calibration_dataset(dataset)
compressor.post_train()

compressor.model.save_pretrained(...)

Alternating Oneshot/SFT

Add a finalized_modifiers attribute. When modifiers finalize, move from modifiers list to finalized_modifiers list

compressor = LLMCompressor(model, training_recipe)
# round 1: training
compressor.set_train_dataset(dataset_one)
compressor.train(**training_kwargs)

# round 2: compression
compressor.append_recipe(quantization_recipe)
compressor.set_calibration_dataset(dataset_two)
compressor.post_train()

model.save_pretrained(...)

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@vllm-project vllm-project deleted a comment from github-actions bot Mar 17, 2025
Signed-off-by: Kyle Sayers <[email protected]>
…ve quantization modifier from gptq

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs mentioned this pull request May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants