Skip to content

Commit aeb3a4b

Browse files
authored
[EP] add support for ETP=1 (#1555)
This is a followup of original EP support #1324 ### PR summary [TBA] description + figure ### numerics verification setup - optimizer Adam - steps 100, warmup_steps 20 - seed 42 comparison set - FSDP 2 - FSDP 2, CP 2, TP 2, EP 8, ETP 1 - FSDP 2 (EP 2), PP 2, TP 2 (ETP 2) <img width="1316" height="392" alt="image" src="https://github.com/user-attachments/assets/2aa63714-a31e-4152-b904-cc889f8434a1" />
1 parent 48b6520 commit aeb3a4b

File tree

17 files changed

+230
-45
lines changed

17 files changed

+230
-45
lines changed

scripts/estimate/estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def estimate_memory(job_config: JobConfig):
5252
tp=parallelism_config.tensor_parallel_degree,
5353
pp=parallelism_config.pipeline_parallel_degree,
5454
ep=parallelism_config.expert_parallel_degree,
55+
etp=parallelism_config.expert_tensor_parallel_degree,
5556
world_size=world_size,
5657
)
5758
# ParallelDims.build_mesh has to happen outside of the FakeTensorMode

scripts/generate/test_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_generate(
125125
tp=world_size,
126126
pp=1,
127127
ep=1,
128+
etp=1,
128129
world_size=world_size,
129130
)
130131
world_mesh = parallel_dims.world_mesh

tests/unit_tests/test_model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
2222
tp=parallelism_config.tensor_parallel_degree,
2323
pp=parallelism_config.pipeline_parallel_degree,
2424
ep=parallelism_config.expert_parallel_degree,
25+
etp=parallelism_config.expert_tensor_parallel_degree,
2526
world_size=world_size,
2627
)
2728
return parallel_dims

torchtitan/config/job_config.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,27 @@ class Parallelism:
374374

375375
expert_parallel_degree: int = 1
376376
"""
377-
Expert parallelism degree. 1 means disabled.
378-
Currently, only "dp2ep" is supported, with the following constraints:
379-
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
377+
Expert parallelism degree. 1 means disabled. No effect for non-MoE models.
378+
Currently, it is supported with the following constraints:
379+
- when etp = tp:
380+
- cp <= ep <= dp_shard * cp
381+
- ep % cp == 0
382+
- dp_shard * cp % ep == 0
383+
- when etp = 1:
384+
- cp * tp <= ep <= dp_shard * cp * tp
385+
- ep % (cp * tp) == 0
386+
- dp_shard * cp * tp % ep == 0
387+
Note that this is still an experimental feature. Some contrains will be
388+
relaxed soon when we have more flexible DeviceMesh support.
389+
"""
390+
391+
expert_tensor_parallel_degree: int = 1
392+
"""
393+
Expert tensor parallelism degree. 1 means disabled. No effect for non-MoE models, or when ep = 1.
394+
With this option, the tensor parallel degree on routed experts can be different from that on other params.
395+
Currently, we only support either
396+
- [partial dp -> ep] etp = tp
397+
- [partial dp + all tp -> ep] etp = 1
380398
Note that this is still an experimental feature.
381399
"""
382400

