|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +###################################################################### |
| 7 | +# |
| 8 | +# To run these benchmarks, use the following command: |
| 9 | +# |
| 10 | +# torchrun --nproc-per-node=8 --local-ranks-filter=0 torchao/prototype/moe_training/benchmarks/benchmark_moe_layer.py |
| 11 | +# |
| 12 | +####################################################################### |
| 13 | + |
| 14 | +import argparse |
| 15 | +import copy |
| 16 | +import os |
| 17 | +import statistics |
| 18 | +from time import perf_counter_ns |
| 19 | + |
| 20 | +import pytest |
| 21 | +import torch |
| 22 | +from torch import distributed as dist |
| 23 | +from torch import nn |
| 24 | +from torch.distributed._composable.fsdp import fully_shard |
| 25 | +from torch.nn import functional as F |
| 26 | + |
| 27 | +# this feature requires CUDA and SM89+ |
| 28 | +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): |
| 29 | + pytest.skip( |
| 30 | + "CUDA not available or compute capability < 8.9", allow_module_level=True |
| 31 | + ) |
| 32 | + |
| 33 | +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig |
| 34 | +from torchao.quantization.quant_api import quantize_ |
| 35 | + |
| 36 | +# this test requires torchtitan |
| 37 | +try: |
| 38 | + from torchtitan.experiments.llama4.infra.expert_parallel import ( |
| 39 | + set_token_group_alignment_size_m, |
| 40 | + ) |
| 41 | + from torchtitan.experiments.llama4.model.args import TransformerModelArgs |
| 42 | + from torchtitan.experiments.llama4.model.moe import MoE |
| 43 | +except ImportError: |
| 44 | + pytest.skip( |
| 45 | + "torchtitan not installed, skipping MoE tests.", allow_module_level=True |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def bench_moe_float8_training_fsdp(enable_profile=False): |
| 50 | + assert torch.cuda.is_available() |
| 51 | + |
| 52 | + # setup distributed for fsdp |
| 53 | + setup_distributed() |
| 54 | + |
| 55 | + # define model args |
| 56 | + target_fqns = ["experts"] |
| 57 | + model_args = TransformerModelArgs( |
| 58 | + moe_enabled=True, |
| 59 | + num_experts=16, |
| 60 | + dim=5120, |
| 61 | + ) |
| 62 | + init_std = 0.02 |
| 63 | + device = torch.device("cuda") |
| 64 | + |
| 65 | + # reference bf16 MoE |
| 66 | + ref_model = MoE(model_args).to(torch.bfloat16).cuda() |
| 67 | + torch.manual_seed(42) |
| 68 | + ref_model.init_weights(init_std, device) |
| 69 | + |
| 70 | + # target MoE for testing conversion |
| 71 | + model = copy.deepcopy(ref_model) |
| 72 | + |
| 73 | + # assert starting params are identical for both models |
| 74 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 75 | + assert torch.equal(param1, param2) |
| 76 | + |
| 77 | + # convert MoE to float8 training |
| 78 | + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: |
| 79 | + for target_fqn in target_fqns: |
| 80 | + if target_fqn in cur_fqn: |
| 81 | + return True |
| 82 | + return False |
| 83 | + |
| 84 | + # quantize test model |
| 85 | + config = MoETrainingConfig() |
| 86 | + quantize_(model, config=config, filter_fn=moe_module_filter_fn) |
| 87 | + |
| 88 | + # FSDP2 |
| 89 | + fully_shard(model) |
| 90 | + fully_shard(ref_model) |
| 91 | + |
| 92 | + # inputs (llama4 shapes) |
| 93 | + batch, seq, dim = 1, 8192, 5120 |
| 94 | + ref_x = torch.randn( |
| 95 | + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device |
| 96 | + ) |
| 97 | + x = ref_x.detach().clone().requires_grad_(True) |
| 98 | + |
| 99 | + def bench_fn_microseconds(model, input): |
| 100 | + labels = torch.ones_like(input) |
| 101 | + times = [] |
| 102 | + for _ in range(10): |
| 103 | + start_ns = perf_counter_ns() |
| 104 | + out = model(input) |
| 105 | + loss = F.mse_loss(out, labels) |
| 106 | + loss.backward() |
| 107 | + torch.cuda.synchronize() |
| 108 | + end_ns = perf_counter_ns() |
| 109 | + duration_us = (end_ns - start_ns) / 1000 |
| 110 | + times.append(duration_us) |
| 111 | + return statistics.median(times) |
| 112 | + |
| 113 | + def profile_fn(model, input, profile_name="profile"): |
| 114 | + # Only profile on rank 0 |
| 115 | + if torch.distributed.get_rank() == 0: |
| 116 | + labels = torch.ones_like(input) |
| 117 | + wait, warmup, active = 1, 3, 1 |
| 118 | + total_steps = wait + warmup + active |
| 119 | + with torch.profiler.profile( |
| 120 | + activities=[ |
| 121 | + torch.profiler.ProfilerActivity.CPU, |
| 122 | + torch.profiler.ProfilerActivity.CUDA, |
| 123 | + ], |
| 124 | + schedule=torch.profiler.schedule( |
| 125 | + wait=wait, warmup=warmup, active=active, repeat=0 |
| 126 | + ), |
| 127 | + record_shapes=True, |
| 128 | + with_stack=True, |
| 129 | + ) as prof: |
| 130 | + for _ in range(total_steps): |
| 131 | + out = model(input) |
| 132 | + loss = F.mse_loss(out, labels) |
| 133 | + loss.backward() |
| 134 | + prof.step() |
| 135 | + |
| 136 | + # Save profiler results |
| 137 | + prof.export_chrome_trace(f"{profile_name}.json") |
| 138 | + print(f"Saved: {profile_name}.json") |
| 139 | + |
| 140 | + # Compile models |
| 141 | + ref_model = torch.compile(ref_model, fullgraph=False) |
| 142 | + model = torch.compile(model, fullgraph=False) |
| 143 | + |
| 144 | + print("Benchmarking MoE with FSDP2 using bf16 training") |
| 145 | + bf16_us = bench_fn_microseconds(ref_model, ref_x) |
| 146 | + print(f"bf16 time: {bf16_us} us") |
| 147 | + if enable_profile: |
| 148 | + print("Profiling bf16 model") |
| 149 | + profile_fn(ref_model, ref_x, profile_name="bf16_profile") |
| 150 | + |
| 151 | + # Token group alignment size must be 16 for fp8 rowwise training |
| 152 | + set_token_group_alignment_size_m(16) |
| 153 | + |
| 154 | + print("Benchmarking MoE with FSDP2 using fp8 rowwise training") |
| 155 | + fp8_us = bench_fn_microseconds(model, x) |
| 156 | + print(f"fp8 time: {fp8_us} us") |
| 157 | + if enable_profile: |
| 158 | + print("Profiling fp8 model") |
| 159 | + profile_fn(model, x, profile_name="fp8_profile") |
| 160 | + |
| 161 | + dist.destroy_process_group() |
| 162 | + |
| 163 | + |
| 164 | +def setup_distributed(): |
| 165 | + rank = int(os.environ["RANK"]) |
| 166 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 167 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 168 | + torch.cuda.set_device(rank) |
| 169 | + |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2") |
| 173 | + parser.add_argument( |
| 174 | + "--profile", |
| 175 | + action="store_true", |
| 176 | + help="Enable PyTorch profiling and save results to file", |
| 177 | + ) |
| 178 | + args = parser.parse_args() |
| 179 | + bench_moe_float8_training_fsdp(enable_profile=args.profile) |
0 commit comments