Skip to content

Commit 38a9d30

Browse files
authored
publish instructions on adding a new model (#1451)
as titled
1 parent 58afc5f commit 38a9d30

File tree

5 files changed

+86
-5
lines changed

5 files changed

+86
-5
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ To use the latest features of `torchtitan`, we recommend using the most recent P
2020

2121

2222
## Latest News
23+
- [2025/07] We published [instructions](/torchtitan/models/README.md) on how to add a model to `torchtitan`.
2324
- [2025/07] We released `torchtitan` [v0.1.0](https://github.com/pytorch/torchtitan/releases), and also set up nightly builds.
2425
- [2025/04] Our paper was accepted by [ICLR 2025](https://iclr.cc/virtual/2025/poster/29620).
2526
- [2025/04] [Llama 4](torchtitan/experiments/llama4/) initial support is available as an experiment.
@@ -37,7 +38,7 @@ To use the latest features of `torchtitan`, we recommend using the most recent P
3738

3839
Our mission is to accelerate innovation in the field of generative AI by empowering researchers and developers to explore new modeling architectures and infrastructure techniques.
3940

40-
The guiding principles when building `torchtitan`
41+
The Guiding Principles when building `torchtitan`
4142
* Designed to be easy to understand, use and extend for different training purposes.
4243
* Minimal changes to the model code when applying multi-dimensional parallelism.
4344
* Bias towards a clean, minimal codebase while providing basic reusable / swappable components.

torchtitan/experiments/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ To accelerate contributions to and innovations around `torchtitan`, we are addin
44

55
We provide this `experiments/` folder to host experiments that add significant value to `torchtitan`, with the following principles. We refer to the part of `torchtitan` outside `experiments` as `core`.
66
1. Each subfolder in `experiments` will be an experiment, with a clear theme which can be flexible, such as
7-
- a new model, or preferably a new model architecture, with its training infrastructure including parallelization functions;
8-
- an enhancement or addition to the existing infrastructure of `torchtitan`.
7+
- A new model, or preferably a new model architecture, with its training infrastructure including parallelization functions. Please see the [instructions](/torchtitan/models/README.md) on how to contribute a new model.
8+
- An enhancement or addition to the existing infrastructure of `torchtitan`.
99
2. It is the contributors' responsibility to justify the value of an experiment. `torchtitan` team will review proposals on a case-by-case basis. As part of the contribution, the contributors should provide documentation that clearly showcases the motivation and innovation of an experiment, including reports on performance and loss convergence.
1010
3. An experiment should reuse existing `torchtitan` code as much as possible, such as modules in [`components/`](../components/) (via a new [`TrainSpec`](../protocols/train_spec.py)) and [`train.py`](../train.py). For a list of extension points we provide, please refer to [docs/extension.md](../../docs/extension.md).
1111
- The extension points are subject to change. We kindly request that contributors provide feedback if they encounter issues reusing any components, rather than simply using a copy-and-paste approach.

torchtitan/models/README.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
This note outlines the process of adding a new model in the `torchtitan` repo. In most cases, new models should be added first under the `torchtitan/experiments` folder. For criteria of contributions, please see the [Contributing Guidelines](/torchtitan/experiments/README.md) therein. In general, please adhere to the [Guiding Principles](/README.md#overview) of `torchtitan`.
2+
3+
For offline explorations, we recommend the same steps, unless otherwise noted.
4+
5+
## Adding the model
6+
7+
Please refer to the [Llama 3 folder](.llama3) as an example.
8+
9+
The folder should be organized as follows
10+
- `model` folder: a self-contained folder of model definition and args
11+
- `args.py`
12+
- Inherit [`BaseModelArgs`](/torchtitan/protocols/model.py) and implement the interfaces.
13+
- `get_nparams_and_flops()` will be used to understand model size and compute throughput.
14+
- `update_from_config()` updates the model args from training configs. To extend training configs, see the bullet point below on `job_config.py`.
15+
- `model.py`
16+
- NOTE: Please adhere to the guiding principles and write single-device model code.
17+
- NOTE: We prioritize readability over flexibility. The preferred style is to not share modules among different models, except for the most common and complicated ones.
18+
- Inherit [`ModelProtocol`](/torchtitan/protocols/model.py) and implement the interfaces.
19+
- `__init__()` consumes a `ModelArgs` input to build the model
20+
- `init_weights()` is used to properly initialize the parameters and buffers in the model. Please define it in a recursive way so that every submodule has its own `init_weights()`.
21+
- Add additional files to reduce the complexity of `model.py` if it grows too large or complex, e.g. moe.py to host the `MoE`, `Router`, and `GroupedExperts` modules.
22+
- `state_dict_adapter.py`
23+
- Inherit [`StateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats).
24+
- There are multiple ways such adapters could be used
25+
- Checkpoint conversion scripts in `scripts/checkpoint_conversion/` will use them to adapt state dicts containing non-sharded `torch.Tensor` on CPU.
26+
- During training, [`CheckpointManager`](/torchtitan/components/checkpoint.py) will use them to adapt state dicts containing (potentially sharded) `DTensor` on GPUs to save / load checkpoints in HF format.
27+
- In post-training, `to_hf()` helps convert a torchtitan model to HF model, which can be used for inference by other frameworks.
28+
- This is optional for offline exploration.
29+
- `infra` folder: containing the functions used to parallelize the model using PyTorch native techniques
30+
- `parallelize.py`
31+
- apply training techniques in the following order
32+
- TP (and EP if the model has MoE architecture)
33+
- activation checkpointing
34+
- `torch.compile`
35+
- FSDP / HSDP
36+
- NOTE: currently CP support for language models is enabled via a context manager in `torchtitan/train.py`. Ideally no extra work is needed to enable CP.
37+
- `pipeline.py` (optional if model size is small)
38+
- apply PP
39+
- Include other util files if necessary.
40+
- `__init__.py`
41+
- A dictionary of the actual model configurations, of the type `[str: ModelArgs]`.
42+
- Call `register_train_spec` to specify a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of
43+
- model name, model class, model args
44+
- parallelizing function, pipelining function
45+
- builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function
46+
- More often than not, existing components can be reused.
47+
- Adding new datasets requires the `torchtitan` team’s review and legal approval.
48+
- Try to have minimal dependency on external libraries, if any.
49+
- state dict adapter
50+
- Read [more](/docs/extension.md#trainspec) on `TrainSpec`.
51+
- `README.md`
52+
- Include [instructions](/README.md#downloading-a-tokenizer) to download tokenizers / encoders.
53+
- Include instructions to download model checkpoints for continued pretraining or post training.
54+
- Update the current status of development, including the supported features and coming features.
55+
- This is optional for offline exploration.
56+
- `job_config.py` (if necessary)
57+
- Sometimes a new model needs to access additional configs, to be consumed by various training components. Read the [guidance](/docs/extension.md#train-script) on extending `JobConfig`.
58+
- `train.py` (only if absolutely necessary)
59+
- Sometimes `torchtitan/train.py` may not be enough to run the model. There is a [tradeoff](/docs/extension.md#train-script) between extending the existing one vs. having a new one.
60+
- Even if a new one needs to be added, it should reuse `torchtitan/train.py` as much as possible. See `torchtitan/experiments/flux/train.py` as an example.
61+
- `train_configs` folder
62+
- There should be one `.toml` file for each model variant (e.g. Llama 3.1 8B / 70B / 405B) as well as a `debug_model.toml`.
63+
- They should be verified with real training jobs, in terms of optimized throughput and loss converging.
64+
65+
## Testing and Benchmarking
66+
- Numerics testing
67+
- One way of doing this E2E is to load the same model checkpoint into the `torchtitan` model and the HF model, and compare the model output given the same input. This assumes
68+
- HF implementation is correct.
69+
- The correctness of a `torchtitan` model and the corresponding state dict adapter together indicates the correctness of both.
70+
- Loss converging
71+
- If there is a verified baseline, compare the loss curves with the baseline.
72+
- For comparisons within `torchtitan`, see the [guidelines](/docs/converging.md).
73+
- Performance benchmarking
74+
- Please refer to the [benchmarks](/benchmarks/) folder.
75+
- CI tests
76+
- Including unit tests and integration tests, see [examples](/tests/).
77+
- If the model folder is under the experiments folder, put the tests under the model folder. Otherwise, put the tests under the `/tests` folder.
78+
- Add necessary GitHub [workflows](/.github/workflows/).

torchtitan/protocols/state_dict_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from typing import Any
99

10-
from torchtitan.protocols import BaseModelArgs
10+
from .model import BaseModelArgs
1111

1212

1313
class StateDictAdapter(ABC):

torchtitan/protocols/train_spec.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from torchtitan.components.tokenizer import BaseTokenizer
2020
from torchtitan.components.validate import BaseValidator
2121
from torchtitan.config import LRScheduler
22-
from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter
22+
23+
from .model import BaseModelArgs, ModelProtocol
24+
from .state_dict_adapter import StateDictAdapter
2325

2426

2527
ParallelizeFunction: TypeAlias = Callable[..., nn.Module]

0 commit comments

Comments
 (0)