Skip to content

Commit fc6152d

Browse files
committed
Add policy entropy reporting
1 parent 9cf207a commit fc6152d

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

verifiers/trainers/grpo_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,13 @@ class GRPOConfig(TrainingArguments):
335335
"all prompts are logged."
336336
},
337337
)
338+
log_policy_entropy: bool = field(
339+
default=True,
340+
metadata={
341+
"help": "Whether to log the policy entropy during training. If `True`, the policy entropy is logged to "
342+
"`wandb` and printed to the console."
343+
},
344+
)
338345

339346
def __post_init__(self):
340347
super().__post_init__()

verifiers/trainers/grpo_trainer.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,32 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
241241
return torch.max(tensor[~torch.isnan(tensor)])
242242

243243

244+
def entropy_from_logits_memory_efficient(logits: torch.Tensor, chunk_size: int = 32):
245+
"""
246+
Compute entropy by processing sequence positions in chunks.
247+
Args:
248+
logits: (B, L, V) tensor
249+
chunk_size: Number of sequence positions to process at once
250+
"""
251+
with torch.no_grad():
252+
253+
B, L, V = logits.shape
254+
entropy = torch.empty(B, L, device=logits.device, dtype=logits.dtype)
255+
256+
for start_idx in range(0, L, chunk_size):
257+
end_idx = min(start_idx + chunk_size, L)
258+
logits_chunk = logits[:, start_idx:end_idx, :] # (B, chunk_size, V)
259+
260+
# More memory-efficient entropy calculation
261+
log_probs = torch.log_softmax(logits_chunk, dim=-1)
262+
probs = torch.softmax(logits_chunk, dim=-1)
263+
entropy_chunk = -(probs * log_probs).sum(dim=-1) # (B, chunk_size)
264+
265+
entropy[:, start_idx:end_idx] = entropy_chunk
266+
267+
return entropy
268+
269+
244270
class GRPOTrainer(Trainer):
245271
def __init__(
246272
self,
@@ -496,6 +522,7 @@ def data_collator(features):
496522
self.log_completions = args.log_completions
497523
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
498524
self.num_completions_to_print = args.num_completions_to_print
525+
self.log_policy_entropy = args.log_policy_entropy
499526

500527
# Environment integration parameters
501528
self.mask_env_responses = args.mask_env_responses
@@ -714,12 +741,14 @@ def _get_last_hidden_state(
714741

715742
# Get the per-token log probabilities for the completions for the model and the reference model
716743
def _get_per_token_logps(
717-
self, model, input_ids, attention_mask, logits_to_keep, batch_size=None
718-
) -> torch.Tensor:
744+
self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False
745+
) -> tuple[torch.Tensor, torch.Tensor | None]:
719746
batch_size = batch_size or input_ids.size(
720747
0
721748
) # Chunk inputs into smaller batches to reduce memory peak
722749
all_logps = []
750+
all_entropies = []
751+
723752
for i in range(0, input_ids.size(0), batch_size):
724753
input_ids_batch = input_ids[i : i + batch_size]
725754
attention_mask_batch = attention_mask[i : i + batch_size]
@@ -731,18 +760,31 @@ def _get_per_token_logps(
731760
logits = logits[
732761
:, :-1, :
733762
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
763+
734764
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
735765
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
736766
# See https://github.com/huggingface/trl/issues/2770
737767
logits = logits[:, -logits_to_keep:]
738768
# Divide logits by sampling temperature.
739769
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
740770
logits = logits / self.temperature
771+
772+
if compute_entropy:
773+
entropy = entropy_from_logits_memory_efficient(logits, chunk_size=32)
774+
all_entropies.append(entropy)
775+
741776
logps = selective_log_softmax(
742777
logits, input_ids_batch
743778
) # compute logprobs for the input tokens
744779
all_logps.append(logps)
745-
return torch.cat(all_logps, dim=0)
780+
781+
logps_result = torch.cat(all_logps, dim=0)
782+
783+
if compute_entropy:
784+
entropies_result = torch.cat(all_entropies, dim=0)
785+
return logps_result, entropies_result
786+
else:
787+
return logps_result, None
746788

747789
def _move_model_to_vllm(self):
748790
# For DeepSpeed ZeRO-3 we need to gather all parameters before operations
@@ -1187,8 +1229,8 @@ def compute_loss( # type: ignore
11871229
# prompt is at least 1 token
11881230
completion_mask = attention_mask[:, 1:]
11891231
logits_to_keep = completion_mask.size(1)
1190-
per_token_logps = self._get_per_token_logps(
1191-
model, input_ids, attention_mask, logits_to_keep
1232+
per_token_logps, per_token_entropy = self._get_per_token_logps(
1233+
model, input_ids, attention_mask, logits_to_keep, compute_entropy=self.log_policy_entropy
11921234
)
11931235
# Compute the loss
11941236
advantages = inputs["advantages"]
@@ -1218,12 +1260,12 @@ def compute_loss( # type: ignore
12181260
if self.beta != 0.0:
12191261
with torch.no_grad():
12201262
if self.ref_model is not None:
1221-
ref_per_token_logps = self._get_per_token_logps(
1263+
ref_per_token_logps, _ = self._get_per_token_logps(
12221264
self.ref_model, input_ids, attention_mask, logits_to_keep
12231265
)
12241266
else:
12251267
with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore
1226-
ref_per_token_logps = self._get_per_token_logps(
1268+
ref_per_token_logps, _ = self._get_per_token_logps(
12271269
self.model, input_ids, attention_mask, logits_to_keep
12281270
)
12291271
per_token_kl = (
@@ -1281,6 +1323,17 @@ def compute_loss( # type: ignore
12811323
self._metrics[mode]["clip_ratio/region_mean"].append(
12821324
gathered_clip_ratio.nanmean().item() # type: ignore
12831325
)
1326+
1327+
if self.log_policy_entropy:
1328+
masked_entropy = per_token_entropy * completion_mask
1329+
total_completion_tokens = completion_mask.sum()
1330+
if total_completion_tokens > 0:
1331+
mean_entropy = masked_entropy.sum() / total_completion_tokens
1332+
gathered_entropy = self.accelerator.gather_for_metrics(mean_entropy)
1333+
self._metrics[mode]["entropy/mean"].append(gathered_entropy.nanmean().item())
1334+
self._metrics[mode]["entropy/min"].append(nanmin(gathered_entropy).item())
1335+
self._metrics[mode]["entropy/max"].append(nanmax(gathered_entropy).item())
1336+
12841337
return loss
12851338

12861339
def _sanitize_tool_calls(
@@ -1603,4 +1656,4 @@ def _log_completion_metrics_primary(
16031656
)
16041657
self._metrics[mode]["completions/max_terminated_length"].append(
16051658
float(max(term_lengths))
1606-
)
1659+
)

0 commit comments

Comments
 (0)