Skip to content

Commit a9e1e2f

Browse files
[moe training] add benchmark script for moe layer
stack-info: PR: #2671, branch: danielvegamyhre/stack/29
1 parent 546d90d commit a9e1e2f

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these benchmarks, use the following command:
9+
#
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 torchao/prototype/moe_training/benchmarks/benchmark_moe_layer.py
11+
#
12+
#######################################################################
13+
14+
import argparse
15+
import copy
16+
import os
17+
import statistics
18+
from time import perf_counter_ns
19+
20+
import pytest
21+
import torch
22+
from torch import distributed as dist
23+
from torch import nn
24+
from torch.distributed._composable.fsdp import fully_shard
25+
from torch.nn import functional as F
26+
27+
# this feature requires CUDA and SM89+
28+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
29+
pytest.skip(
30+
"CUDA not available or compute capability < 8.9", allow_module_level=True
31+
)
32+
33+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
34+
from torchao.quantization.quant_api import quantize_
35+
36+
# this test requires torchtitan
37+
try:
38+
from torchtitan.experiments.llama4.infra.expert_parallel import (
39+
set_token_group_alignment_size_m,
40+
)
41+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
42+
from torchtitan.experiments.llama4.model.moe import MoE
43+
except ImportError:
44+
pytest.skip(
45+
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
46+
)
47+
48+
49+
def bench_moe_float8_training_fsdp(enable_profile=False):
50+
assert torch.cuda.is_available()
51+
52+
# setup distributed for fsdp
53+
setup_distributed()
54+
55+
# define model args
56+
target_fqns = ["experts"]
57+
model_args = TransformerModelArgs(
58+
moe_enabled=True,
59+
num_experts=16,
60+
dim=5120,
61+
)
62+
init_std = 0.02
63+
device = torch.device("cuda")
64+
65+
# reference bf16 MoE
66+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
67+
torch.manual_seed(42)
68+
ref_model.init_weights(init_std, device)
69+
70+
# target MoE for testing conversion
71+
model = copy.deepcopy(ref_model)
72+
73+
# assert starting params are identical for both models
74+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
75+
assert torch.equal(param1, param2)
76+
77+
# convert MoE to float8 training
78+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
79+
for target_fqn in target_fqns:
80+
if target_fqn in cur_fqn:
81+
return True
82+
return False
83+
84+
# quantize test model
85+
config = MoETrainingConfig()
86+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
87+
88+
# FSDP2
89+
fully_shard(model)
90+
fully_shard(ref_model)
91+
92+
# inputs (llama4 shapes)
93+
batch, seq, dim = 1, 8192, 5120
94+
ref_x = torch.randn(
95+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
96+
)
97+
x = ref_x.detach().clone().requires_grad_(True)
98+
99+
def bench_fn_microseconds(model, input):
100+
labels = torch.ones_like(input)
101+
times = []
102+
for _ in range(10):
103+
start_ns = perf_counter_ns()
104+
out = model(input)
105+
loss = F.mse_loss(out, labels)
106+
loss.backward()
107+
torch.cuda.synchronize()
108+
end_ns = perf_counter_ns()
109+
duration_us = (end_ns - start_ns) / 1000
110+
times.append(duration_us)
111+
return statistics.median(times)
112+
113+
def profile_fn(model, input, profile_name="profile"):
114+
# Only profile on rank 0
115+
if torch.distributed.get_rank() == 0:
116+
labels = torch.ones_like(input)
117+
wait, warmup, active = 1, 3, 1
118+
total_steps = wait + warmup + active
119+
with torch.profiler.profile(
120+
activities=[
121+
torch.profiler.ProfilerActivity.CPU,
122+
torch.profiler.ProfilerActivity.CUDA,
123+
],
124+
schedule=torch.profiler.schedule(
125+
wait=wait, warmup=warmup, active=active, repeat=0
126+
),
127+
record_shapes=True,
128+
with_stack=True,
129+
) as prof:
130+
for _ in range(total_steps):
131+
out = model(input)
132+
loss = F.mse_loss(out, labels)
133+
loss.backward()
134+
prof.step()
135+
136+
# Save profiler results
137+
prof.export_chrome_trace(f"{profile_name}.json")
138+
print(f"Saved: {profile_name}.json")
139+
140+
# Compile models
141+
ref_model = torch.compile(ref_model, fullgraph=False)
142+
model = torch.compile(model, fullgraph=False)
143+
144+
print("Benchmarking MoE with FSDP2 using bf16 training")
145+
bf16_us = bench_fn_microseconds(ref_model, ref_x)
146+
print(f"bf16 time: {bf16_us} us")
147+
if enable_profile:
148+
print("Profiling bf16 model")
149+
profile_fn(ref_model, ref_x, profile_name="bf16_profile")
150+
151+
# Token group alignment size must be 16 for fp8 rowwise training
152+
set_token_group_alignment_size_m(16)
153+
154+
print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
155+
fp8_us = bench_fn_microseconds(model, x)
156+
print(f"fp8 time: {fp8_us} us")
157+
if enable_profile:
158+
print("Profiling fp8 model")
159+
profile_fn(model, x, profile_name="fp8_profile")
160+
161+
dist.destroy_process_group()
162+
163+
164+
def setup_distributed():
165+
rank = int(os.environ["RANK"])
166+
world_size = int(os.environ["WORLD_SIZE"])
167+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
168+
torch.cuda.set_device(rank)
169+
170+
171+
if __name__ == "__main__":
172+
parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2")
173+
parser.add_argument(
174+
"--profile",
175+
action="store_true",
176+
help="Enable PyTorch profiling and save results to file",
177+
)
178+
args = parser.parse_args()
179+
bench_moe_float8_training_fsdp(enable_profile=args.profile)

0 commit comments

Comments
 (0)