torchtitan/distributed/expert_parallel.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,52 @@ def wrapper(
363363
return out
364364

365365
return wrapper
366+
367+
368+
# This class is to support Sequence Parallel for ETP=1
369+
# when EP borrows from all TP and part of DP
370+
class ReordererSequenceParallel(ParallelStyle):
371+
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
372+
top_scores, selected_experts_indices = inputs
373+
374+
top_scores = DTensor.from_local(top_scores, device_mesh, (Replicate(),))
375+
selected_experts_indices = DTensor.from_local(
376+
selected_experts_indices, device_mesh, (Replicate(),)
377+
)
378+
379+
# TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
380+
# if top_scores.shape[0] % device_mesh.size() != 0:
381+
# num_tokens = top_scores.shape[0]
382+
# tp_size = device_mesh.size()
383+
# n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens
384+
# selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad])
385+
# top_scores = F.pad(top_scores, [0, 0, 0, n_pad])
386+
assert top_scores.shape[0] % device_mesh.size() == 0
387+
388+
# split on the bs*slen dimension
389+
top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local()
390+
selected_experts_indices = selected_experts_indices.redistribute(
391+
device_mesh, (Shard(0),)
392+
).to_local()
393+
394+
return top_scores, selected_experts_indices
395+
396+
def _prepare_output_fn(self, mod, outputs, device_mesh):
397+
top_scores, token_indices_experts_sorted, num_tokens_per_expert = outputs
398+
399+
# NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
400+
# the MoE gather and scatter still require global token indices.
401+
num_tokens = top_scores.shape[0]
402+
local_rank = device_mesh.get_local_rank()
403+
token_indices_experts_sorted += num_tokens // device_mesh.size() * local_rank
404+
405+
return top_scores, token_indices_experts_sorted, num_tokens_per_expert
406+
407+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
408+
return distribute_module(
409+
module,
410+
device_mesh,
411+
partition_fn=None,
412+
input_fn=self._prepare_inputput_fn,
413+
output_fn=self._prepare_output_fn,
414+
)

torchtitan/distributed/parallel_dims.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ParallelDims:
2323
tp: int
2424
pp: int
2525
ep: int
26+
etp: int
2627
world_size: int
2728

2829
_world_mesh: DeviceMesh = None
@@ -31,18 +32,19 @@ def __post_init__(self):
3132
self._validate()
3233

3334
def _validate(self):
34-
dp_replicate, dp_shard, cp, tp, pp, ep = (
35+
dp_replicate, dp_shard, cp, tp, pp, ep, etp = (
3536
self.dp_replicate,
3637
self.dp_shard,
3738
self.cp,
3839
self.tp,
3940
self.pp,
4041
self.ep,
42+
self.etp,
4143
)
42-
for d in (dp_replicate, cp, tp, pp, ep):
44+
for d in (dp_replicate, cp, tp, pp, ep, etp):
4345
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
4446

45-
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
47+
assert dp_shard == -1 or dp_shard >= 1, "dp_shard must -1 or >=1."
4648
if dp_shard < 0:
4749
self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp)
4850
assert dp_shard >= 1
@@ -53,8 +55,13 @@ def _validate(self):
5355
)
5456

5557
if ep > 1:
56-
# EP would borrow all cp and some dp_shard degree
57-
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
58+
assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1"
59+
if etp == tp:
60+
# EP would borrow all cp and some dp_shard degree
61+
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
62+
elif etp == 1:
63+
# EP would borrow all cp and tp and some dp_shard degree
64+
assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0
5865

5966
def build_mesh(self) -> DeviceMesh:
6067
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
@@ -68,9 +75,15 @@ def build_mesh(self) -> DeviceMesh:
6875
def _build_mesh_with_ep(self) -> DeviceMesh:
6976
# With ep, dp_shard and ep are derived submeshes:
7077
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
71-
# ep = dp_shard_in_ep * cp
72-
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
73-
dp_shard_in_ep = self.ep // self.cp
78+
if self.etp == self.tp:
79+
# ep = dp_shard_in_ep * cp
80+
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
81+
dp_shard_in_ep = self.ep // self.cp
82+
else:
83+
assert self.etp == 1
84+
# ep = dp_shard_in_ep * cp * tp
85+
dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep
86+
dp_shard_in_ep = self.ep // (self.cp * self.tp)
7487

7588
dims = []
7689
names = []
@@ -121,6 +134,8 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
121134
dp_shard_cp_mesh_dim_names.append("cp")
122135
dp_cp_mesh_dim_names.append("cp")
123136
ep_mesh_dim_names.append("cp")
137+
if self.etp == 1 and self.tp_enabled:
138+
ep_mesh_dim_names.append("tp")
124139

125140
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
126141
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
@@ -218,6 +233,10 @@ def pp_enabled(self):
218233
def ep_enabled(self):
219234
return self.ep > 1
220235

