|
| 1 | +This note outlines the process of adding a new model in the `torchtitan` repo. In most cases, new models should be added first under the `torchtitan/experiments` folder. For criteria of contributions, please see the [Contributing Guidelines](/torchtitan/experiments/README.md) therein. In general, please adhere to the [Guiding Principles](/README.md#overview) of `torchtitan`. |
| 2 | + |
| 3 | +For offline explorations, we recommend the same steps, unless otherwise noted. |
| 4 | + |
| 5 | +## Adding the model |
| 6 | + |
| 7 | +Please refer to the [Llama 3 folder](.llama3) as an example. |
| 8 | + |
| 9 | +The folder should be organized as follows |
| 10 | +- `model` folder: a self-contained folder of model definition and args |
| 11 | + - `args.py` |
| 12 | + - Inherit [`BaseModelArgs`](/torchtitan/protocols/model.py) and implement the interfaces. |
| 13 | + - `get_nparams_and_flops()` will be used to understand model size and compute throughput. |
| 14 | + - `update_from_config()` updates the model args from training configs. To extend training configs, see the bullet point below on `job_config.py`. |
| 15 | + - `model.py` |
| 16 | + - NOTE: Please adhere to the guiding principles and write single-device model code. |
| 17 | + - NOTE: We prioritize readability over flexibility. The preferred style is to not share modules among different models, except for the most common and complicated ones. |
| 18 | + - Inherit [`ModelProtocol`](/torchtitan/protocols/model.py) and implement the interfaces. |
| 19 | + - `__init__()` consumes a `ModelArgs` input to build the model |
| 20 | + - `init_weights()` is used to properly initialize the parameters and buffers in the model. Please define it in a recursive way so that every submodule has its own `init_weights()`. |
| 21 | + - Add additional files to reduce the complexity of `model.py` if it grows too large or complex, e.g. moe.py to host the `MoE`, `Router`, and `GroupedExperts` modules. |
| 22 | + - `state_dict_adapter.py` |
| 23 | + - Inherit [`StateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats). |
| 24 | + - There are multiple ways such adapters could be used |
| 25 | + - Checkpoint conversion scripts in `scripts/checkpoint_conversion/` will use them to adapt state dicts containing non-sharded `torch.Tensor` on CPU. |
| 26 | + - During training, [`CheckpointManager`](/torchtitan/components/checkpoint.py) will use them to adapt state dicts containing (potentially sharded) `DTensor` on GPUs to save / load checkpoints in HF format. |
| 27 | + - In post-training, `to_hf()` helps convert a torchtitan model to HF model, which can be used for inference by other frameworks. |
| 28 | + - This is optional for offline exploration. |
| 29 | +- `infra` folder: containing the functions used to parallelize the model using PyTorch native techniques |
| 30 | + - `parallelize.py` |
| 31 | + - apply training techniques in the following order |
| 32 | + - TP (and EP if the model has MoE architecture) |
| 33 | + - activation checkpointing |
| 34 | + - `torch.compile` |
| 35 | + - FSDP / HSDP |
| 36 | + - NOTE: currently CP support for language models is enabled via a context manager in `torchtitan/train.py`. Ideally no extra work is needed to enable CP. |
| 37 | + - `pipeline.py` (optional if model size is small) |
| 38 | + - apply PP |
| 39 | + - Include other util files if necessary. |
| 40 | +- `__init__.py` |
| 41 | + - A dictionary of the actual model configurations, of the type `[str: ModelArgs]`. |
| 42 | + - Call `register_train_spec` to specify a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of |
| 43 | + - model name, model class, model args |
| 44 | + - parallelizing function, pipelining function |
| 45 | + - builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function |
| 46 | + - More often than not, existing components can be reused. |
| 47 | + - Adding new datasets requires the `torchtitan` team’s review and legal approval. |
| 48 | + - Try to have minimal dependency on external libraries, if any. |
| 49 | + - state dict adapter |
| 50 | + - Read [more](/docs/extension.md#trainspec) on `TrainSpec`. |
| 51 | +- `README.md` |
| 52 | + - Include [instructions](/README.md#downloading-a-tokenizer) to download tokenizers / encoders. |
| 53 | + - Include instructions to download model checkpoints for continued pretraining or post training. |
| 54 | + - Update the current status of development, including the supported features and coming features. |
| 55 | + - This is optional for offline exploration. |
| 56 | +- `job_config.py` (if necessary) |
| 57 | + - Sometimes a new model needs to access additional configs, to be consumed by various training components. Read the [guidance](/docs/extension.md#train-script) on extending `JobConfig`. |
| 58 | +- `train.py` (only if absolutely necessary) |
| 59 | + - Sometimes `torchtitan/train.py` may not be enough to run the model. There is a [tradeoff](/docs/extension.md#train-script) between extending the existing one vs. having a new one. |
| 60 | + - Even if a new one needs to be added, it should reuse `torchtitan/train.py` as much as possible. See `torchtitan/experiments/flux/train.py` as an example. |
| 61 | +- `train_configs` folder |
| 62 | + - There should be one `.toml` file for each model variant (e.g. Llama 3.1 8B / 70B / 405B) as well as a `debug_model.toml`. |
| 63 | + - They should be verified with real training jobs, in terms of optimized throughput and loss converging. |
| 64 | + |
| 65 | +## Testing and Benchmarking |
| 66 | +- Numerics testing |
| 67 | + - One way of doing this E2E is to load the same model checkpoint into the `torchtitan` model and the HF model, and compare the model output given the same input. This assumes |
| 68 | + - HF implementation is correct. |
| 69 | + - The correctness of a `torchtitan` model and the corresponding state dict adapter together indicates the correctness of both. |
| 70 | +- Loss converging |
| 71 | + - If there is a verified baseline, compare the loss curves with the baseline. |
| 72 | + - For comparisons within `torchtitan`, see the [guidelines](/docs/converging.md). |
| 73 | +- Performance benchmarking |
| 74 | + - Please refer to the [benchmarks](/benchmarks/) folder. |
| 75 | +- CI tests |
| 76 | + - Including unit tests and integration tests, see [examples](/tests/). |
| 77 | + - If the model folder is under the experiments folder, put the tests under the model folder. Otherwise, put the tests under the `/tests` folder. |
| 78 | + - Add necessary GitHub [workflows](/.github/workflows/). |
0 commit comments