Skip to content

Commit 8bd8c93

Browse files
authored
Move fqn mapping logic to StateDictAdapter (#1557)
This moves the logic that parses `model.safetensors.index.json` and generates the `fqn_to_index_mapping` to `StateDictAdapter` since this logic should be shared by all classes that inherit from `StateDictAdapter`.
1 parent a6972ae commit 8bd8c93

File tree

8 files changed

+45
-33
lines changed

8 files changed

+45
-33
lines changed

scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
6565
"--hf_assets_path",
6666
type=Path,
6767
help="Path to HF assets directory. This is used to get the model.safetensors.index.json mapping",
68-
default="./assets/hf/Llama3.1-8B",
68+
default="./assets/hf/Llama-3.1-8B",
6969
)
7070
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
7171
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")

torchtitan/components/checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3838
from torchtitan.components.optimizer import OptimizersContainer
3939
from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP
40-
from torchtitan.protocols import StateDictAdapter
40+
from torchtitan.protocols import BaseStateDictAdapter
4141
from torchtitan.tools.logging import logger
4242
from torchtitan.tools.utils import GarbageCollection
4343

@@ -177,7 +177,7 @@ class CheckpointManager:
177177
checkpoint_config (Checkpoint): The config used to configure the checkpointing.
178178
base_folder (str): The base folder to save the checkpoint. Will be concatenated
179179
with checkpoint_config.folder
180-
sd_adapter (Optional[type[StateDictAdapter]]): The adapter used to convert model state
180+
sd_adapter (Optional[type[BaseStateDictAdapter]]): The adapter used to convert model state
181181
dicts between native format and other formats.
182182
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
183183
@@ -191,7 +191,7 @@ def __init__(
191191
lr_schedulers: LRSchedulersContainer,
192192
states: dict[str, Any],
193193
checkpoint_config: CheckpointConfig,
194-
sd_adapter: StateDictAdapter | None,
194+
sd_adapter: BaseStateDictAdapter | None,
195195
base_folder: str = "",
196196
ft_manager: FTManager | None = None,
197197
) -> None:

torchtitan/experiments/forge/train_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# Import torchtitan.models to ensure all train specs are registered
1010
import torchtitan.models # noqa: F401
11-
from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter
11+
from torchtitan.protocols import BaseModelArgs, BaseStateDictAdapter, ModelProtocol
1212
from torchtitan.protocols.train_spec import (
1313
_train_specs,
1414
LossFunctionBuilder,
@@ -30,7 +30,7 @@ class ForgeTrainSpec:
3030
build_optimizers_fn: OptimizersBuilder
3131
build_lr_schedulers_fn: LRSchedulersBuilder
3232
build_loss_fn: LossFunctionBuilder
33-
state_dict_adapter: type[StateDictAdapter] | None = None
33+
state_dict_adapter: type[BaseStateDictAdapter] | None = None
3434

3535

3636
# Copy and transform train specs from torchtitan.protocols.train_spec._train_specs

torchtitan/models/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The folder should be organized as follows
2020
- `init_weights()` is used to properly initialize the parameters and buffers in the model. Please define it in a recursive way so that every submodule has its own `init_weights()`.
2121
- Add additional files to reduce the complexity of `model.py` if it grows too large or complex, e.g. moe.py to host the `MoE`, `Router`, and `GroupedExperts` modules.
2222
- `state_dict_adapter.py`
23-
- Inherit [`StateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats).
23+
- Inherit [`BaseStateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats).
2424
- There are multiple ways such adapters could be used
2525
- Checkpoint conversion scripts in `scripts/checkpoint_conversion/` will use them to adapt state dicts containing non-sharded `torch.Tensor` on CPU.
2626
- During training, [`CheckpointManager`](/torchtitan/components/checkpoint.py) will use them to adapt state dicts containing (potentially sharded) `DTensor` on GPUs to save / load checkpoints in HF format.

torchtitan/models/llama3/model/state_dict_adapter.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import json
87
import logging
9-
import os
108
import re
119
from typing import Any
1210

@@ -19,6 +17,8 @@
1917

2018
class Llama3StateDictAdapter(StateDictAdapter):
2119
def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None):
20+
super().__init__(model_args, hf_assets_path)
21+
2222
self.model_args = model_args
2323
self.hf_assets_path = hf_assets_path
2424
self.from_hf_map = {
@@ -37,26 +37,6 @@ def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None)
3737
"lm_head.weight": "output.weight",
3838
}
3939

