@@ -23,6 +23,7 @@ class ParallelDims:
23
23
tp : int
24
24
pp : int
25
25
ep : int
26
+ etp : int
26
27
world_size : int
27
28
28
29
_world_mesh : DeviceMesh = None
@@ -31,18 +32,19 @@ def __post_init__(self):
31
32
self ._validate ()
32
33
33
34
def _validate (self ):
34
- dp_replicate , dp_shard , cp , tp , pp , ep = (
35
+ dp_replicate , dp_shard , cp , tp , pp , ep , etp = (
35
36
self .dp_replicate ,
36
37
self .dp_shard ,
37
38
self .cp ,
38
39
self .tp ,
39
40
self .pp ,
40
41
self .ep ,
42
+ self .etp ,
41
43
)
42
- for d in (dp_replicate , cp , tp , pp , ep ):
44
+ for d in (dp_replicate , cp , tp , pp , ep , etp ):
43
45
assert d >= 1 , "Parallelism degree should be >= 1, except for dp_shard"
44
46
45
- assert dp_shard == - 1 or dp_shard >= 1 , " dp_shard must -1 or >=1."
47
+ assert dp_shard == - 1 or dp_shard >= 1 , "dp_shard must -1 or >=1."
46
48
if dp_shard < 0 :
47
49
self .dp_shard = dp_shard = self .world_size // (dp_replicate * cp * tp * pp )
48
50
assert dp_shard >= 1
@@ -53,8 +55,13 @@ def _validate(self):
53
55
)
54
56
55
57
if ep > 1 :
56
- # EP would borrow all cp and some dp_shard degree
57
- assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
58
+ assert etp == tp or etp == 1 , "Currently we only support ETP=TP or ETP=1"
59
+ if etp == tp :
60
+ # EP would borrow all cp and some dp_shard degree
61
+ assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
62
+ elif etp == 1 :
63
+ # EP would borrow all cp and tp and some dp_shard degree
64
+ assert ep % (cp * tp ) == 0 and (dp_shard * cp * tp ) % ep == 0
58
65
59
66
def build_mesh (self ) -> DeviceMesh :
60
67
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
@@ -68,9 +75,15 @@ def build_mesh(self) -> DeviceMesh:
68
75
def _build_mesh_with_ep (self ) -> DeviceMesh :
69
76
# With ep, dp_shard and ep are derived submeshes:
70
77
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
71
- # ep = dp_shard_in_ep * cp
72
- dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
73
- dp_shard_in_ep = self .ep // self .cp
78
+ if self .etp == self .tp :
79
+ # ep = dp_shard_in_ep * cp
80
+ dp_shard_mod_ep = self .dp_shard * self .cp // self .ep
81
+ dp_shard_in_ep = self .ep // self .cp
82
+ else :
83
+ assert self .etp == 1
84
+ # ep = dp_shard_in_ep * cp * tp
85
+ dp_shard_mod_ep = self .dp_shard * self .cp * self .tp // self .ep
86
+ dp_shard_in_ep = self .ep // (self .cp * self .tp )
74
87
75
88
dims = []
76
89
names = []
@@ -121,6 +134,8 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
121
134
dp_shard_cp_mesh_dim_names .append ("cp" )
122
135
dp_cp_mesh_dim_names .append ("cp" )
123
136
ep_mesh_dim_names .append ("cp" )
137
+ if self .etp == 1 and self .tp_enabled :
138
+ ep_mesh_dim_names .append ("tp" )
124
139
125
140
mesh [tuple (dp_mesh_dim_names )]._flatten (mesh_dim_name = "dp" )
126
141
mesh [tuple (dp_shard_cp_mesh_dim_names )]._flatten (mesh_dim_name = "dp_shard_cp" )
@@ -218,6 +233,10 @@ def pp_enabled(self):
218
233
def ep_enabled (self ):
219
234
return self .ep > 1
220
235
236
+ @property
237
+ def etp_enabled (self ):
238
+ return self .etp > 1
239
+
221
240
@property
222
241
def fsdp_gradient_divide_factor (self ) -> int :
223
242
# This is needed for FSDP-sharded experts when Expert Parallel is enabled.
0 commit comments