Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 15, 2025

This PR refactors RewardTrainer

closes #3780 #3101 #4104 #2633 #2758

  • Aims for an implementation that better aligns with SFTTrainer, improving long-term maintainability
  • Streamlines usage for end-users (see readme diff)
  • Better test coverage: expanded from 8 to 29 cases
  • No regressions or breaking changes expected
  • Better documentation: expanded from 94 to 245 lines
  • Now works seamlessly with the CLI: trl reward ...
  • Now includes its own model card
  • Now supports activation offloading
  • While caution is always good, this trainer isn’t heavily used, so the risk of major disruption is minimal (and we’re 99% confident nothing’s broken)
  • The goal is to have the same refactor for other trainer (KTO, DPO, PRM, ...)

blue: before refactor
red: after refactor

Screenshot 2025-09-22 at 11 33 10 AM

to have the same base model:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", num_labels=1)
tokenizer.save_pretrained("Qwen2.5-0.5B-Reward-Base")
model.saved_pretrained("Qwen2.5-0.5B-Reward-Base")
# Code user before refactoring
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-Reward-Base")
model = AutoModelForSequenceClassification.from_pretrained("Qwen2.5-0.5B-Reward-Base", num_labels=1)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:2000]")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", learning_rate=0.0001, logging_steps=10)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train() # + needed to implement accuracy logging


# Code user after ractroing
from trl import RewardTrainer, RewardConfig
from datasets import load_dataset

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:2000]")

trainer = RewardTrainer(
    model="Qwen2.5-0.5B-Reward-Base",
    train_dataset=dataset,
    args=RewardConfig(output_dir="Qwen2.5-0.5B-Reward", learning_rate=0.0001, logging_steps=10),
)
trainer.train()

@qgallouedec qgallouedec changed the base branch from main to support-reward-refactor September 15, 2025 21:31
@qgallouedec qgallouedec changed the base branch from support-reward-refactor to main September 16, 2025 04:34
@qgallouedec qgallouedec changed the base branch from main to support-seq-cls-clone-chat September 16, 2025 17:30
Base automatically changed from support-seq-cls-clone-chat to main September 30, 2025 17:42
Copy link
Collaborator

@kashif kashif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 typo but all else is good from my side! great!

@qgallouedec qgallouedec merged commit da209f8 into main Sep 30, 2025
11 of 12 checks passed
@qgallouedec qgallouedec deleted the reward-refactor branch September 30, 2025 21:13
@qgallouedec qgallouedec linked an issue Oct 5, 2025 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add support for padding-free reward modeling training in TRL Reward

5 participants