@@ -241,6 +241,32 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
241
241
return torch .max (tensor [~ torch .isnan (tensor )])
242
242
243
243
244
+ def entropy_from_logits_memory_efficient (logits : torch .Tensor , chunk_size : int = 32 ):
245
+ """
246
+ Compute entropy by processing sequence positions in chunks.
247
+ Args:
248
+ logits: (B, L, V) tensor
249
+ chunk_size: Number of sequence positions to process at once
250
+ """
251
+ with torch .no_grad ():
252
+
253
+ B , L , V = logits .shape
254
+ entropy = torch .empty (B , L , device = logits .device , dtype = logits .dtype )
255
+
256
+ for start_idx in range (0 , L , chunk_size ):
257
+ end_idx = min (start_idx + chunk_size , L )
258
+ logits_chunk = logits [:, start_idx :end_idx , :] # (B, chunk_size, V)
259
+
260
+ # More memory-efficient entropy calculation
261
+ log_probs = torch .log_softmax (logits_chunk , dim = - 1 )
262
+ probs = torch .softmax (logits_chunk , dim = - 1 )
263
+ entropy_chunk = - (probs * log_probs ).sum (dim = - 1 ) # (B, chunk_size)
264
+
265
+ entropy [:, start_idx :end_idx ] = entropy_chunk
266
+
267
+ return entropy
268
+
269
+
244
270
class GRPOTrainer (Trainer ):
245
271
def __init__ (
246
272
self ,
@@ -496,6 +522,7 @@ def data_collator(features):
496
522
self .log_completions = args .log_completions
497
523
self .wandb_log_unique_prompts = args .wandb_log_unique_prompts
498
524
self .num_completions_to_print = args .num_completions_to_print
525
+ self .log_policy_entropy = args .log_policy_entropy
499
526
500
527
# Environment integration parameters
501
528
self .mask_env_responses = args .mask_env_responses
@@ -714,12 +741,14 @@ def _get_last_hidden_state(
714
741
715
742
# Get the per-token log probabilities for the completions for the model and the reference model
716
743
def _get_per_token_logps (
717
- self , model , input_ids , attention_mask , logits_to_keep , batch_size = None
718
- ) -> torch .Tensor :
744
+ self , model , input_ids , attention_mask , logits_to_keep , batch_size = None , compute_entropy = False
745
+ ) -> tuple [ torch .Tensor , torch . Tensor | None ] :
719
746
batch_size = batch_size or input_ids .size (
720
747
0
721
748
) # Chunk inputs into smaller batches to reduce memory peak
722
749
all_logps = []
750
+ all_entropies = []
751
+
723
752
for i in range (0 , input_ids .size (0 ), batch_size ):
724
753
input_ids_batch = input_ids [i : i + batch_size ]
725
754
attention_mask_batch = attention_mask [i : i + batch_size ]
@@ -731,18 +760,31 @@ def _get_per_token_logps(
731
760
logits = logits [
732
761
:, :- 1 , :
733
762
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
763
+
734
764
input_ids_batch = input_ids_batch [:, - logits_to_keep :]
735
765
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
736
766
# See https://github.com/huggingface/trl/issues/2770
737
767
logits = logits [:, - logits_to_keep :]
738
768
# Divide logits by sampling temperature.
739
769
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
740
770
logits = logits / self .temperature
771
+
772
+ if compute_entropy :
773
+ entropy = entropy_from_logits_memory_efficient (logits , chunk_size = 32 )
774
+ all_entropies .append (entropy )
775
+
741
776
logps = selective_log_softmax (
742
777
logits , input_ids_batch
743
778
) # compute logprobs for the input tokens
744
779
all_logps .append (logps )
745
- return torch .cat (all_logps , dim = 0 )
780
+
781
+ logps_result = torch .cat (all_logps , dim = 0 )
782
+
783
+ if compute_entropy :
784
+ entropies_result = torch .cat (all_entropies , dim = 0 )
785
+ return logps_result , entropies_result
786
+ else :
787
+ return logps_result , None
746
788
747
789
def _move_model_to_vllm (self ):
748
790
# For DeepSpeed ZeRO-3 we need to gather all parameters before operations
@@ -1187,8 +1229,8 @@ def compute_loss( # type: ignore
1187
1229
# prompt is at least 1 token
1188
1230
completion_mask = attention_mask [:, 1 :]
1189
1231
logits_to_keep = completion_mask .size (1 )
1190
- per_token_logps = self ._get_per_token_logps (
1191
- model , input_ids , attention_mask , logits_to_keep
1232
+ per_token_logps , per_token_entropy = self ._get_per_token_logps (
1233
+ model , input_ids , attention_mask , logits_to_keep , compute_entropy = self . log_policy_entropy
1192
1234
)
1193
1235
# Compute the loss
1194
1236
advantages = inputs ["advantages" ]
@@ -1218,12 +1260,12 @@ def compute_loss( # type: ignore
1218
1260
if self .beta != 0.0 :
1219
1261
with torch .no_grad ():
1220
1262
if self .ref_model is not None :
1221
- ref_per_token_logps = self ._get_per_token_logps (
1263
+ ref_per_token_logps , _ = self ._get_per_token_logps (
1222
1264
self .ref_model , input_ids , attention_mask , logits_to_keep
1223
1265
)
1224
1266
else :
1225
1267
with self .accelerator .unwrap_model (self .model ).disable_adapter (): # type: ignore
1226
- ref_per_token_logps = self ._get_per_token_logps (
1268
+ ref_per_token_logps , _ = self ._get_per_token_logps (
1227
1269
self .model , input_ids , attention_mask , logits_to_keep
1228
1270
)
1229
1271
per_token_kl = (
@@ -1281,6 +1323,17 @@ def compute_loss( # type: ignore
1281
1323
self ._metrics [mode ]["clip_ratio/region_mean" ].append (
1282
1324
gathered_clip_ratio .nanmean ().item () # type: ignore
1283
1325
)
1326
+
1327
+ if self .log_policy_entropy :
1328
+ masked_entropy = per_token_entropy * completion_mask
1329
+ total_completion_tokens = completion_mask .sum ()
1330
+ if total_completion_tokens > 0 :
1331
+ mean_entropy = masked_entropy .sum () / total_completion_tokens
1332
+ gathered_entropy = self .accelerator .gather_for_metrics (mean_entropy )
1333
+ self ._metrics [mode ]["entropy/mean" ].append (gathered_entropy .nanmean ().item ())
1334
+ self ._metrics [mode ]["entropy/min" ].append (nanmin (gathered_entropy ).item ())
1335
+ self ._metrics [mode ]["entropy/max" ].append (nanmax (gathered_entropy ).item ())
1336
+
1284
1337
return loss
1285
1338
1286
1339
def _sanitize_tool_calls (
@@ -1603,4 +1656,4 @@ def _log_completion_metrics_primary(
1603
1656
)
1604
1657
self ._metrics [mode ]["completions/max_terminated_length" ].append (
1605
1658
float (max (term_lengths ))
1606
- )
1659
+ )
0 commit comments