Skip to content

Commit 816ba4e

Browse files
token_dispatcher support expert_num 64 (#10905)
* token_dispatcher support expert_num 64 * token_dispatcher support expert_num 64
1 parent b2bb5d7 commit 816ba4e

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

paddlenlp/trainer/trainer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
12561256
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
12571257
)
12581258

1259+
12591260
def parse_nccl_config_file(config_dir):
12601261
json_file = Path(config_dir)
12611262
if json_file.exists():

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,9 +1223,7 @@ def __post_init__(self):
12231223
if self.sharding_parallel_degree == -1:
12241224
if len(self.sharding) > 0:
12251225
self.sharding_parallel_degree = world_size // (
1226-
tensor_parallel_degree
1227-
* sep_parallel_degree
1228-
* pipeline_parallel_degree
1226+
tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
12291227
)
12301228

12311229
sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
@@ -1234,10 +1232,7 @@ def __post_init__(self):
12341232
self.sharding = []
12351233

12361234
self.data_parallel_degree = world_size // (
1237-
sharding_parallel_degree
1238-
* tensor_parallel_degree
1239-
* sep_parallel_degree
1240-
* pipeline_parallel_degree
1235+
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
12411236
)
12421237

12431238
if expert_parallel_degree > 1:
@@ -1513,7 +1508,9 @@ def is_segment_parallel_supported():
15131508
def is_context_parallel_supported():
15141509
import inspect
15151510

1516-
members = [name for (name, date) in inspect.getmembers(fleet.base.topology.EPHybridCommunicateGroup)]
1511+
members = [
1512+
name for (name, date) in inspect.getmembers(fleet.base.topology.EPHybridCommunicateGroup)
1513+
]
15171514
support_cp = "get_context_parallel_world_size" in members
15181515
if not support_cp:
15191516
logger.warning("context parallel is not supported!!! Ignore it.")
@@ -1714,9 +1711,7 @@ def is_context_parallel_supported():
17141711
if self.sharding_parallel_degree == -1:
17151712
if len(self.sharding) > 0:
17161713
self.sharding_parallel_degree = world_size // (
1717-
self.tensor_parallel_degree
1718-
* self.sep_parallel_degree
1719-
* self.pipeline_parallel_degree
1714+
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree
17201715
)
17211716

17221717
self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1)

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,13 @@ def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold):
191191
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
192192
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
193193
self.ema_buffer_model_params[index] = updated_ema
194-
logger.info(f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done")
194+
logger.info(
195+
f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done"
196+
)
195197
else:
196-
logger.info(f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}")
197-
198+
logger.info(
199+
f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}"
200+
)
198201

199202
@imperative_base.no_grad()
200203
def ema_state_dict(self):
@@ -788,9 +791,9 @@ def process_offload_task(self, dump, global_step):
788791

789792
if self.ema_coef is not None:
790793
self.zcc_ema_processor.ema_accumulate(
791-
self.trainer_state.global_step,
794+
self.trainer_state.global_step,
792795
self.trainer_state.loss,
793-
self.training_args_content.zcc_ema_loss_threshold
796+
self.training_args_content.zcc_ema_loss_threshold,
794797
)
795798

796799
# continue to process dumping task at the last chunk

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,6 @@ __device__ __forceinline__ void vectorized_memcpy(const T* src,
124124
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 8, __VA_ARGS__); \
125125
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 16, __VA_ARGS__); \
126126
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 32, __VA_ARGS__); \
127+
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 64, __VA_ARGS__); \
127128
PD_THROW("Unsupported expert number %d", int(__num_expert)); \
128129
} while (0)

0 commit comments

Comments
 (0)