File tree Expand file tree Collapse file tree 5 files changed +12
-6
lines changed Expand file tree Collapse file tree 5 files changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -110,7 +110,6 @@ def __init__(
110
110
assert len (
111
111
self .initial_local_expert_ids ) == self .expert_size_per_partition
112
112
113
- max_num_tokens = model_config .max_num_tokens
114
113
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115
114
moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
116
115
self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
Original file line number Diff line number Diff line change @@ -339,7 +339,9 @@ def __init__(
339
339
# It can avoid OOM for 8k/1k cases.
340
340
default_moe_max_num_tokens = 18688
341
341
if moe_max_num_tokens > default_moe_max_num_tokens :
342
+ model_config ._frozen = False
342
343
model_config .moe_max_num_tokens = default_moe_max_num_tokens
344
+ model_config ._frozen = True
343
345
344
346
super ().__init__ (
345
347
routing_method = routing_method ,
@@ -600,9 +602,12 @@ def forward(
600
602
else :
601
603
num_rows = x .shape [0 ]
602
604
603
- # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
604
- num_chunks = (num_rows + self .moe_max_num_tokens -
605
- 1 ) // self .moe_max_num_tokens
605
+ # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
606
+ # Because we will use two streams in chunked moe and preallocate two workspaces.
607
+ num_chunks = 1
608
+ if num_rows > self .moe_max_num_tokens * 2 :
609
+ num_chunks = (num_rows + self .moe_max_num_tokens -
610
+ 1 ) // self .moe_max_num_tokens
606
611
607
612
if use_dp_padding :
608
613
all_rank_num_tokens_padded = [all_rank_max_num_tokens
Original file line number Diff line number Diff line change @@ -81,7 +81,6 @@ def __init__(
81
81
self .num_experts )
82
82
self .expert_size_per_partition = self .expert_end - self .expert_start
83
83
84
- max_num_tokens = model_config .max_num_tokens
85
84
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
86
85
moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
87
86
self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
Original file line number Diff line number Diff line change @@ -150,7 +150,6 @@ def __init__(
150
150
assert len (
151
151
self .initial_local_expert_ids ) == self .expert_size_per_partition
152
152
153
- max_num_tokens = model_config .max_num_tokens
154
153
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
155
154
moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
156
155
self .moe_max_num_tokens = model_config .moe_max_num_tokens or moe_max_num_tokens
Original file line number Diff line number Diff line change @@ -372,6 +372,10 @@ def node_rank(self):
372
372
def local_rank (self ):
373
373
return self .rank % self .gpus_per_node
374
374
375
+ @property
376
+ def dp_size (self ):
377
+ return self .tp_size if self .enable_attention_dp else 1
378
+
375
379
def has_cp (self ):
376
380
return self .cp_size > 1
377
381
You can’t perform that action at this time.
0 commit comments