40-
if hf_assets_path:
41-
mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json")
42-
try:
43-
with open(mapping_path, "r") as f:
44-
hf_safetensors_indx = json.load(f)
45-
except FileNotFoundError:
46-
logger.warning(
47-
"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \
48-
Defaulting to saving a single safetensors file if checkpoint is saved in HF format.",
49-
)
50-
hf_safetensors_indx = None
51-
52-
if hf_safetensors_indx:
53-
self.fqn_to_index_mapping = {}
54-
for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items():
55-
indx = re.search(r"\d+", raw_indx).group(0)
56-
self.fqn_to_index_mapping[hf_key] = indx
57-
else:
58-
self.fqn_to_index_mapping = None
59-
6040
# HuggingFace permutation function (exact copy from their conversion script)
6141
def _permute(self, w, n_heads_arg, dim1=None, dim2=None):
6242
if dim1 is None:

torchtitan/protocols/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
from .model import BaseModelArgs, ModelProtocol
88
from .model_converter import ModelConverter, ModelConvertersContainer
9-
from .state_dict_adapter import StateDictAdapter
9+
from .state_dict_adapter import BaseStateDictAdapter, StateDictAdapter
1010

1111
__all__ = [
1212
"BaseModelArgs",
1313
"ModelProtocol",
1414
"ModelConverter",
1515
"ModelConvertersContainer",
1616
"StateDictAdapter",
17+
"BaseStateDictAdapter",
1718
]

torchtitan/protocols/state_dict_adapter.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
8+
import logging
9+
import os
10+
import re
711
from abc import ABC, abstractmethod
812
from typing import Any
913

14+
logger = logging.getLogger()
15+
1016
from .model import BaseModelArgs
1117

1218

13-
class StateDictAdapter(ABC):
19+
class BaseStateDictAdapter(ABC):
1420
"""Abstract base class for state dict transformations.
1521
1622
This class defines the interface for converting between native model
@@ -47,3 +53,28 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
4753
The converted native model state dict
4854
"""
4955
pass
56+
57+
58+
class StateDictAdapter(BaseStateDictAdapter):
59+
"""State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""
60+
61+
def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None):
62+
if hf_assets_path:
63+
mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json")
64+
try:
65+
with open(mapping_path, "r") as f:
66+
hf_safetensors_indx = json.load(f)
67+
except FileNotFoundError:
68+
logger.warning(
69+
"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \
70+
Defaulting to saving a single safetensors file if checkpoint is saved in HF format.",
71+
)
72+
hf_safetensors_indx = None
73+
74+
if hf_safetensors_indx:
75+
self.fqn_to_index_mapping = {}
76+
for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items():
77+
indx = re.search(r"\d+", raw_indx).group(0)
78+
self.fqn_to_index_mapping[hf_key] = indx
79+
else:
80+
self.fqn_to_index_mapping = None

torchtitan/protocols/train_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchtitan.config import LRScheduler
2222

2323
from .model import BaseModelArgs, ModelProtocol
24-
from .state_dict_adapter import StateDictAdapter
24+
from .state_dict_adapter import BaseStateDictAdapter
2525

2626

2727
ParallelizeFunction: TypeAlias = Callable[..., nn.Module]
@@ -53,7 +53,7 @@ class TrainSpec:
5353
build_loss_fn: LossFunctionBuilder
5454
build_validator_fn: ValidatorBuilder | None = None
5555
build_metrics_processor_fn: MetricsProcessorBuilder | None = None
56-
state_dict_adapter: type[StateDictAdapter] | None = None
56+
state_dict_adapter: type[BaseStateDictAdapter] | None = None
5757

5858

5959
_train_specs = {}

0 commit comments

Comments
 (0)