From a5e6b1e1094d478e0dddedefce997558aa2c765e Mon Sep 17 00:00:00 2001 From: Kunbo Ding Date: Fri, 15 Aug 2025 11:12:37 +0800 Subject: [PATCH] Unified FuseHeadLoss --- paddlenlp/rl/models/ppo_model_utils.py | 18 ++++++++++++++++++ paddlenlp/transformers/qwen2/modeling.py | 17 ++++------------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/paddlenlp/rl/models/ppo_model_utils.py b/paddlenlp/rl/models/ppo_model_utils.py index 601e977591f7..86e016512072 100644 --- a/paddlenlp/rl/models/ppo_model_utils.py +++ b/paddlenlp/rl/models/ppo_model_utils.py @@ -30,6 +30,12 @@ from paddle.distributed.fleet.layers.mpu import mp_ops from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp +except: + pass + + from ...transformers.llama.modeling import ( LlamaPretrainingCriterion as PretrainingCriterion, ) @@ -436,11 +442,23 @@ def forward( logits = logits / self.temperature if self.temperature > 0.0 else logits else: hidden_states, weight, bias, transpose_y = logits + + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape( + [ + input_ids.shape[0], + -1, + hidden_states.shape[-1], + ] + ) + if self.use_fp32_compute and hidden_states.dtype != paddle.float32: hidden_states = hidden_states.cast(paddle.float32) weight = weight.cast(paddle.float32) if bias is not None: bias = bias.cast(paddle.float32) + total_loss, pg_loss, entropy_loss, kl_loss = actor_fused_pg_entropy_kl_loss( hidden_states, weight, diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 9328229f8702..fd422054cd2d 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -1484,6 +1484,10 @@ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=Fals self.weight.split_axis = 0 if self.transpose_y else 1 def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None): + # add this for fused_head_and_loss_fn + if self.config.use_fused_head_and_loss_fn: + return hidden_states, self.weight, None, self.transpose_y + if self.config.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size]) @@ -1670,19 +1674,6 @@ def forward( hidden_states = outputs[0] - # add this for fused_head_and_loss_fn - if self.config.use_fused_head_and_loss_fn and self.training: - if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: - hidden_states = GatherOp.apply(hidden_states) - hidden_states = hidden_states.reshape( - [ - batch_size, - -1, - hidden_states.shape[-1], - ] - ) - return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y - # if labels is None,means we need full output, instead of tensor_parallel_output # tensor_parallel_output is together with ParallelCrossEntropy tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1