|
| 1 | +# Checkpoint Loading |
| 2 | + |
| 3 | +The PyTorch backend provides a flexible and extensible infrastructure for loading model checkpoints from different sources and formats, such as HuggingFace (HF) or custom formats, by implementing required components like the checkpoint's weight loader, mapper, and configuration parser. |
| 4 | + |
| 5 | +## Table of Contents |
| 6 | +1. [Overview](#overview) |
| 7 | +2. [Core Components](#core-components) |
| 8 | +3. [Built-in Checkpoint Formats](#built-in-checkpoint-formats) |
| 9 | +4. [Using Checkpoint Loaders](#using-checkpoint-loaders) |
| 10 | +5. [Creating Custom Checkpoint Loaders](#creating-custom-checkpoint-loaders) |
| 11 | + |
| 12 | +## Overview |
| 13 | + |
| 14 | +The checkpoint loading design is built around a plugin-like architecture that is separated into four distinct components: |
| 15 | + |
| 16 | +- **Checkpoint Loaders**: Orchestrates the loading process for specific formats. |
| 17 | +- **Config Loaders**: Handles model configuration parsing and validation. |
| 18 | +- **Weight Loaders**: Manages the actual loading of model weights from storage into memory. |
| 19 | +- **Weight Mappers**: Maps and transforms loaded weights to the TRTLLM model's definition. |
| 20 | + |
| 21 | +This modular design allows for easy extension to support new checkpoint formats while maintaining backward compatibility and performance optimizations. By separating checkpoint loading into four subcomponents, users can leverage existing implementations and introduce custom, checkpoint-specific components. |
| 22 | + |
| 23 | +To support a new checkpoint format, you must implement all four components. |
| 24 | +If the format shares components with an existing framework (such as HF), you only need to implement the components that differ. |
| 25 | + |
| 26 | +## Core Components |
| 27 | + |
| 28 | +### BaseCheckpointLoader |
| 29 | + |
| 30 | +The `BaseCheckpointLoader` is the central interface for all checkpoint loading operations. It provides a unified API regardless of the underlying checkpoint format. This interface is responsible for holding and exposing all objects required for the loading and parsing process. |
| 31 | + |
| 32 | +**Key Methods:** |
| 33 | +- `load_config(checkpoint_dir, **kwargs)`: Loads and returns a `ModelConfig` object |
| 34 | +- `load_weights(checkpoint_dir, **kwargs)`: Loads and returns a dictionary of weights |
| 35 | +- `get_initialized_weight_mapper(model, config)`: Returns a weight mapper initialized at runtime for the model |
| 36 | +- `cleanup()`: Releases resources and cleans up internal state |
| 37 | + |
| 38 | +### BaseConfigLoader |
| 39 | + |
| 40 | +Loads model configurations from checkpoint directories and parses them into a TRTLLM `ModelConfig`: |
| 41 | + |
| 42 | +```python |
| 43 | +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader |
| 44 | + |
| 45 | +class CustomConfigLoader(BaseConfigLoader): |
| 46 | + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: |
| 47 | + # Load and parse configuration from your custom format |
| 48 | + pretrained_config = self._get_pretrained_config(checkpoint_dir, **kwargs) |
| 49 | + |
| 50 | + return ModelConfig(pretrained_config=pretrained_config, |
| 51 | + ...) |
| 52 | + |
| 53 | + def _get_pretrained_config(self, checkpoint_dir, **kwargs): |
| 54 | + ... |
| 55 | + |
| 56 | +``` |
| 57 | + |
| 58 | +### BaseWeightLoader |
| 59 | + |
| 60 | +Handles the loading of model weights from storage: |
| 61 | + |
| 62 | +```python |
| 63 | +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader |
| 64 | + |
| 65 | +class CustomWeightLoader(BaseWeightLoader): |
| 66 | + def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: |
| 67 | + # Load weights from your custom format |
| 68 | + # Return a dictionary mapping parameter names to tensors |
| 69 | + return weights_dict |
| 70 | +``` |
| 71 | + |
| 72 | +### BaseWeightMapper |
| 73 | + |
| 74 | +Transforms weights between different naming conventions and applies model-specific transformations to the TRTLLM model object. |
| 75 | + |
| 76 | +## Built-in Checkpoint Formats |
| 77 | + |
| 78 | +### HuggingFace Format |
| 79 | + |
| 80 | +Currently, the HF checkpoint loader is the primary built-in format and supports: |
| 81 | + |
| 82 | +- **Weights loading** (`.safetensors, .bin, .pth`): Load HF-compatible weights from disk |
| 83 | +- **Configuration parser** - Parse configuration information stored by HF into a TRTLLM `ModelConfig` object |
| 84 | +- **Weights Mapping** - Convert HF weights into a TRTLLM-compatible representation |
| 85 | + |
| 86 | +## Using Checkpoint Loaders |
| 87 | + |
| 88 | +### Basic Usage |
| 89 | + |
| 90 | +There are two main approaches for using checkpoint loading objects |
| 91 | + |
| 92 | +The first approach is through the llm-api, as shown in the following example: |
| 93 | + |
| 94 | +```python |
| 95 | +from tensorrt_llm import LLM |
| 96 | + |
| 97 | +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" |
| 98 | + |
| 99 | +llm = LLM(model=hf_model_dir) |
| 100 | +``` |
| 101 | + |
| 102 | +In this example, the `HfCheckpointLoader` is selected by default. |
| 103 | + |
| 104 | +To explicitly set the checkpoint loader, specify the required checkpoint-specific loader: |
| 105 | + |
| 106 | +```python |
| 107 | +from tensorrt_llm import LLM |
| 108 | +from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader |
| 109 | + |
| 110 | +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" |
| 111 | + |
| 112 | +llm = LLM(model=hf_model_dir, |
| 113 | + checkpoint_loader=HfCheckpointLoader()) |
| 114 | +``` |
| 115 | + |
| 116 | +Similarly, to use a basic checkpoint loader with a specific subcomponent, provide the desired subcomponent as needed: |
| 117 | + |
| 118 | +```python |
| 119 | +from tensorrt_llm import LLM |
| 120 | +from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader |
| 121 | + |
| 122 | +hf_model_dir = "llama-models-v2/llama-v2-13b-hf" |
| 123 | + |
| 124 | +llm = LLM(model=hf_model_dir, |
| 125 | + checkpoint_loader=HfCheckpointLoader(weight_loader=MyCustomWeightLoader())) |
| 126 | +``` |
| 127 | + |
| 128 | +In the second approach, you can directly use the individual checkpoint loading components: |
| 129 | + |
| 130 | +```python |
| 131 | +from tensorrt_llm._torch.models.checkpoints.hf.gemma3_weight_mapper import \ |
| 132 | + Gemma3HfWeightMapper |
| 133 | +from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM |
| 134 | + |
| 135 | +gemma3 = Gemma3ForCausalLM(model_config) |
| 136 | +weight_mapper = Gemma3HfWeightMapper() |
| 137 | +weight_mapper.init_model_and_config(gemma3, model_config) |
| 138 | +gemma3.load_weights(hf_gemma3.state_dict(), weight_mapper) |
| 139 | +``` |
| 140 | +## Creating Custom Checkpoint Loaders |
| 141 | + |
| 142 | +To support a new checkpoint format, implement all four components. This section provides minimal templates for each. |
| 143 | + |
| 144 | +### When to Create Custom Components |
| 145 | + |
| 146 | +- **Complete New Format**: Implement all four components to support a new checkpoint format |
| 147 | +- **Custom Weight Storage**: Implement only a custom weight loader if you have a unique weight storage format (such as a custom binary format or database storage) |
| 148 | +- **Custom Configuration**: Implement only a custom config loader if your configuration format cannot be parsed by existing loaders |
| 149 | +- **Custom Weight Mapping**: Implement only a custom weight mapper if your model has unique weight naming or transformation requirements that are checkpoint-specific |
| 150 | + |
| 151 | +### Step 1: Create the Checkpoint Loader |
| 152 | + |
| 153 | +```python |
| 154 | +from typing import Optional |
| 155 | +from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader |
| 156 | +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader |
| 157 | +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader |
| 158 | +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper |
| 159 | +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader |
| 160 | + |
| 161 | +@register_checkpoint_loader("CUSTOM_FORMAT") |
| 162 | +class CustomCheckpointLoader(BaseCheckpointLoader): |
| 163 | + def __init__(self, |
| 164 | + *, |
| 165 | + weight_loader: Optional[BaseWeightLoader] = None, |
| 166 | + weight_mapper: Optional[BaseWeightMapper] = None, |
| 167 | + config_loader: Optional[BaseConfigLoader] = None): |
| 168 | + self._weight_loader = weight_loader or self.get_default_weight_loader() |
| 169 | + self._config_loader = config_loader or self.get_default_config_loader() |
| 170 | + self._weight_mapper = weight_mapper |
| 171 | + self._checkpoint_format = "CUSTOM_FORMAT" # Set the checkpoint format name |
| 172 | + |
| 173 | + def get_default_weight_loader(self) -> BaseWeightLoader: |
| 174 | + return CustomWeightLoader() |
| 175 | + |
| 176 | + def get_default_config_loader(self) -> BaseConfigLoader: |
| 177 | + return CustomConfigLoader() |
| 178 | +``` |
| 179 | + |
| 180 | +### Step 2: Create the Checkpoint Weight Loader |
| 181 | + |
| 182 | +```python |
| 183 | +from typing import Any |
| 184 | +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader |
| 185 | +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight_loader |
| 186 | + |
| 187 | +@register_checkpoint_weight_loader("CUSTOM_FORMAT") |
| 188 | +class CustomWeightLoader(BaseWeightLoader): |
| 189 | + def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: |
| 190 | + """ |
| 191 | + Load weights from your custom format. |
| 192 | +
|
| 193 | + Args: |
| 194 | + checkpoint_dir: Directory containing checkpoint files |
| 195 | + **kwargs: Additional loading parameters |
| 196 | +
|
| 197 | + Returns: |
| 198 | + Dictionary mapping parameter names to tensors |
| 199 | + """ |
| 200 | + weights = {} # Implement your custom weight loading logic here |
| 201 | + |
| 202 | + # Examples: |
| 203 | + # - Load from custom binary files |
| 204 | + # - Load from databases |
| 205 | + # - Load from compressed archives |
| 206 | + # - Apply custom preprocessing |
| 207 | + |
| 208 | + return weights |
| 209 | +``` |
| 210 | + |
| 211 | +### Step 3: Create the Checkpoint Config Loader |
| 212 | + |
| 213 | +```python |
| 214 | +from tensorrt_llm._torch.model_config import ModelConfig |
| 215 | +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader |
| 216 | +from tensorrt_llm._torch.models.modeling_utils import register_config_loader |
| 217 | + |
| 218 | +@register_config_loader("CUSTOM_FORMAT") |
| 219 | +class CustomConfigLoader(BaseConfigLoader): |
| 220 | + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: |
| 221 | + """ |
| 222 | + Load and parse configuration from your custom format. |
| 223 | +
|
| 224 | + Args: |
| 225 | + checkpoint_dir: Directory containing configuration files |
| 226 | + **kwargs: Additional loading parameters |
| 227 | +
|
| 228 | + Returns: |
| 229 | + ModelConfig object containing parsed configuration |
| 230 | + """ |
| 231 | + # Load your custom configuration format here |
| 232 | + # Examples: |
| 233 | + # - Parse YAML/TOML files |
| 234 | + # - Convert from proprietary formats |
| 235 | + |
| 236 | + pretrained_config = self._load_pretrained_config(checkpoint_dir, **kwargs) |
| 237 | + |
| 238 | + return ModelConfig( |
| 239 | + pretrained_config=pretrained_config, |
| 240 | + # Add other ModelConfig parameters as needed |
| 241 | + ) |
| 242 | + |
| 243 | + def _load_pretrained_config(self, checkpoint_dir: str, **kwargs): |
| 244 | + """Load the raw configuration from your custom format.""" |
| 245 | + # Implement as needed |
| 246 | + pass |
| 247 | +``` |
| 248 | + |
| 249 | +### Step 4: Create the Checkpoint Weight Mapper |
| 250 | + |
| 251 | +```python |
| 252 | +from torch import nn |
| 253 | +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper |
| 254 | +from tensorrt_llm._torch.models.modeling_utils import register_mapper |
| 255 | + |
| 256 | +@register_mapper("CUSTOM_FORMAT") |
| 257 | +class CustomWeightMapper(BaseWeightMapper): |
| 258 | + def __init__(self): |
| 259 | + super().__init__() |
| 260 | + # Define any weight transformation callbacks |
| 261 | + self._callbacks = [ |
| 262 | + # Add your custom weight transformation functions |
| 263 | + # self._custom_transform_function, |
| 264 | + ] |
| 265 | + |
| 266 | + def map_weights(self) -> None: |
| 267 | + """ |
| 268 | + Define mappings between source and target weight names. |
| 269 | + """ |
| 270 | + self.mapping.update({ |
| 271 | + # Map source names to target names |
| 272 | + # 'target_module_name': ['source_param1', 'source_param2'], |
| 273 | + # For example: 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'] |
| 274 | + }) |
| 275 | + |
| 276 | + def apply_callbacks(self, module: nn.Module, module_name: str, |
| 277 | + module_names_breakdown: list[str], |
| 278 | + weights: dict) -> list[dict]: |
| 279 | + """ |
| 280 | + Apply weight transformations for modules that require special handling. |
| 281 | +
|
| 282 | + Args: |
| 283 | + module: The target module |
| 284 | + module_name: The specific module name being processed |
| 285 | + module_names_breakdown: Module path components |
| 286 | + weights: Source weights dictionary |
| 287 | +
|
| 288 | + Returns: |
| 289 | + List of transformed weight dictionaries |
| 290 | + """ |
| 291 | + module_weights = [] |
| 292 | + |
| 293 | + for new_name in self._mapping[module_name]: |
| 294 | + # Filter weights for this specific parameter |
| 295 | + fw = self.filter_weights( |
| 296 | + '.'.join(module_names_breakdown + [new_name]), weights) |
| 297 | + |
| 298 | + # Apply transformation callbacks |
| 299 | + for callback in self._callbacks: |
| 300 | + fw = callback(module, new_name, fw) |
| 301 | + |
| 302 | + module_weights.append(fw) |
| 303 | + |
| 304 | + return module_weights |
| 305 | + |
| 306 | + def should_skip_module(self, module_name: str) -> bool: |
| 307 | + """ |
| 308 | + Define which modules should be skipped during loading. |
| 309 | + """ |
| 310 | + # Add logic to skip specific modules based on your requirements |
| 311 | + # Examples: |
| 312 | + # - Skip LoRA-specific modules |
| 313 | + # - Skip temporary/auxiliary modules |
| 314 | + |
| 315 | + return super().should_skip_module(module_name) |
| 316 | +``` |
| 317 | + |
| 318 | +Note: When creating a custom mapper, you can define either a checkpoint-format-specific mapper. For example: |
| 319 | + |
| 320 | +```python |
| 321 | +@register_mapper("CUSTOM_FORMAT") |
| 322 | +class CustomWeightMapper(BaseWeightMapper) |
| 323 | +``` |
| 324 | + |
| 325 | +Alternatively, you can define a checkpoint-model-specific mapper. For example: |
| 326 | + |
| 327 | +```python |
| 328 | +@register_mapper("CUSTOM_FORMAT", "Gemma3ForCausalLM") |
| 329 | +class CustomWeightMapper(BaseWeightMapper) |
| 330 | +``` |
| 331 | + |
| 332 | +By setting the model name, the registered mapper will be associated with the specific model. |
0 commit comments