@@ -23,6 +23,7 @@ class ParallelDims:
2323 tp : int
2424 pp : int
2525 ep : int
26+ etp : int
2627 world_size : int
2728
2829 _world_mesh : DeviceMesh = None
@@ -31,18 +32,19 @@ def __post_init__(self):
3132 self ._validate ()
3233
3334 def _validate (self ):
34- dp_replicate , dp_shard , cp , tp , pp , ep = (
35+ dp_replicate , dp_shard , cp , tp , pp , ep , etp = (
3536 self .dp_replicate ,
3637 self .dp_shard ,
3738 self .cp ,
3839 self .tp ,
3940 self .pp ,
4041 self .ep ,
42+ self .etp ,
4143 )
42- for d in (dp_replicate , cp , tp , pp , ep ):
44+ for d in (dp_replicate , cp , tp , pp , ep , etp ):
4345 assert d >= 1 , "Parallelism degree should be >= 1, except for dp_shard"
4446
45- assert dp_shard == - 1 or dp_shard >= 1 , " dp_shard must -1 or >=1."
47+ assert dp_shard == - 1 or dp_shard >= 1 , "dp_shard must -1 or >=1."
4648 if dp_shard < 0 :
4749 self .dp_shard = dp_shard = self .world_size // (dp_replicate * cp * tp * pp )
4850 assert dp_shard >= 1
@@ -53,8 +55,13 @@ def _validate(self):
5355 )
5456
5557 if ep > 1 :
56- # EP would borrow all cp and some dp_shard degree
57- assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
58+ assert etp == tp or etp == 1 , "Currently we only support ETP=TP or ETP=1"
59+ if etp == tp :
60+ # EP would borrow all cp and some dp_shard degree
61+ assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
62+ elif etp == 1 :
63+ # EP would borrow all cp and tp and some dp_shard degree
64+ assert ep % (cp * tp ) == 0 and (dp_shard * cp * tp ) % ep == 0
5865
5966 def build_mesh (self ) -> DeviceMesh :
6067 # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
@@ -68,9 +75,15 @@ def build_mesh(self) -> DeviceMesh:
6875 def _build_mesh_with_ep (self ) -> DeviceMesh :
6976 # With ep, dp_shard and ep are derived submeshes:
7077 # dp_shard = dp_shard_mod_ep * dp_shard_in_ep
71- # ep = dp_shard_in_ep * cp
72- dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
73- dp_shard_in_ep = self .ep // self .cp
78+ if self .etp == self .tp :
79+ # ep = dp_shard_in_ep * cp
80+ dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
81+ dp_shard_in_ep = self .ep // self .cp
82+ else :
83+ assert self .etp == 1
84+ # ep = dp_shard_in_ep * cp * tp
85+ dp_shard_mod_ep = self .dp_shard * self .cp * self .tp // self .ep
86+ dp_shard_in_ep = self .ep // (self .cp * self .tp )
7487
7588 dims = []
7689 names = []
@@ -121,6 +134,8 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
121134 dp_shard_cp_mesh_dim_names .append ("cp" )
122135 dp_cp_mesh_dim_names .append ("cp" )
123136 ep_mesh_dim_names .append ("cp" )
137+ if self .etp == 1 and self .tp_enabled :
138+ ep_mesh_dim_names .append ("tp" )
124139
125140 mesh [tuple (dp_mesh_dim_names )]._flatten (mesh_dim_name = "dp" )
126141 mesh [tuple (dp_shard_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_shard_cp" )
@@ -218,6 +233,10 @@ def pp_enabled(self):
218233 def ep_enabled (self ):
219234 return self .ep > 1
220235
236+ @property
237+ def etp_enabled (self ):
238+ return self .etp > 1
239+
221240 @property
222241 def fsdp_gradient_divide_factor (self ) -> int :
223242 # This is needed for FSDP-sharded experts when Expert Parallel is enabled.
0 commit comments