Skip to content

Commit 45e4583

Browse files
[moe training] add benchmark script for moe layer
stack-info: PR: #2671, branch: danielvegamyhre/stack/29
1 parent 5dc7338 commit 45e4583

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
#
8+
# To run these benchmarks, use the following command:
9+
#
10+
# torchrun --nproc-per-node=8 --local-ranks-filter=0 torchao/prototype/moe_training/benchmarks/benchmark_moe_layer.py
11+
#
12+
#######################################################################
13+
14+
import argparse
15+
import copy
16+
import os
17+
import statistics
18+
from time import perf_counter_ns
19+
20+
import pytest
21+
import torch
22+
from torch import distributed as dist
23+
from torch import nn
24+
from torch.distributed._composable.fsdp import fully_shard
25+
from torch.nn import functional as F
26+
27+
# this feature requires CUDA and SM89+
28+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
29+
pytest.skip(
30+
"CUDA not available or compute capability < 8.9", allow_module_level=True
31+
)
32+
33+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
34+
from torchao.quantization.quant_api import quantize_
35+
36+
# this test requires torchtitan
37+
try:
38+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
39+
from torchtitan.experiments.llama4.model.moe import MoE
40+
except ImportError:
41+
pytest.skip(
42+
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
43+
)
44+
45+
46+
def bench_moe_float8_training_fsdp(enable_profile=False):
47+
assert torch.cuda.is_available()
48+
49+
# setup distributed for fsdp
50+
setup_distributed()
51+
52+
# define model args
53+
target_fqns = ["experts"]
54+
model_args = TransformerModelArgs(
55+
moe_enabled=True,
56+
num_experts=16,
57+
dim=5120,
58+
)
59+
init_std = 0.02
60+
device = torch.device("cuda")
61+
62+
# reference bf16 MoE
63+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
64+
torch.manual_seed(42)
65+
ref_model.init_weights(init_std, device)
66+
67+
# target MoE for testing conversion
68+
model = copy.deepcopy(ref_model)
69+
70+
# assert starting params are identical for both models
71+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
72+
assert torch.equal(param1, param2)
73+
74+
# convert MoE to float8 training
75+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
76+
for target_fqn in target_fqns:
77+
if target_fqn in cur_fqn:
78+
return True
79+
return False
80+
81+
# quantize test model
82+
config = MoETrainingConfig()
83+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
84+
85+
# FSDP2
86+
fully_shard(model)
87+
fully_shard(ref_model)
88+
89+
# inputs (llama4 shapes)
90+
batch, seq, dim = 1, 8192, 5120
91+
ref_x = torch.randn(
92+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
93+
)
94+
x = ref_x.detach().clone().requires_grad_(True)
95+
96+
def bench_fn_microseconds(model, input):
97+
labels = torch.ones_like(input)
98+
times = []
99+
for _ in range(10):
100+
start_ns = perf_counter_ns()
101+
out = model(input)
102+
loss = F.mse_loss(out, labels)
103+
loss.backward()
104+
torch.cuda.synchronize()
105+
end_ns = perf_counter_ns()
106+
duration_us = (end_ns - start_ns) / 1000
107+
times.append(duration_us)
108+
return statistics.median(times)
109+
110+
def profile_fn(model, input, profile_name="profile"):
111+
# Only profile on rank 0
112+
if torch.distributed.get_rank() == 0:
113+
labels = torch.ones_like(input)
114+
wait, warmup, active = 1, 3, 1
115+
total_steps = wait + warmup + active
116+
with torch.profiler.profile(
117+
activities=[
118+
torch.profiler.ProfilerActivity.CPU,
119+
torch.profiler.ProfilerActivity.CUDA,
120+
],
121+
schedule=torch.profiler.schedule(
122+
wait=wait, warmup=warmup, active=active, repeat=0
123+
),
124+
record_shapes=True,
125+
with_stack=True,
126+
) as prof:
127+
for _ in range(total_steps):
128+
out = model(input)
129+
loss = F.mse_loss(out, labels)
130+
loss.backward()
131+
prof.step()
132+
133+
# Save profiler results
134+
prof.export_chrome_trace(f"{profile_name}.json")
135+
print(f"Saved: {profile_name}.json")
136+
137+
# Compile models
138+
ref_model, model = torch.compile(ref_model), torch.compile(model)
139+
140+
print("Benchmarking MoE with FSDP2 using bf16 training")
141+
bf16_us = bench_fn_microseconds(ref_model, ref_x)
142+
print(f"bf16 time: {bf16_us} us")
143+
if enable_profile:
144+
print("Profiling bf16 model")
145+
profile_fn(ref_model, ref_x, profile_name="bf16_profile")
146+
147+
print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
148+
fp8_us = bench_fn_microseconds(model, x)
149+
print(f"fp8 time: {fp8_us} us")
150+
if enable_profile:
151+
print("Profiling fp8 model")
152+
profile_fn(model, x, profile_name="fp8_profile")
153+
154+
dist.destroy_process_group()
155+
156+
157+
def setup_distributed():
158+
rank = int(os.environ["RANK"])
159+
world_size = int(os.environ["WORLD_SIZE"])
160+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
161+
torch.cuda.set_device(rank)
162+
163+
164+
if __name__ == "__main__":
165+
parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2")
166+
parser.add_argument(
167+
"--profile",
168+
action="store_true",
169+
help="Enable PyTorch profiling and save results to file",
170+
)
171+
args = parser.parse_args()
172+
bench_moe_float8_training_fsdp(enable_profile=args.profile)

0 commit comments

Comments
 (0)