Skip to content

Conversation

SamComber
Copy link
Contributor

@SamComber SamComber commented Aug 18, 2025

Would be good to have an option to understand potential policy entropy collapse. Exhaustion is typically indicative of growing determinism in token decoding and diminishes exploration behaviour (see https://arxiv.org/pdf/2505.22617 paper)

image

WIP

Need to optimise so we don't create B*L*V intermediate tensors inside the entropy calc (i.e. basically 3x the memory required right now)... will likely batch it and put a warning flag when calculate entropy is enabled ??

@SamComber SamComber marked this pull request as draft August 18, 2025 15:52
@SamComber SamComber changed the title Add policy entropy WIP (need to test thoroughly): Add policy entropy Aug 18, 2025
@SamComber SamComber changed the title WIP (need to test thoroughly): Add policy entropy WIP: Add policy entropy Aug 18, 2025
@SamComber SamComber changed the title WIP: Add policy entropy [Draft]: Add policy entropy Aug 18, 2025
@SamComber SamComber changed the title [Draft]: Add policy entropy [Draft]: Add policy entropy metric Aug 18, 2025
@SamComber SamComber force-pushed the add-policy-entropy branch 3 times, most recently from bec0673 to 16a14d4 Compare August 18, 2025 19:54
logits = logits / self.temperature

if compute_entropy:
entropy = entropy_from_logits_memory_efficient(logits, chunk_size=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! if you want something aligned with trl, we have entropy_from_logits that you can directly import

https://github.com/cyyever/trl/blob/2ddf8010881e3bdf215be28ed6fd6f3a8ae2bcf5/trl/trainer/utils.py#L1469

it does the same as here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is perfect! Thanks for the suggestion

Copy link
Contributor Author

@SamComber SamComber Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still getting OOMs with chunk size 1 here in places where we definitely should not be hmm @qgallouedec , does entropy_from_logits internally need to torch.no_grad() perhaps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants