Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddlenlp/rl/models/ppo_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
from paddle.distributed.fleet.layers.mpu import mp_ops
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp
except:
pass


from ...transformers.llama.modeling import (
LlamaPretrainingCriterion as PretrainingCriterion,
)
Expand Down Expand Up @@ -436,11 +442,23 @@ def forward(
logits = logits / self.temperature if self.temperature > 0.0 else logits
else:
hidden_states, weight, bias, transpose_y = logits

if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = hidden_states.reshape(
[
input_ids.shape[0],
-1,
hidden_states.shape[-1],
]
)

if self.use_fp32_compute and hidden_states.dtype != paddle.float32:
hidden_states = hidden_states.cast(paddle.float32)
weight = weight.cast(paddle.float32)
if bias is not None:
bias = bias.cast(paddle.float32)

total_loss, pg_loss, entropy_loss, kl_loss = actor_fused_pg_entropy_kl_loss(
hidden_states,
weight,
Expand Down
17 changes: 4 additions & 13 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,10 @@ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=Fals
self.weight.split_axis = 0 if self.transpose_y else 1

def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):
# add this for fused_head_and_loss_fn
if self.config.use_fused_head_and_loss_fn:
return hidden_states, self.weight, None, self.transpose_y

if self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
Expand Down Expand Up @@ -1670,19 +1674,6 @@ def forward(

hidden_states = outputs[0]

# add this for fused_head_and_loss_fn
if self.config.use_fused_head_and_loss_fn and self.training:
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = hidden_states.reshape(
[
batch_size,
-1,
hidden_states.shape[-1],
]
)
return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is together with ParallelCrossEntropy
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1
Expand Down
Loading