Skip to content

Commit b6baa9e

Browse files
authored
[TRTLLM-6823][doc] Add checkpoint refactor docs (#6592)
Signed-off-by: Shahar Mor <[email protected]>
1 parent 4142320 commit b6baa9e

File tree

3 files changed

+335
-3
lines changed

3 files changed

+335
-3
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# Checkpoint Loading
2+
3+
The PyTorch backend provides a flexible and extensible infrastructure for loading model checkpoints from different sources and formats, such as HuggingFace (HF) or custom formats, by implementing required components like the checkpoint's weight loader, mapper, and configuration parser.
4+
5+
## Table of Contents
6+
1. [Overview](#overview)
7+
2. [Core Components](#core-components)
8+
3. [Built-in Checkpoint Formats](#built-in-checkpoint-formats)
9+
4. [Using Checkpoint Loaders](#using-checkpoint-loaders)
10+
5. [Creating Custom Checkpoint Loaders](#creating-custom-checkpoint-loaders)
11+
12+
## Overview
13+
14+
The checkpoint loading design is built around a plugin-like architecture that is separated into four distinct components:
15+
16+
- **Checkpoint Loaders**: Orchestrates the loading process for specific formats.
17+
- **Config Loaders**: Handles model configuration parsing and validation.
18+
- **Weight Loaders**: Manages the actual loading of model weights from storage into memory.
19+
- **Weight Mappers**: Maps and transforms loaded weights to the TRTLLM model's definition.
20+
21+
This modular design allows for easy extension to support new checkpoint formats while maintaining backward compatibility and performance optimizations. By separating checkpoint loading into four subcomponents, users can leverage existing implementations and introduce custom, checkpoint-specific components.
22+
23+
To support a new checkpoint format, you must implement all four components.
24+
If the format shares components with an existing framework (such as HF), you only need to implement the components that differ.
25+
26+
## Core Components
27+
28+
### BaseCheckpointLoader
29+
30+
The `BaseCheckpointLoader` is the central interface for all checkpoint loading operations. It provides a unified API regardless of the underlying checkpoint format. This interface is responsible for holding and exposing all objects required for the loading and parsing process.
31+
32+
**Key Methods:**
33+
- `load_config(checkpoint_dir, **kwargs)`: Loads and returns a `ModelConfig` object
34+
- `load_weights(checkpoint_dir, **kwargs)`: Loads and returns a dictionary of weights
35+
- `get_initialized_weight_mapper(model, config)`: Returns a weight mapper initialized at runtime for the model
36+
- `cleanup()`: Releases resources and cleans up internal state
37+
38+
### BaseConfigLoader
39+
40+
Loads model configurations from checkpoint directories and parses them into a TRTLLM `ModelConfig`:
41+
42+
```python
43+
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
44+
45+
class CustomConfigLoader(BaseConfigLoader):
46+
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
47+
# Load and parse configuration from your custom format
48+
pretrained_config = self._get_pretrained_config(checkpoint_dir, **kwargs)
49+
50+
return ModelConfig(pretrained_config=pretrained_config,
51+
...)
52+
53+
def _get_pretrained_config(self, checkpoint_dir, **kwargs):
54+
...
55+
56+
```
57+
58+
### BaseWeightLoader
59+
60+
Handles the loading of model weights from storage:
61+
62+
```python
63+
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
64+
65+
class CustomWeightLoader(BaseWeightLoader):
66+
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
67+
# Load weights from your custom format
68+
# Return a dictionary mapping parameter names to tensors
69+
return weights_dict
70+
```
71+
72+
### BaseWeightMapper
73+
74+
Transforms weights between different naming conventions and applies model-specific transformations to the TRTLLM model object.
75+
76+
## Built-in Checkpoint Formats
77+
78+
### HuggingFace Format
79+
80+
Currently, the HF checkpoint loader is the primary built-in format and supports:
81+
82+
- **Weights loading** (`.safetensors, .bin, .pth`): Load HF-compatible weights from disk
83+
- **Configuration parser** - Parse configuration information stored by HF into a TRTLLM `ModelConfig` object
84+
- **Weights Mapping** - Convert HF weights into a TRTLLM-compatible representation
85+
86+
## Using Checkpoint Loaders
87+
88+
### Basic Usage
89+
90+
There are two main approaches for using checkpoint loading objects
91+
92+
The first approach is through the llm-api, as shown in the following example:
93+
94+
```python
95+
from tensorrt_llm import LLM
96+
97+
hf_model_dir = "llama-models-v2/llama-v2-13b-hf"
98+
99+
llm = LLM(model=hf_model_dir)
100+
```
101+
102+
In this example, the `HfCheckpointLoader` is selected by default.
103+
104+
To explicitly set the checkpoint loader, specify the required checkpoint-specific loader:
105+
106+
```python
107+
from tensorrt_llm import LLM
108+
from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader
109+
110+
hf_model_dir = "llama-models-v2/llama-v2-13b-hf"
111+
112+
llm = LLM(model=hf_model_dir,
113+
checkpoint_loader=HfCheckpointLoader())
114+
```
115+
116+
Similarly, to use a basic checkpoint loader with a specific subcomponent, provide the desired subcomponent as needed:
117+
118+
```python
119+
from tensorrt_llm import LLM
120+
from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader
121+
122+
hf_model_dir = "llama-models-v2/llama-v2-13b-hf"
123+
124+
llm = LLM(model=hf_model_dir,
125+
checkpoint_loader=HfCheckpointLoader(weight_loader=MyCustomWeightLoader()))
126+
```
127+
128+
In the second approach, you can directly use the individual checkpoint loading components:
129+
130+
```python
131+
from tensorrt_llm._torch.models.checkpoints.hf.gemma3_weight_mapper import \
132+
Gemma3HfWeightMapper
133+
from tensorrt_llm._torch.models.modeling_gemma3 import Gemma3ForCausalLM
134+
135+
gemma3 = Gemma3ForCausalLM(model_config)
136+
weight_mapper = Gemma3HfWeightMapper()
137+
weight_mapper.init_model_and_config(gemma3, model_config)
138+
gemma3.load_weights(hf_gemma3.state_dict(), weight_mapper)
139+
```
140+
## Creating Custom Checkpoint Loaders
141+
142+
To support a new checkpoint format, implement all four components. This section provides minimal templates for each.
143+
144+
### When to Create Custom Components
145+
146+
- **Complete New Format**: Implement all four components to support a new checkpoint format
147+
- **Custom Weight Storage**: Implement only a custom weight loader if you have a unique weight storage format (such as a custom binary format or database storage)
148+
- **Custom Configuration**: Implement only a custom config loader if your configuration format cannot be parsed by existing loaders
149+
- **Custom Weight Mapping**: Implement only a custom weight mapper if your model has unique weight naming or transformation requirements that are checkpoint-specific
150+
151+
### Step 1: Create the Checkpoint Loader
152+
153+
```python
154+
from typing import Optional
155+
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
156+
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
157+
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
158+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
159+
from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader
160+
161+
@register_checkpoint_loader("CUSTOM_FORMAT")
162+
class CustomCheckpointLoader(BaseCheckpointLoader):
163+
def __init__(self,
164+
*,
165+
weight_loader: Optional[BaseWeightLoader] = None,
166+
weight_mapper: Optional[BaseWeightMapper] = None,
167+
config_loader: Optional[BaseConfigLoader] = None):
168+
self._weight_loader = weight_loader or self.get_default_weight_loader()
169+
self._config_loader = config_loader or self.get_default_config_loader()
170+
self._weight_mapper = weight_mapper
171+
self._checkpoint_format = "CUSTOM_FORMAT" # Set the checkpoint format name
172+
173+
def get_default_weight_loader(self) -> BaseWeightLoader:
174+
return CustomWeightLoader()
175+
176+
def get_default_config_loader(self) -> BaseConfigLoader:
177+
return CustomConfigLoader()
178+
```
179+
180+
### Step 2: Create the Checkpoint Weight Loader
181+
182+
```python
183+
from typing import Any
184+
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
185+
from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_weight_loader
186+
187+
@register_checkpoint_weight_loader("CUSTOM_FORMAT")
188+
class CustomWeightLoader(BaseWeightLoader):
189+
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
190+
"""
191+
Load weights from your custom format.
192+
193+
Args:
194+
checkpoint_dir: Directory containing checkpoint files
195+
**kwargs: Additional loading parameters
196+
197+
Returns:
198+
Dictionary mapping parameter names to tensors
199+
"""
200+
weights = {} # Implement your custom weight loading logic here
201+
202+
# Examples:
203+
# - Load from custom binary files
204+
# - Load from databases
205+
# - Load from compressed archives
206+
# - Apply custom preprocessing
207+
208+
return weights
209+
```
210+
211+
### Step 3: Create the Checkpoint Config Loader
212+
213+
```python
214+
from tensorrt_llm._torch.model_config import ModelConfig
215+
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
216+
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
217+
218+
@register_config_loader("CUSTOM_FORMAT")
219+
class CustomConfigLoader(BaseConfigLoader):
220+
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
221+
"""
222+
Load and parse configuration from your custom format.
223+
224+
Args:
225+
checkpoint_dir: Directory containing configuration files
226+
**kwargs: Additional loading parameters
227+
228+
Returns:
229+
ModelConfig object containing parsed configuration
230+
"""
231+
# Load your custom configuration format here
232+
# Examples:
233+
# - Parse YAML/TOML files
234+
# - Convert from proprietary formats
235+
236+
pretrained_config = self._load_pretrained_config(checkpoint_dir, **kwargs)
237+
238+
return ModelConfig(
239+
pretrained_config=pretrained_config,
240+
# Add other ModelConfig parameters as needed
241+
)
242+
243+
def _load_pretrained_config(self, checkpoint_dir: str, **kwargs):
244+
"""Load the raw configuration from your custom format."""
245+
# Implement as needed
246+
pass
247+
```
248+
249+
### Step 4: Create the Checkpoint Weight Mapper
250+
251+
```python
252+
from torch import nn
253+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
254+
from tensorrt_llm._torch.models.modeling_utils import register_mapper
255+
256+
@register_mapper("CUSTOM_FORMAT")
257+
class CustomWeightMapper(BaseWeightMapper):
258+
def __init__(self):
259+
super().__init__()
260+
# Define any weight transformation callbacks
261+
self._callbacks = [
262+
# Add your custom weight transformation functions
263+
# self._custom_transform_function,
264+
]
265+
266+
def map_weights(self) -> None:
267+
"""
268+
Define mappings between source and target weight names.
269+
"""
270+
self.mapping.update({
271+
# Map source names to target names
272+
# 'target_module_name': ['source_param1', 'source_param2'],
273+
# For example: 'qkv_proj': ['q_proj', 'k_proj', 'v_proj']
274+
})
275+
276+
def apply_callbacks(self, module: nn.Module, module_name: str,
277+
module_names_breakdown: list[str],
278+
weights: dict) -> list[dict]:
279+
"""
280+
Apply weight transformations for modules that require special handling.
281+
282+
Args:
283+
module: The target module
284+
module_name: The specific module name being processed
285+
module_names_breakdown: Module path components
286+
weights: Source weights dictionary
287+
288+
Returns:
289+
List of transformed weight dictionaries
290+
"""
291+
module_weights = []
292+
293+
for new_name in self._mapping[module_name]:
294+
# Filter weights for this specific parameter
295+
fw = self.filter_weights(
296+
'.'.join(module_names_breakdown + [new_name]), weights)
297+
298+
# Apply transformation callbacks
299+
for callback in self._callbacks:
300+
fw = callback(module, new_name, fw)
301+
302+
module_weights.append(fw)
303+
304+
return module_weights
305+
306+
def should_skip_module(self, module_name: str) -> bool:
307+
"""
308+
Define which modules should be skipped during loading.
309+
"""
310+
# Add logic to skip specific modules based on your requirements
311+
# Examples:
312+
# - Skip LoRA-specific modules
313+
# - Skip temporary/auxiliary modules
314+
315+
return super().should_skip_module(module_name)
316+
```
317+
318+
Note: When creating a custom mapper, you can define either a checkpoint-format-specific mapper. For example:
319+
320+
```python
321+
@register_mapper("CUSTOM_FORMAT")
322+
class CustomWeightMapper(BaseWeightMapper)
323+
```
324+
325+
Alternatively, you can define a checkpoint-model-specific mapper. For example:
326+
327+
```python
328+
@register_mapper("CUSTOM_FORMAT", "Gemma3ForCausalLM")
329+
class CustomWeightMapper(BaseWeightMapper)
330+
```
331+
332+
By setting the model name, the registered mapper will be associated with the specific model.

tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def get(cls, checkpoint_format: str, **kwargs) -> "BaseCheckpointLoader":
6767
f"available formats are: {CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING.keys()}"
6868
)
6969

70-
def get_initilized_weight_mapper(self, model: nn.Module,
71-
config: ModelConfig) -> BaseWeightMapper:
70+
def get_initialized_weight_mapper(self, model: nn.Module,
71+
config: ModelConfig) -> BaseWeightMapper:
7272
weight_mapper = None
7373
if self.weight_mapper is not None:
7474
self.weight_mapper.init_model_and_config(model, config)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def init_meta_tensor(t: torch.Tensor):
10791079
else:
10801080
weights = checkpoint_loader.load_weights(checkpoint_dir)
10811081

1082-
weight_mapper = checkpoint_loader.get_initilized_weight_mapper(
1082+
weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
10831083
model, config)
10841084
self._call_load_weights(model.load_weights, weights,
10851085
weight_mapper)

0 commit comments

Comments
 (0)