Skip to content

Commit 67b4306

Browse files
authored
Unified FuseHeadLoss (#10954)
1 parent 80e13e8 commit 67b4306

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
from paddle.distributed.fleet.layers.mpu import mp_ops
3131
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
3232

33+
try:
34+
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp
35+
except:
36+
pass
37+
38+
3339
from ...transformers.llama.modeling import (
3440
LlamaPretrainingCriterion as PretrainingCriterion,
3541
)
@@ -436,11 +442,23 @@ def forward(
436442
logits = logits / self.temperature if self.temperature > 0.0 else logits
437443
else:
438444
hidden_states, weight, bias, transpose_y = logits
445+
446+
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
447+
hidden_states = GatherOp.apply(hidden_states)
448+
hidden_states = hidden_states.reshape(
449+
[
450+
input_ids.shape[0],
451+
-1,
452+
hidden_states.shape[-1],
453+
]
454+
)
455+
439456
if self.use_fp32_compute and hidden_states.dtype != paddle.float32:
440457
hidden_states = hidden_states.cast(paddle.float32)
441458
weight = weight.cast(paddle.float32)
442459
if bias is not None:
443460
bias = bias.cast(paddle.float32)
461+
444462
total_loss, pg_loss, entropy_loss, kl_loss = actor_fused_pg_entropy_kl_loss(
445463
hidden_states,
446464
weight,

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,10 @@ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=Fals
14841484
self.weight.split_axis = 0 if self.transpose_y else 1
14851485

14861486
def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):
1487+
# add this for fused_head_and_loss_fn
1488+
if self.config.use_fused_head_and_loss_fn:
1489+
return hidden_states, self.weight, None, self.transpose_y
1490+
14871491
if self.config.sequence_parallel:
14881492
hidden_states = GatherOp.apply(hidden_states)
14891493
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
@@ -1670,19 +1674,6 @@ def forward(
16701674

16711675
hidden_states = outputs[0]
16721676

1673-
# add this for fused_head_and_loss_fn
1674-
if self.config.use_fused_head_and_loss_fn and self.training:
1675-
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
1676-
hidden_states = GatherOp.apply(hidden_states)
1677-
hidden_states = hidden_states.reshape(
1678-
[
1679-
batch_size,
1680-
-1,
1681-
hidden_states.shape[-1],
1682-
]
1683-
)
1684-
return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y
1685-
16861677
# if labels is None,means we need full output, instead of tensor_parallel_output
16871678
# tensor_parallel_output is together with ParallelCrossEntropy
16881679
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1

0 commit comments

Comments
 (0)