Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions verifiers/trainers/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ class GRPOConfig(TrainingArguments):
"all prompts are logged."
},
)
log_policy_entropy: bool = field(
default=True,
metadata={
"help": "Whether to log the policy entropy during training. If `True`, the policy entropy is logged to "
"`wandb` and printed to the console."
},
)

def __post_init__(self):
super().__post_init__()
Expand Down
48 changes: 39 additions & 9 deletions verifiers/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers.trainer_utils import seed_worker
from trl.models import create_reference_model, prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import disable_dropout_in_model, pad, selective_log_softmax
from trl.trainer.utils import disable_dropout_in_model, pad, selective_log_softmax, entropy_from_logits

from verifiers import Environment
from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest
Expand Down Expand Up @@ -241,6 +241,7 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
return torch.max(tensor[~torch.isnan(tensor)])



class GRPOTrainer(Trainer):
def __init__(
self,
Expand Down Expand Up @@ -496,6 +497,7 @@ def data_collator(features):
self.log_completions = args.log_completions
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
self.log_policy_entropy = args.log_policy_entropy

# Environment integration parameters
self.mask_env_responses = args.mask_env_responses
Expand Down Expand Up @@ -714,12 +716,14 @@ def _get_last_hidden_state(

# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(
self, model, input_ids, attention_mask, logits_to_keep, batch_size=None
) -> torch.Tensor:
self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False
) -> tuple[torch.Tensor, torch.Tensor | None]:
batch_size = batch_size or input_ids.size(
0
) # Chunk inputs into smaller batches to reduce memory peak
all_logps = []
all_entropies = []

for i in range(0, input_ids.size(0), batch_size):
input_ids_batch = input_ids[i : i + batch_size]
attention_mask_batch = attention_mask[i : i + batch_size]
Expand All @@ -731,18 +735,31 @@ def _get_per_token_logps(
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

input_ids_batch = input_ids_batch[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

if compute_entropy:
entropy = entropy_from_logits(logits)
all_entropies.append(entropy)

logps = selective_log_softmax(
logits, input_ids_batch
) # compute logprobs for the input tokens
all_logps.append(logps)
return torch.cat(all_logps, dim=0)

logps_result = torch.cat(all_logps, dim=0)

if compute_entropy:
entropies_result = torch.cat(all_entropies, dim=0)
return logps_result, entropies_result
else:
return logps_result, None

def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3 we need to gather all parameters before operations
Expand Down Expand Up @@ -1187,8 +1204,8 @@ def compute_loss( # type: ignore
# prompt is at least 1 token
completion_mask = attention_mask[:, 1:]
logits_to_keep = completion_mask.size(1)
per_token_logps = self._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep
per_token_logps, per_token_entropy = self._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep, compute_entropy=self.log_policy_entropy
)
# Compute the loss
advantages = inputs["advantages"]
Expand Down Expand Up @@ -1218,12 +1235,12 @@ def compute_loss( # type: ignore
if self.beta != 0.0:
with torch.no_grad():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
ref_per_token_logps, _ = self._get_per_token_logps(
self.ref_model, input_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore
ref_per_token_logps = self._get_per_token_logps(
ref_per_token_logps, _ = self._get_per_token_logps(
self.model, input_ids, attention_mask, logits_to_keep
)
per_token_kl = (
Expand Down Expand Up @@ -1281,6 +1298,19 @@ def compute_loss( # type: ignore
self._metrics[mode]["clip_ratio/region_mean"].append(
gathered_clip_ratio.nanmean().item() # type: ignore
)

if self.log_policy_entropy:
masked_entropy = per_token_entropy * completion_mask
total_completion_tokens = completion_mask.sum()

if total_completion_tokens > 0:
valid_entropy_values = masked_entropy[completion_mask.bool()]
gathered_entropy_values = self.accelerator.gather_for_metrics(valid_entropy_values)

self._metrics[mode]["entropy/mean"].append(gathered_entropy_values.nanmean().item())
self._metrics[mode]["entropy/min"].append(nanmin(gathered_entropy_values).item())
self._metrics[mode]["entropy/max"].append(nanmax(gathered_entropy_values).item())

return loss

def _sanitize_tool_calls(
Expand Down Expand Up @@ -1603,4 +1633,4 @@ def _log_completion_metrics_primary(
)
self._metrics[mode]["completions/max_terminated_length"].append(
float(max(term_lengths))
)
)
Loading