236+
@property
237+
def etp_enabled(self):
238+
return self.etp > 1
239+
221240
@property
222241
def fsdp_gradient_divide_factor(self) -> int:
223242
# This is needed for FSDP-sharded experts when Expert Parallel is enabled.

torchtitan/experiments/forge/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(self, job_config: ForgeJobConfig):
8080
tp=parallelism_config.tensor_parallel_degree,
8181
pp=parallelism_config.pipeline_parallel_degree,
8282
ep=parallelism_config.expert_parallel_degree,
83+
etp=parallelism_config.expert_tensor_parallel_degree,
8384
world_size=world_size,
8485
)
8586

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ExpertParallel,
2626
ExpertTensorParallel,
2727
NoParallel,
28+
ReordererSequenceParallel,
2829
TensorParallel,
2930
)
3031

@@ -87,17 +88,19 @@ def parallelize_llama(
8788
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
8889
)
8990

90-
# TODO: shall we support tensorwise float8 comms for MoE TP
9191
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
9292
apply_moe_ep_tp(
9393
model,
9494
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
9595
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
9696
ep_tp_mesh=(
9797
world_mesh["ep", "tp"]
98-
if parallel_dims.tp_enabled and parallel_dims.ep_enabled
98+
if parallel_dims.tp_enabled
99+
and parallel_dims.ep_enabled
100+
and parallel_dims.etp_enabled
99101
else None
100102
),
103+
etp_enabled=parallel_dims.etp_enabled,
101104
)
102105

103106
if job_config.activation_checkpoint.mode != "none":
@@ -344,6 +347,7 @@ def apply_moe_ep_tp(
344347
tp_mesh: DeviceMesh | None,
345348
ep_mesh: DeviceMesh | None,
346349
ep_tp_mesh: DeviceMesh | None,
350+
etp_enabled: bool,
347351
):
348352
for transformer_block in model.layers.values():
349353
if not transformer_block.moe_enabled:
@@ -365,13 +369,17 @@ def apply_moe_ep_tp(
365369
# input Replicate, output Partial
366370
"moe.shared_expert": TensorParallel(),
367371
}
372+
if not etp_enabled:
373+
# If TP is borrowed for EP, then split the tokens across TP ranks so that
374+
# the reorderer, the all-to-all comms, and routed experts computation
375+
# are effectively running Sequence Parallel (split along the folded bs*slen dim)
376+
moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()})
368377
parallelize_module(
369378
module=transformer_block,
370379
device_mesh=tp_mesh,
371380
parallelize_plan=moe_layer_plan,
372381
)
373382

374-
# if ep_mesh is not None:
375383
experts_mesh, experts_plan = None, None
376384
if ep_mesh is None:
377385
experts_mesh = tp_mesh
@@ -381,9 +389,13 @@ def apply_moe_ep_tp(
381389
experts_mesh = ep_mesh
382390
# input / output sharding on the batch / tokens dim
383391
experts_plan = ExpertParallel()
384-
else:
392+
elif etp_enabled:
385393
experts_mesh = ep_tp_mesh
386394
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
395+
else:
396+
experts_mesh = ep_mesh
397+
experts_plan = ExpertParallel()
398+
387399
parallelize_module(
388400
module=transformer_block.moe.experts,
389401
device_mesh=experts_mesh,

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ enable_async_tensor_parallel = false
5353
pipeline_parallel_degree = 1
5454
context_parallel_degree = 1
5555
expert_parallel_degree = 1
56+
expert_tensor_parallel_degree = 1
5657

5758
[checkpoint]
5859
enable_checkpoint = false

torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ pipeline_parallel_degree = 4
4646
# pipeline_parallel_schedule = "interleaved1f1b"
4747
# pipeline_parallel_microbatches = 2
4848
context_parallel_degree = 1
49+
expert_parallel_degree = 1
50+
expert_tensor_parallel_degree = 8
4951

5052
[checkpoint]
5153
enable_checkpoint = false

0 commit comments

Comments
 (0)