Skip to content

Commit 94bd6c9

Browse files
syedazinghtm
andauthored
Trainium llama3 (#363)
* First commit * Updates to run 70B model * Updated ReadMe for SMHP-Tranium Example, tested on dev cluster * Updated ReadMe minor changes * Updated ReadMe minor changes v2 * Updated ReadMe minor changes v3 * Updated ReadMe minor changes v4 --------- Co-authored-by: nghtm <[email protected]>
1 parent e944da0 commit 94bd6c9

File tree

14 files changed

+2837
-0
lines changed

14 files changed

+2837
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 128000,
8+
"eos_token_id": 128001,
9+
"hidden_act": "silu",
10+
"hidden_size": 8192,
11+
"initializer_range": 0.02,
12+
"intermediate_size": 28672,
13+
"max_position_embeddings": 8192,
14+
"model_type": "llama",
15+
"num_attention_heads": 64,
16+
"num_hidden_layers": 80,
17+
"num_key_value_heads": 8,
18+
"pretraining_tp": 1,
19+
"rms_norm_eps": 1e-05,
20+
"rope_scaling": null,
21+
"rope_theta": 500000.0,
22+
"tie_word_embeddings": false,
23+
"torch_dtype": "bfloat16",
24+
"transformers_version": "4.40.0.dev0",
25+
"use_cache": true,
26+
"vocab_size": 128256,
27+
"sequence_parallel_enabled": false,
28+
"selective_checkpoint_enabled": false,
29+
"move_model_to_device":false
30+
}
31+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 128000,
8+
"eos_token_id": 128001,
9+
"hidden_act": "silu",
10+
"hidden_size": 4096,
11+
"initializer_range": 0.02,
12+
"intermediate_size": 14336,
13+
"max_position_embeddings": 8192,
14+
"model_type": "llama",
15+
"num_attention_heads": 32,
16+
"num_hidden_layers": 32,
17+
"num_key_value_heads": 8,
18+
"pad_token_id": 0,
19+
"pretraining_tp": 1,
20+
"rms_norm_eps": 1e-05,
21+
"rope_scaling": null,
22+
"rope_theta": 500000.0,
23+
"tie_word_embeddings": false,
24+
"torch_dtype": "bfloat16",
25+
"transformers_version": "4.31.0",
26+
"use_cache": true,
27+
"vocab_size": 128256,
28+
"sequence_parallel_enabled": false,
29+
"selective_checkpoint_enabled": false,
30+
"move_model_to_device":false
31+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# How to pre-train Llama3 with SageMaker Hyperpod using Amazon Trainium
2+
3+
## What is SageMaker Hyperpod?
4+
[Amazon SageMaker Hyperpod](https://aws.amazon.com/sagemaker/hyperpod/) offers advanced training tools to help you accelerate scalable, reliable, and secure generative AI application development. It removes the undifferentiated heavy lifting involved in building and optimizing machine learning (ML) infrastructure for training foundation models (FMs) significantly reducing training time. SageMaker Hyperpod ensure customers can continue FM training uninterrupted by periodically saving checkpoints. When a hardware failure occurs during training, SageMaker Hyperpod automatically detects the failure, repairs, or replaces the faulty instance, and resumes the training from the last saved checkpoint, removing the need for customers to manually manage this process and helping them train for week or months in a distributed setting without disruption.
5+
6+
SageMaker Hyperpod also allows customers to run their FM training workloads on [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/). AWS Trainium is the machine learning (ML) chip that AWS purpose built for deep learning (DL) training of 100B+ parameter models. Each Amazon Elastic Compute Cloud (Amazon EC2) [Trn1 instance](https://aws.amazon.com/ec2/instance-types/trn1) deploys up to 16 Trainium accelerators to deliver a high-performance, low-cost solution for DL training in the cloud. [AWS Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) helps developers train models on Trainium accelerators (and deploy them on [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/) accelerators). It natively integrates popular frameworks, such as PyTorch and Tensorflow, so that you can continue to train on Trainium accelerators and use your existing code and workflows.
7+
8+
## 0. Prerequisites
9+
You will need to set up a SageMaker Hyperpod cluster using 4 [trn1.32xlarge](https://aws.amazon.com/ec2/instance-types/trn1/) instances with a shared parallel filesystem such as [Amazon FSx for Lustre](https://docs.aws.amazon.com/fsx/latest/LustreGuide/getting-started.html). See the sagemaker-hyperpod section in the [Sagemaker Hyperpod](https://github.com/aws-samples/awsome-distributed-training/tree/main/1.architectures/5.sagemaker-hyperpod) folder for setup instructions.
10+
11+
## 1. Create Environment
12+
13+
1. Once the cluster is set up, SSH into the cluster head/controller node and switch to the `ubuntu` user:
14+
``` bash
15+
sudo su - ubuntu
16+
```
17+
> [!NOTE]
18+
> You will run the following steps from the head/controller node of your cluster.
19+
20+
2. Make sure the home directory is set up to `/fsx/ubuntu` as this will allow us to install the required dependencies only once on the head node:
21+
22+
``` bash
23+
pwd
24+
```
25+
26+
3. Next install Python virtual environment:
27+
28+
``` bash
29+
# Install Python venv
30+
sudo apt-get install -y python3.8-venv g++
31+
32+
# Create Python venv
33+
python3.8 -m venv aws_neuron_venv_pytorch
34+
```
35+
36+
Now lets activate the Virtual Environment:
37+
```bash
38+
# Activate Python venv
39+
source aws_neuron_venv_pytorch/bin/activate
40+
python -m pip install -U pip
41+
```
42+
43+
4. Install PyTorch Neuron:
44+
45+
``` bash
46+
# Install Jupyter notebook kernel
47+
pip install ipykernel
48+
python3.8 -m ipykernel install --user --name aws_neuron_venv_pytorch --display-name "Python (torch-neuronx)"
49+
pip install jupyter notebook
50+
pip install environment_kernels
51+
52+
# Set pip repository pointing to the Neuron repository
53+
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
54+
55+
# Install wget, awscli
56+
python -m pip install wget
57+
python -m pip install awscli
58+
59+
# Install Neuron Compiler and Framework
60+
python -m pip install neuronx-cc==2.* torch-neuronx torchvision
61+
python -m pip install neuronx_distributed --extra-index-url https://pip.repos.neuron.amazonaws.com
62+
```
63+
64+
65+
On your cluster head node, clone this repo
66+
``` bash
67+
git clone https://github.com/aws-samples/awsome-distributed-training/
68+
cd awsome-distributed-training/3.test_cases/22.SMHP-trainium-llama3
69+
```
70+
71+
With the repo installed, lets install the libraries defined in our `requirements.txt` file:
72+
73+
```bash
74+
# Install requirements.txt
75+
pip install -r requirements.txt
76+
```
77+
78+
## 2. Prepare Dataset
79+
80+
Next, we need to tokenize our dataset. To tokenize the data, you must request the tokenizer from HuggingFace and Meta by following the instructions at the following link: [HuggingFace Llama 3 8B Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B) . Use of the Llama 3 model is governed by the Meta license. In order to download the model weights and tokenizer, please visit the above website and accept their License before requesting access. After access has been granted, you may use the download scripts provided by Meta to download the model weights and tokenizer to your cluster.
81+
82+
1. Install the huggingface CLI and download the model:
83+
84+
```bash
85+
pip install huggingface-hub
86+
```
87+
88+
2. Authenticate with your [HuggingFace Access Token](https://huggingface.co/settings/tokens).
89+
> [!IMPORTANT]
90+
> Ensure your HuggingFace Access Token has permissions to public gated repos. You can configure your token to access public gated repos with: (*Edit Acces Token Permissions* > Check the Box: *Read access to contents of all public gated repos you can access* > Save).
91+
```bash
92+
huggingface-cli login
93+
```
94+
3. Download the [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repo from HuggingFace:
95+
> [!NOTE]
96+
> Note we are passing the `--include "*"` flag to ensure that we clone the entire repo, including the sub-directory `/original` which contains the `tokenizer.model` file we will copy in the next step.
97+
```bash
98+
huggingface-cli download meta-llama/Meta-Llama-3-8B --include "*" --local-dir Meta-Llama-3-8B .
99+
```
100+
101+
102+
4. Within the current working directory `22.SMHP-trainium-llama3`, lets copy some files from the cloned repo (`/Meta-Llama-3-8B`) to our local working directory. In particular, we will copy `config.json`, `tokenizer_config.json`, `tokenizer.json`, and `original/tokenizer.model` files from the cloned HuggingFace directory we named `/Meta-Llama-3-8B` to our current working directory `22.SMHP-trainium-llama3` so they can be picked up by our scripts:
103+
104+
```bash
105+
cp Meta-Llama-3-8B/tokenizer_config.json Meta-Llama-3-8B/config.json Meta-Llama-3-8B/tokenizer.json Meta-Llama-3-8B/original/tokenizer.model .
106+
```
107+
108+
5. Run the 'get_dataset.py' script to prepare the dataset for training. We will run this script via `srun` to ensure it runs on a compute (trn1) node:
109+
110+
``` bash
111+
srun --job-name=get_dataset_job --output=get_dataset_output.log --nodes=1 python get_dataset.py &
112+
```
113+
114+
>[!IMPORTANT]
115+
>The `get_dataset.py` job will take several minutes to execute, do not proceed until this job is completed. You can monitor the job logs with the following command:
116+
>```bash
117+
>tail -f get_dataset_output.log
118+
>```
119+
> Once `squeue` shows the job is completed and `sinfo` shows all nodes as idle, you can proceed to the next section, **Compiling the Model**.
120+
121+
122+
## 3. Compile the Model
123+
124+
Next, we will comiplie the model graph using the [neuron parallel compile](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/training/pytorch-neuron-parallel-compile.html#pytorch-neuronx-parallel-compile-cli) tool.
125+
126+
``` bash
127+
sbatch --exclusive \
128+
--nodes 4 \
129+
--cpus-per-task 64 \
130+
--wrap="srun neuron_parallel_compile bash $(pwd)/run_llama_8B_tp_pp.sh"
131+
```
132+
133+
## 4. Run Training
134+
135+
Once the graphs are compiled, we can now run model training
136+
137+
``` bash
138+
sbatch --exclusive \
139+
--nodes 4 \
140+
--cpus-per-task 64 \
141+
--wrap="srun bash $(pwd)/run_llama_8B_tp_pp.sh"
142+
```
143+
144+
## Running the 70B model
145+
146+
If you would like to compile and train the 70B model instead, run the `run_llama_70B_tp_pp.sh` script instead as below:
147+
148+
- Model Compilation
149+
150+
``` bash
151+
sbatch --exclusive \
152+
--nodes 4 \
153+
--cpus-per-task 64 \
154+
--wrap="srun neuron_parallel_compile bash $(pwd)/run_llama_70B_tp_pp.sh"
155+
```
156+
157+
- Model Training
158+
159+
``` bash
160+
sbatch --exclusive \
161+
--nodes 4 \
162+
--cpus-per-task 64 \
163+
--wrap="srun bash $(pwd)/run_llama_70B_tp_pp.sh"
164+
```
165+
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Any, Dict, Iterator, Tuple
2+
import torch.nn as nn
3+
4+
import torch
5+
from torch_xla.utils.checkpoint import checkpoint as torch_checkpoint
6+
from neuronx_distributed.parallel_layers.parallel_state import rmsg
7+
from neuronx_distributed.utils.logger import get_logger
8+
from torch.distributed.utils import _replace_by_prefix
9+
10+
logger = get_logger()
11+
12+
_CHECKPOINT_WRAPPED_MODULE = "mod"
13+
_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."
14+
15+
class CheckPointWrapper(torch.nn.Module):
16+
def __init__(self, mod) -> None:
17+
super().__init__()
18+
self.mod = mod
19+
# state_dict post hook to remove prefix to allow loading into a
20+
# non-checkpoint wrapped module.
21+
self._register_state_dict_hook(self._post_state_dict_hook)
22+
# load_state_dict pre-hook to allow loading back into
23+
# checkpoint-wrapped module.
24+
self._register_load_state_dict_pre_hook(
25+
self._pre_load_state_dict_hook, with_module=True
26+
)
27+
28+
29+
def forward(self, *args, **kwargs):
30+
ordered_args = list(args)
31+
for value in kwargs.values():
32+
ordered_args += [value]
33+
34+
# Note: checkpoint cannot accept kwargs
35+
return torch_checkpoint(self.mod, *ordered_args, use_reentrant=True)
36+
37+
def named_parameters(
38+
self,
39+
*args,
40+
**kwargs,
41+
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
42+
"""
43+
Overrides :meth:`named_parameters()` to intercept parameter names and
44+
remove all occurrences of ``_CHECKPOINT_PREFIX``.
45+
"""
46+
for param_name, param in super().named_parameters(*args, **kwargs):
47+
updated_name = param_name.replace(_CHECKPOINT_PREFIX, "")
48+
yield updated_name, param
49+
50+
def named_modules(self,*args,**kwargs):
51+
for module_name, module in super().named_modules(*args, **kwargs):
52+
updated_name = module_name.replace(_CHECKPOINT_PREFIX, "")
53+
yield updated_name, module
54+
55+
@staticmethod
56+
def _post_state_dict_hook(
57+
module: nn.Module,
58+
state_dict: Dict[str, Any],
59+
prefix: str,
60+
*args: Any,
61+
) -> Dict[str, Any]:
62+
"""
63+
_post_state_dict_hook() is called after the state_dict() of this
64+
FSDP module is executed. For ``checkpoint_wrapper``, it will strip
65+
checkpoint-wrapped module prefix so that this module can be loaded into
66+
non-checkpointed modules. It would still be able to be loaded into
67+
checkpoint-wrapped modules as this class adds the prefix back before
68+
loading the state_dict.
69+
"""
70+
_replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
71+
return state_dict
72+
73+
@staticmethod
74+
def _pre_load_state_dict_hook(
75+
module: nn.Module,
76+
state_dict: Dict[str, Any],
77+
prefix: str,
78+
*args: Any,
79+
) -> None:
80+
"""
81+
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
82+
is called. For ``checkpoint_wrapper``, it will add back the module
83+
prefix so that non-checkpointed modules can be loaded into
84+
checkpoint_wrapper modules properly.
85+
"""
86+
_replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")
87+
88+
89+
90+
def apply_checkpoint(dist_model, layers_to_checkpoint=None):
91+
checkpoint_wrapper_added = False
92+
if layers_to_checkpoint is not None and len(layers_to_checkpoint) == 0:
93+
raise RuntimeError(
94+
rmsg(f"invalid input layers_to_checkpoint {layers_to_checkpoint}, can't be empty")
95+
)
96+
for name, module in dist_model.local_module.named_children():
97+
# checkpoint layers that are provided in input
98+
# if layers not provide in input, then checkpoint if it is transformer layer
99+
if (layers_to_checkpoint and name in layers_to_checkpoint) or (
100+
not layers_to_checkpoint and type(module) == dist_model.transformer_layer_cls
101+
):
102+
# add_module replaces old module with our own custom module.
103+
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module
104+
dist_model.local_module.add_module(name, CheckPointWrapper(module))
105+
checkpoint_wrapper_added = True
106+
if layers_to_checkpoint is not None and not checkpoint_wrapper_added:
107+
logger.warning(
108+
rmsg(f"layers_to_checkpoint {layers_to_checkpoint} do not exist in the graph")
109+
)
110+
elif layers_to_checkpoint is None and not checkpoint_wrapper_added:
111+
logger.warning(
112+
rmsg(
113+
f"During applying activation checkpointing, transformer_layer_cls {dist_model.transformer_layer_cls.__name__} can not be found in stage {dist_model.pipeline_parallel_rank}, skipping..."
114+
)
115+
)

0 commit comments

Comments
 (0)