Skip to content

Adding torch accelerator and requirements file to FSDP2 example #1375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions distributed/FSDP2/README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
## FSDP2
To run FSDP2 on transformer model:

```
cd distributed/FSDP2
torchrun --nproc_per_node 2 train.py
pip install -r requirements.txt
torchrun --nproc_per_node 2 example.py
```
* For 1st time, it creates a "checkpoints" folder and saves state dicts there
* For 2nd time, it loads from previous checkpoints

To enable explicit prefetching
```
torchrun --nproc_per_node 2 train.py --explicit-prefetch
torchrun --nproc_per_node 2 example.py --explicit-prefetch
```

To enable mixed precision
```
torchrun --nproc_per_node 2 train.py --mixed-precision
torchrun --nproc_per_node 2 example.py --mixed-precision
```

To showcase DCP API
```
torchrun --nproc_per_node 2 train.py --dcp-api
torchrun --nproc_per_node 2 example.py --dcp-api
```

## Ensure you are running a recent version of PyTorch:
Expand Down
27 changes: 23 additions & 4 deletions distributed/FSDP2/train.py → distributed/FSDP2/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from utils import inspect_mixed_precision, inspect_model

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus

def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
for i, layer in enumerate(model.layers):
Expand All @@ -29,10 +34,23 @@ def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):


def main(args):
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
exit()
rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group(backend="nccl", device_id=device)
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
else:
device = torch.device("cpu")
print(f"Running on device {device}")

backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, device_id=device)

torch.manual_seed(0)
vocab_size = 1024
batch_size = 32
Expand Down Expand Up @@ -64,7 +82,7 @@ def main(args):

checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
if checkpointer.last_training_time is None:
model.to_empty(device="cuda")
model.to_empty(device=device)
model.reset_parameters()
else:
checkpointer.load_model(model)
Expand Down Expand Up @@ -96,4 +114,5 @@ def main(args):
parser.add_argument("--mixed-precision", action="store_true", default=False)
parser.add_argument("--dcp-api", action="store_true", default=False)
args = parser.parse_args()

main(args)
2 changes: 2 additions & 0 deletions distributed/FSDP2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch>=2.7
numpy
11 changes: 11 additions & 0 deletions distributed/FSDP2/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# /bin/bash
# bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to run. Default = 'example.py'
# num_gpus = num local gpus to use (must be at least 2). Default = 4

# samples to run include:
# example.py

echo "Launching ${1:-example.py} with ${2:-4} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-4} ${1:-example.py}

4 changes: 4 additions & 0 deletions run_distributed_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ function distributed_tensor_parallelism() {
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
}

function distributed_FSDP2() {
uv run bash run_example.sh example.py || error "FSDP2 example failed"
}

function distributed_ddp() {
uv run main.py || error "ddp example failed"
}
Expand Down