diff --git a/README.md b/README.md
index 2df75017d02..b7a0668adc9 100644
--- a/README.md
+++ b/README.md
@@ -136,23 +136,13 @@ trainer.train()
 Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
 
 ```python
-from trl import RewardConfig, RewardTrainer
+from trl import RewardTrainer
 from datasets import load_dataset
-from transformers import AutoModelForSequenceClassification, AutoTokenizer
-
-tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
-model = AutoModelForSequenceClassification.from_pretrained(
-    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
-)
-model.config.pad_token_id = tokenizer.pad_token_id
 
 dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
 
-training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
 trainer = RewardTrainer(
-    args=training_args,
-    model=model,
-    processing_class=tokenizer,
+    model="Qwen/Qwen2.5-0.5B-Instruct",
     train_dataset=dataset,
 )
 trainer.train()
diff --git a/docs/source/clis.md b/docs/source/clis.md
index 54b7501c1aa..d6e433cada7 100644
--- a/docs/source/clis.md
+++ b/docs/source/clis.md
@@ -9,6 +9,7 @@ Currently supported commands are:
 - `trl dpo`: fine-tune a LLM with DPO
 - `trl grpo`: fine-tune a LLM with GRPO
 - `trl kto`: fine-tune a LLM with KTO
+- `trl reward`: train a Reward Model
 - `trl rloo`: fine-tune a LLM with RLOO
 - `trl sft`: fine-tune a LLM with SFT
 
@@ -41,6 +42,15 @@ trl dpo \
   --dataset_name anthropic/hh-rlhf
 ```
 
+
+
+
+```bash
+trl reward \
+  --model_name_or_path Qwen/Qwen2.5-0.5B \
+  --dataset_name trl-lib/ultrafeedback_binarized
+```
+
 
 
 
@@ -78,6 +88,21 @@ Launch with:
 trl dpo --config dpo_config.yaml
 ```
 
+
+
+
+```yaml
+# reward_config.yaml
+model_name_or_path: Qwen/Qwen2.5-0.5B
+dataset_name: trl-lib/ultrafeedback_binarized
+```
+
+Launch with:
+
+```bash
+trl reward --config reward_config.yaml
+```
+
 
 
 
@@ -138,6 +163,33 @@ Launch with:
 ```bash
 trl dpo --config dpo_config.yaml
 ```
+
+
+
+
+```bash
+trl reward \
+  --model_name_or_path Qwen/Qwen2.5-0.5B \
+  --dataset_name trl-lib/ultrafeedback_binarized \
+  --num_processes 4
+```
+
+
+
+
+```yaml
+# reward_config.yaml
+model_name_or_path: Qwen/Qwen2.5-0.5B
+dataset_name: trl-lib/ultrafeedback_binarized
+num_processes: 4
+```
+
+Launch with:
+
+```bash
+trl reward --config reward_config.yaml
+```
+
 
 
 
@@ -217,6 +269,33 @@ Launch with:
 ```bash
 trl dpo --config dpo_config.yaml
 ```
+
+
+
+
+```bash
+trl reward \
+  --model_name_or_path Qwen/Qwen2.5-0.5B \
+  --dataset_name trl-lib/ultrafeedback_binarized \
+  --accelerate_config zero2  # or path/to/my/accelerate/config.yaml
+```
+
+
+
+
+```yaml
+# reward_config.yaml
+model_name_or_path: Qwen/Qwen2.5-0.5B
+dataset_name: trl-lib/ultrafeedback_binarized
+accelerate_config: zero2  # or path/to/my/accelerate/config.yaml
+```
+
+Launch with:
+
+```bash
+trl reward --config reward_config.yaml
+```
+
 
 
 
@@ -224,7 +303,7 @@ trl dpo --config dpo_config.yaml
 
 You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
 
-
+
 
 
 ```yaml
@@ -258,6 +337,23 @@ Launch with:
 trl dpo --config dpo_config.yaml
 ```
 
+
+
+
+```yaml
+# reward_config.yaml
+model_name_or_path: Qwen/Qwen2.5-0.5B
+datasets:
+  - path: trl-lib/tldr-preference
+  - path: trl-lib/lm-human-preferences-sentiment
+```
+
+Launch with:
+
+```bash
+trl reward --config reward_config.yaml
+```
+
 
 
 
diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md
index 3ddfe06f72f..d4b69cb44a3 100644
--- a/docs/source/paper_index.md
+++ b/docs/source/paper_index.md
@@ -533,3 +533,53 @@ training_args = CPOConfig(
     ...
 )
 ```
+
+## Reward Modeling
+
+Papers relating to the [`RewardTrainer`]
+
+### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
+
+**š Paper**: https://huggingface.co/papers/2312.09244
+
+This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
+
+$$
+\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
+$$
+
+To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
+
+```python
+from trl import RewardConfig
+
+training_args = RewardConfig(
+    center_rewards_coefficient=0.01,  # Ī· in the paper
+    ...
+)
+```
+
+### Llama 2: Open Foundation and Fine-Tuned Chat Models
+
+**š Paper**: https://huggingface.co/papers/2307.09288
+
+In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
+
+$$
+\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
+$$
+
+You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
+
+```python
+def add_margin(example):
+    preference_to_margin = {
+        "significantly better": 1.0,
+        "better": 2.0/3.0,
+        "slightly better": 1.0/3.0,
+        "negligibly better / unsure": 0.0,
+    }
+    return {"margin": preference_to_margin[example["preference_label"]]}
+
+dataset = dataset.map(add_margin)
+```
diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md
index 908230131e0..4898019bd9c 100644
--- a/docs/source/quickstart.md
+++ b/docs/source/quickstart.md
@@ -1,6 +1,6 @@
 # Quickstart
 
-TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO),  Direct Preference Optimization (DPO).
+TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
 
 ## Quick Examples
 
@@ -51,6 +51,21 @@ trainer = DPOTrainer(
 trainer.train()
 ```
 
+### Reward Modeling
+
+```python
+from trl import RewardTrainer
+from datasets import load_dataset
+
+dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
+
+trainer = RewardTrainer(
+    model="Qwen/Qwen2.5-0.5B-Instruct",
+    train_dataset=dataset,
+)
+trainer.train()
+```
+
 ## Command Line Interface
 
 Skip the code entirely - train directly from your terminal:
@@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
 # DPO: Align with preferences  
 trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
     --dataset_name trl-lib/ultrafeedback_binarized
+
+# Reward: Train a reward model
+trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
+    --dataset_name trl-lib/ultrafeedback_binarized
 ```
 
 ## What's Next?
diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md
index 972c45b0c88..1179ab220f8 100644
--- a/docs/source/reward_trainer.md
+++ b/docs/source/reward_trainer.md
@@ -2,84 +2,225 @@
 
 [](https://huggingface.co/models?other=reward-trainer,trl)
 
-TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
+## Overview
 
-Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
+TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.
 
-## Expected dataset type
+This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
 
-The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
-The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
+## Quick start
 
-You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
+This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset.
 
-## Using the `RewardTrainer`
+```python
+from trl import RewardTrainer
+from datasets import load_dataset
 
-After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from š¤ Transformers.
-You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
+trainer = RewardTrainer(
+    model="Qwen/Qwen3-0.6B",
+    train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
+)
+trainer.train()
+```
 
-### Leveraging š¤ PEFT to train a reward model
+
 
-Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
+## Expected dataset type and format
+
+[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
 
 ```python
-from peft import LoraConfig, TaskType
-from transformers import AutoModelForSequenceClassification, AutoTokenizer
-from trl import RewardTrainer, RewardConfig
-
-model = AutoModelForSequenceClassification.from_pretrained("gpt2")
-peft_config = LoraConfig(
-    task_type=TaskType.SEQ_CLS,
-    inference_mode=False,
-    r=8,
-    lora_alpha=32,
-    lora_dropout=0.1,
-)
+# Standard preference (implicit prompt)
+{"chosen": "The sky is blue.",
+ "rejected": "The sky is green."}
+
+# Conversational preference (implicit prompt)
+{"chosen": [{"role": "user", "content": "What color is the sky?"},
+            {"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "user", "content": "What color is the sky?"},
+              {"role": "assistant", "content": "It is green."}]}
+
+# Standard preference (explicit prompt)
+{"prompt": "The sky is",
+ "chosen": " blue.",
+ "rejected": " green."}
+
+# Conversational preference (explicit prompt)
+{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "assistant", "content": "It is green."}]}
+```
 
-...
+If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset:
 
-trainer = RewardTrainer(
-    model=model,
-    args=training_args,
-    processing_class=tokenizer,
-    train_dataset=dataset,
-    peft_config=peft_config,
-)
+```python
+from datasets import load_dataset
+import json
 
-trainer.train()
+dataset = load_dataset("lmarena-ai/arena-human-preference-55k")
+
+# Filter out ties
+dataset = dataset.filter(lambda example: example["winner_tie"] == 0)
+
+# Create 'chosen' and 'rejected' fields based on the winner column
+def response_a_b_to_chosen_rejected(example):
+    if example["winner_model_a"] == 1:
+        example["chosen"] = example["response_a"]
+        example["rejected"] = example["response_b"]
+    else:
+        example["chosen"] = example["response_b"]
+        example["rejected"] = example["response_a"]
+    return example
+
+dataset = dataset.map(response_a_b_to_chosen_rejected)
 
+# Convert to conversational format
+def make_conversation(example):
+    prompt = json.loads(example["prompt"])[0]  # '["What color is the sky?"]' -> "What color is the sky?"
+    chosen = json.loads(example["chosen"])[0]
+    rejected = json.loads(example["rejected"])[0]
+    return {
+        "chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}],
+        "rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}],
+    }
+
+
+dataset = dataset.map(make_conversation)
+
+# Keep only necessary columns
+dataset = dataset.select_columns(["chosen", "rejected"])
+
+print(next(iter(dataset["train"])))
 ```
 
-### Adding a margin to the loss
+```json
+{
+    "chosen": [
+        {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
+        {"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."},
+    ],
+    "rejected": [
+        {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
+        {"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."},
+    ],
+}
+```
+
+## Looking deeper into the training method
+
+Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.
+
+This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**.
+
+### Preprocessing and tokenization
+
+During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference).
+The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.
+
+### Computing the loss
+
+Let  \\( x \\) be the input sequence (prompt) and  \\( y^+ \\) and  \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that  \\( y^+ \\) is preferred over  \\( y^- \\) given a reward function  \\( r \\) is  \\( p(y^+ ā» y^- |x) = \sigma(r(x, y^+)ār(x, y^-)) \\), where  \\( Ļ \\) is the sigmoid function.
+
+The reward model  \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses  \\( y^+ \\) over non-preferred ones  \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences:
+
+$$
+\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
+$$
+
+> [!TIP]
+> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`.
+
+## Logged metrics
 
-As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
+While training and evaluating we record the following reward metrics:
+
+* `global_step`: The total number of optimizer steps taken so far.
+* `epoch`: The current epoch number, based on dataset iteration.
+* `num_tokens`: The total number of tokens processed so far.
+* `loss`: The average loss over the last logging interval.
+* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.
+* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval.
+* `mean_reward`: The average reward score assigned by the model over the last logging interval.
+* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval.
+* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval.
+* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
+* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
+
+## Customization
+
+### Model initialization
+
+You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to
 
 ```python
-def add_margin(row):
-    # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
-    return {'margin': row['score_chosen'] - row['score_rejected']}
+model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
+```
 
-dataset = dataset.map(add_margin)
+you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`].
+
+```python
+from trl import RewardConfig
+
+training_args = RewardConfig(
+    model_init_kwargs={"dtype": torch.bfloat16},
+)
 ```
 
-### Centering rewards
+Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1.
+
+### Train adapters with PEFT
 
-In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
+We support tight integration with š¤ PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
 
-[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
+```python
+from datasets import load_dataset
+from trl import RewardTrainer
+from peft import LoraConfig
 
-$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
+dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
 
-This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
+trainer = RewardTrainer(
+    "Qwen/Qwen3-4B",
+    train_dataset=dataset,
+    peft_config=LoraConfig(modules_to_save=["score"])  # important to include the score head when base model is not a sequence classification model
+)
+
+trainer.train()
+```
+
+You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
 
 ```python
-training_args = RewardConfig(
-    center_rewards_coefficient=0.01,
-    ...
+from datasets import load_dataset
+from trl import RewardTrainer
+from peft import AutoPeftModelForCausalLM
+
+model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True)
+dataset = load_dataset("trl-lib/Capybara", split="train")
+
+trainer = RewardTrainer(
+    model=model,
+    train_dataset=dataset,
 )
+
+trainer.train()
 ```
 
-For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
+> [!TIP]
+> When training adapters, you typically use a higher learning rate (ā1eā3) since only new parameters are being learned.
+>
+> ```python
+> RewardConfig(learning_rate=1e-3, ...)
+> ```
+
+## Tool Calling with Reward Modeling
+
+The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
+
+* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
+* The list of available tools in the `tools` column, typically provided as JSON schemas
+
+For details on the expected dataset structure, see the [Dataset Format ā Tool Calling](dataset_formats#tool-calling) section.
 
 ## RewardTrainer
 
@@ -91,3 +232,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr
 ## RewardConfig
 
 [[autodoc]] RewardConfig
+
+## DataCollatoForPreference
+
+[[autodoc]] trainer.reward_trainer.DataCollatorForPreference
diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md
index cc292d65420..e855bd02b02 100644
--- a/docs/source/sft_trainer.md
+++ b/docs/source/sft_trainer.md
@@ -23,7 +23,7 @@ trainer = SFTTrainer(
 trainer.train()
 ```
 
-
+
 
 ## Expected dataset type and format
 
diff --git a/docs/source/trackio_integration.md b/docs/source/trackio_integration.md
index 9f5fe693a9f..4e93120fe19 100644
--- a/docs/source/trackio_integration.md
+++ b/docs/source/trackio_integration.md
@@ -64,4 +64,4 @@ trainer.train()
 
 will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
 
-
+
diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py
index 7a3295d2d20..6c9f09b2834 100644
--- a/scripts/generate_tiny_models.py
+++ b/scripts/generate_tiny_models.py
@@ -43,6 +43,7 @@
     GPT2LMHeadModel,
     GPTNeoXConfig,
     GPTNeoXForCausalLM,
+    GPTNeoXForSequenceClassification,
     GptOssConfig,
     GptOssForCausalLM,
     Idefics2Config,
@@ -73,6 +74,7 @@
     Qwen3ForSequenceClassification,
     Qwen3MoeConfig,
     Qwen3MoeForCausalLM,
+    Qwen3MoeForSequenceClassification,
     SmolVLMForConditionalGeneration,
     T5ForConditionalGeneration,
 )
@@ -234,22 +236,46 @@ def init_weights_tiny_model(model):
 push_to_hub(model, tokenizer, "small")
 
 # Reward models
-for model_id, config_class, model_class, suffix in [
-    ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"),
-    ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"),
-    ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None),
+for model_id, model_class, suffix in [
+    ("EleutherAI/pythia-14m", GPTNeoXForSequenceClassification, None),
+    ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"),
+    ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"),
+    ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None),
 ]:
     tokenizer = AutoTokenizer.from_pretrained(model_id)
-    config = config_class(
-        vocab_size=len(tokenizer.vocab),
-        hidden_size=8,
-        num_attention_heads=4,
-        num_key_value_heads=2,
-        num_hidden_layers=2,
-        intermediate_size=32,
-        num_labels=1,
-    )
-    model = model_class(config)
+    kwargs = {
+        "num_labels": 1,
+        "hidden_size": 16,
+        "num_attention_heads": 4,
+        "num_key_value_heads": 2,
+        "num_hidden_layers": 2,
+        "intermediate_size": 32,
+    }
+    config = AutoConfig.from_pretrained(model_id, **kwargs)
+    # Bug in transformers: it ignores num_hidden_layers to build layer_types
+    if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"):
+        config.layer_types = config.layer_types[:2]
+    model = model_class(config).to(dtype=torch.bfloat16)
+    init_weights_tiny_model(model)
+    push_to_hub(model, tokenizer, "tiny", suffix)
+
+# MoE Reward models
+for model_id, model_class, suffix in [
+    ("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None),
+]:
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    kwargs = {
+        "num_labels": 1,
+        "hidden_size": 16,
+        "num_attention_heads": 4,
+        "num_key_value_heads": 2,
+        "num_hidden_layers": 2,
+        "intermediate_size": 32,
+        "num_experts": 4,
+        "num_experts_per_tok": 2,
+    }
+    config = AutoConfig.from_pretrained(model_id, **kwargs)
+    model = model_class(config).to(dtype=torch.bfloat16)
     push_to_hub(model, tokenizer, "tiny", suffix)
 
 
@@ -315,7 +341,5 @@ def init_weights_tiny_model(model):
         kwargs["perceiver_config"] = {"hidden_size": 16}
 
     config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
-
     model = model_class(config).to(dtype=torch.bfloat16)
-
     push_to_hub(model, processor, "tiny")
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 2f8891c7333..23b5d6bcff7 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -67,6 +67,13 @@ def test_kto(self):
         with patch("sys.argv", command.split(" ")):
             main()
 
+    def test_reward(self):
+        from trl.cli import main
+
+        command = f"trl reward --output_dir {self.tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_implicit_prompt_preference --report_to none"
+        with patch("sys.argv", command.split(" ")):
+            main()
+
     def test_rloo(self):
         from trl.cli import main
 
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index a92c2c8cb0f..b4d53e16941 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -12,217 +12,823 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import pathlib
+import unittest
 
 import torch
-from datasets import Dataset, load_dataset
+from datasets import load_dataset
+from parameterized import parameterized
 from transformers import AutoModelForSequenceClassification, AutoTokenizer
 from transformers.testing_utils import require_peft
 from transformers.utils import is_peft_available
 
-from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
-from trl.trainer.reward_trainer import _tokenize
+from trl import RewardConfig, RewardTrainer
+from trl.trainer.reward_trainer import DataCollatorForPreference
 
 from .testing_utils import TrlTestCase
 
 
 if is_peft_available():
-    from peft import LoraConfig, TaskType
+    from peft import LoraConfig, PeftModel, get_peft_model
+
+
+class TestDataCollatorForPreference(TrlTestCase):
+    def test_basic_padding(self):
+        """Test basic padding functionality without completion masks."""
+        self.collator = DataCollatorForPreference(pad_token_id=0)
+        examples = [
+            {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
+            {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
+        ]
+
+        result = self.collator(examples)
+
+        torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
+        torch.testing.assert_close(
+            result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
+        )
+
+    def test_pad_to_multiple_of(self):
+        """Test padding to multiple of specified value."""
+        collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4)
+        examples = [
+            {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
+            {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
+        ]
+
+        result = collator(examples)
+
+        torch.testing.assert_close(
+            result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]])
+        )
+        torch.testing.assert_close(
+            result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
+        )
+
+    def test_single_example(self):
+        """Test collator with a single example."""
+        self.collator = DataCollatorForPreference(pad_token_id=0)
+        examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}]
+
+        result = self.collator(examples)
+
+        torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
+        torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
+
+    def test_different_pad_token_id(self):
+        """Test with different pad token ID."""
+        collator = DataCollatorForPreference(pad_token_id=999)
+        examples = [
+            {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
+            {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
+        ]
+
+        result = collator(examples)
+
+        torch.testing.assert_close(
+            result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]])
+        )
+        torch.testing.assert_close(
+            result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
+        )
+
+    def test_collate_with_margin(self):
+        self.collator = DataCollatorForPreference(pad_token_id=0)
+        examples = [
+            {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1},
+            {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2},
+        ]
+
+        result = self.collator(examples)
+
+        torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
+        torch.testing.assert_close(
+            result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
+        )
+        torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2]))
 
 
 class RewardTrainerTester(TrlTestCase):
-    def setUp(self):
-        super().setUp()
-        self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
-        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
-        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
-        self.model.config.pad_token_id = self.tokenizer.pad_token_id
-
-    def test_preprocessing_conversational(self):
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+    @parameterized.expand(
+        [
+            ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",),
+            ("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification",),
+            ("trl-internal-testing/tiny-LlamaForSequenceClassification-3.2",),
+        ]
+    )
+    def test_train(self, model_id):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+        trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    @parameterized.expand(
+        [
+            ("standard_preference",),
+            ("conversational_preference",),
+            ("standard_implicit_prompt_preference",),
+            ("conversational_implicit_prompt_preference",),
+        ]
+    )
+    def test_train_dataset_types(self, config_name):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
+
+        # Initialize the trainer
         training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
         )
-        dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
-        dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
-        self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
 
-    def test_preprocessing_standard(self):
-        # No chat template, so we load a fresh tokenizer
-        tokenizer = AutoTokenizer.from_pretrained(self.model_id)
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_model(self):
+        # Instantiate the model
+        model = AutoModelForSequenceClassification.from_pretrained(
+            "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
+        )
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+        trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_from_causal_lm(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
         training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset
+            model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
         )
-        dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer})
-        self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
 
-    def test_train_full(self):
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
-        training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_model_dtype(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(
+            output_dir=self.tmp_dir,
+            model_init_kwargs={"dtype": torch.float16},
+            learning_rate=0.1,
+            report_to="none",
+        )
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
         )
+
+        # Save the initial parameters to compare them later
         previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
         trainer.train()
 
+        # Check that the training loss is not None
         self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
-        # Check that the parameters have changed
+
+        # Check the params have changed
         for n, param in previous_trainable_params.items():
+            # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does
+            # locally. We ignore this parameter for now
+            if "layernorm" in n:
+                continue
             new_param = trainer.model.get_parameter(n)
-            if param.sum() != 0:  # ignore 0 biases
-                self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+            # Check the torch dtype
+            self.assertEqual(new_param.dtype, torch.float16)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    @require_peft
+    def test_train_dense_with_peft_config(self):
+        # Get the base model parameter names
+        model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
 
-    def test_train_full_pretokenized(self):
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
-        dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
-        dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
-        training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+            model=model_id,
+            args=training_args,
+            train_dataset=dataset,
+            peft_config=LoraConfig(),
         )
+
+        # Save the initial parameters to compare them later
         previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
         trainer.train()
 
+        # Check that the training loss is not None
         self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
-        # Check that the parameters have changed
+
+        # Check the peft params have changed and the base model params have not changed
         for n, param in previous_trainable_params.items():
             new_param = trainer.model.get_parameter(n)
-            if param.sum() != 0:  # ignore 0 biases
-                self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
 
     @require_peft
-    def test_train_lora(self):
-        peft_config = LoraConfig(
-            task_type=TaskType.SEQ_CLS,
-            inference_mode=False,
-            r=8,
-            lora_alpha=32,
-            lora_dropout=0.1,
-        )
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
-        training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
+    def test_train_moe_with_peft_config(self):
+        # Get the base model parameter names
+        model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+
         trainer = RewardTrainer(
-            model=self.model,
+            model=model_id,
             args=training_args,
-            processing_class=self.tokenizer,
-            train_dataset=dummy_dataset,
-            peft_config=peft_config,
+            train_dataset=dataset,
+            peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
         )
-        previous_trainable_params = {}
-        previous_non_trainable_params = {}
 
-        # due to a change in the way the modules to save are dealt in PEFT.
-        trainable_params_name = ["lora", "modules_to_save"]
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the peft params have changed and the base model params have not changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    @require_peft
+    def test_train_peft_model(self):
+        # Get the base model
+        model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+
+        # Get the base model parameter names
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+
+        # Turn the model into a peft model
+        lora_config = LoraConfig()
+        model = get_peft_model(model, lora_config)
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+        trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the peft params have changed and the base model params have not changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    @require_peft
+    def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
+        # Get the base model parameter names
+        model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
+
+        trainer = RewardTrainer(
+            model=model_id,
+            args=training_args,
+            train_dataset=dataset,
+            peft_config=LoraConfig(),
+        )
 
-        # check gradients are not None
-        for n, param in trainer.model.named_parameters():
-            if any(t in n for t in trainable_params_name):
-                previous_trainable_params[n] = param.clone()
-            else:
-                previous_non_trainable_params[n] = param.clone()
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
 
+        # Train the model
         trainer.train()
 
-        self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
 
-        # Check that the parameters have changed
+        # Check the peft params have changed and the base model params have not changed
         for n, param in previous_trainable_params.items():
             new_param = trainer.model.get_parameter(n)
-            self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    @require_peft
+    def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
+        # Get the base model parameter names
+        model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
 
-        # Check that the non trainable parameters have not changed
-        for n, param in previous_non_trainable_params.items():
+        trainer = RewardTrainer(
+            model=model_id,
+            args=training_args,
+            train_dataset=dataset,
+            peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the peft params have changed and the base model params have not changed
+        for n, param in previous_trainable_params.items():
             new_param = trainer.model.get_parameter(n)
-            self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
 
     @require_peft
-    def test_train_lora_pretokenized(self):
-        peft_config = LoraConfig(
-            task_type=TaskType.SEQ_CLS,
-            inference_mode=False,
-            r=8,
-            lora_alpha=32,
-            lora_dropout=0.1,
-        )
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
-        dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
-        dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
+    def test_train_with_peft_model_and_gradient_checkpointing(self):
+        # Get the base model parameter names
+        model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
+        model = AutoModelForSequenceClassification.from_pretrained(model_id)
+        base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
+        model = get_peft_model(model, LoraConfig())
+
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
+
+        trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
+
+        # Verify model is a PeftModel
+        self.assertIsInstance(trainer.model, PeftModel)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the peft params have changed and the base model params have not changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            if n in base_param_names:  # We expect the base model parameters to be the same
+                self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
+            elif "base_layer" not in n:  # We expect the peft parameters to be different (except for the base layer)
+                self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_pretokenized_data(self):
+        # Get the dataset
+        model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
+        tokenizer = AutoTokenizer.from_pretrained(model_id)
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        def tokenize_example(example):
+            return {
+                "chosen_input_ids": tokenizer(example["chosen"]).input_ids,
+                "rejected_input_ids": tokenizer(example["rejected"]).input_ids,
+            }
+
+        # Apply tokenization
+        tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"])
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+        trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_iterable_dataset(self):
+        # Get the dataset
+        dataset = load_dataset(
+            "trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True
+        )
+
+        # Initialize the trainer
         training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
         trainer = RewardTrainer(
-            model=self.model,
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_chat_template_kwargs(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
+
+        tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5")
+        # The following template is a simplified version of the Qwen chat template, where an additional argument
+        # `role_capital` is used to control the capitalization of roles.
+        tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%}    {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%}    {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%}    {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%}        {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }}    {%- elif message.role == "assistant" -%}        {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }}        {%- if message.content -%}            {{ "\\n" + message.content }}        {%- endif -%}        {{ "<|im_end|>\\n" }}    {%- elif message.role == "tool" -%}        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%}            {{ "<|im_start|>" + ("USER" if role_capital else "user") }}        {%- endif -%}        {{ "\\n\\n" + message.content + "\\n" }}        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%}            {{ "<|im_end|>\\n" }}        {%- endif -%}    {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}    {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}'
+
+        dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))])
+
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
             args=training_args,
-            processing_class=self.tokenizer,
-            train_dataset=dummy_dataset,
-            peft_config=peft_config,
+            train_dataset=dataset,
         )
-        previous_trainable_params = {}
-        previous_non_trainable_params = {}
 
-        # due to a change in the way the modules to save are dealt in PEFT.
-        trainable_params_name = ["lora", "modules_to_save"]
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
 
-        # check gradients are not None
-        for n, param in trainer.model.named_parameters():
-            if any(t in n for t in trainable_params_name):
-                previous_trainable_params[n] = param.clone()
-            else:
-                previous_non_trainable_params[n] = param.clone()
+    def test_train_with_set_chat_template_from_model(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
 
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
+        # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
+            args=training_args,
+            train_dataset=dataset,
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
         trainer.train()
 
-        self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
 
-        # Check that the parameters have changed
+        # Check the params have changed
         for n, param in previous_trainable_params.items():
             new_param = trainer.model.get_parameter(n)
-            self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+            # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
+            # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
+            # this parameter.
+            if n == "gpt_neox.final_layer_norm.bias":
+                continue
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_set_chat_template_from_path(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(
+            output_dir=self.tmp_dir,
+            chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
+            report_to="none",
+        )
+        # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
+            args=training_args,
+            train_dataset=dataset,
+        )
 
-        # Check that the non trainable parameters have not changed
-        for n, param in previous_non_trainable_params.items():
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
             new_param = trainer.model.get_parameter(n)
-            self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
+            # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
+            # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
+            # this parameter.
+            if n == "gpt_neox.final_layer_norm.bias":
+                continue
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+        # Check that the template saved in the output directory is the same as the one used for training
+        template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
+        self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}")
+
+        with open(template_path) as f:
+            template_content = f.read()
+        with open(training_args.chat_template_path) as f:
+            original_template_content = f.read()
+        self.assertEqual(
+            template_content, original_template_content, "Chat template content does not match the original"
+        )
+
+    @unittest.skip("Skipping until we have a dataset with tool calls")
+    def test_train_toolcall_data(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/toolcall", split="train")
 
-    def test_margin(self):
-        dummy_dataset_dict = {
-            "input_ids_chosen": [
-                torch.LongTensor([0, 1, 2]),
-            ],
-            "attention_mask_chosen": [
-                torch.LongTensor([1, 1, 1]),
-            ],
-            "input_ids_rejected": [
-                torch.LongTensor([0, 2]),
-            ],
-            "attention_mask_rejected": [
-                torch.LongTensor([1, 1]),
-            ],
-            "margin": [
-                torch.FloatTensor([1.0]),
-            ],
-        }
-        dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
+        # Initialize the trainer
         training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_eval(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset["train"],
+            eval_dataset=dataset["test"],
+        )
+
+        # Train the model
+        trainer.train()
+
+        # Check that the eval loss is not None
+        self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
+
+    def test_train_with_multiple_eval_dataset(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset["train"],
+            eval_dataset={"data1": dataset["test"], "data2": dataset["test"]},
+        )
+        # Train the model
+        trainer.train()
+
+        # Check that the eval losses are not None
+        self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"])
+        self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"])
+
+    def test_train_with_gradient_checkpointing(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_tag_added(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            train_dataset=dataset,
+        )
+
+        for tag in ["reward-trainer", "trl"]:
+            self.assertIn(tag, trainer.model.model_tags)
+
+    @require_peft
+    def test_tag_added_peft(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            train_dataset=dataset,
+            peft_config=LoraConfig(),
         )
 
-        batch = [dummy_dataset[0]]
-        batch = trainer.data_collator(batch)
-        batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
-        loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True)
+        for tag in ["reward-trainer", "trl"]:
+            self.assertIn(tag, trainer.model.model_tags)
+
+    def test_train_with_margin(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
 
-        l_val = -torch.nn.functional.logsigmoid(
-            outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
-        ).mean()
+        def add_margin(example):
+            # dummy margin based on the length of the chosen summary
+            return {"margin": len(example["chosen"])}
 
-        self.assertLess(abs(loss - l_val), 1e-6)
+        dataset = dataset.map(add_margin)
 
-    def test_tags(self):
-        dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
+        # Initialize the trainer
         training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
         trainer = RewardTrainer(
-            model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
         )
-        self.assertEqual(trainer.model.model_tags, trainer._tag_names)
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+    def test_train_with_center_rewards_coefficient(self):
+        # Get the dataset
+        dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
+
+        # Initialize the trainer
+        training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none")
+        trainer = RewardTrainer(
+            model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+            args=training_args,
+            train_dataset=dataset,
+        )
+
+        # Save the initial parameters to compare them later
+        previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+        # Train the model
+        trainer.train()
+
+        # Check that the training loss is not None
+        self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+        # Check the params have changed
+        for n, param in previous_trainable_params.items():
+            new_param = trainer.model.get_parameter(n)
+            self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
diff --git a/trl/cli.py b/trl/cli.py
index 9245f31dabd..199b1c26703 100644
--- a/trl/cli.py
+++ b/trl/cli.py
@@ -24,6 +24,7 @@
 from .scripts.env import print_env
 from .scripts.grpo import make_parser as make_grpo_parser
 from .scripts.kto import make_parser as make_kto_parser
+from .scripts.reward import make_parser as make_reward_parser
 from .scripts.rloo import make_parser as make_rloo_parser
 from .scripts.sft import make_parser as make_sft_parser
 from .scripts.utils import TrlParser
@@ -45,6 +46,7 @@ def main():
     subparsers.add_parser("env", help="Print the environment information")
     make_grpo_parser(subparsers)
     make_kto_parser(subparsers)
+    make_reward_parser(subparsers)
     make_rloo_parser(subparsers)
     make_sft_parser(subparsers)
     make_vllm_serve_parser(subparsers)
@@ -111,6 +113,15 @@ def main():
         args.training_script_args = sys.argv[2:]  # remove "trl" and "kto"
         launch_command(args)  # launch training
 
+    elif args.command == "reward":
+        # Get the default args for the launch command
+        reward_training_script = resources.files("trl.scripts").joinpath("reward.py")
+        args = launch_command_parser().parse_args([str(reward_training_script)])
+
+        # Feed the args to the launch command
+        args.training_script_args = sys.argv[2:]  # remove "trl" and "reward"
+        launch_command(args)  # launch training
+
     elif args.command == "rloo":
         # Get the default args for the launch command
         rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py")
diff --git a/trl/models/utils.py b/trl/models/utils.py
index 22ea572ca06..efdba75fbda 100644
--- a/trl/models/utils.py
+++ b/trl/models/utils.py
@@ -94,11 +94,8 @@ def setup_chat_format(
     Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
     embedding layer of the model based on the new special tokens.
 
-    
-
-    This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
-
-    
+    > [!WARNING]
+    > This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
 
     If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
     `tokenizer.chat_template` to `None`.
diff --git a/trl/scripts/reward.py b/trl/scripts/reward.py
new file mode 100644
index 00000000000..f34b04e80ee
--- /dev/null
+++ b/trl/scripts/reward.py
@@ -0,0 +1,109 @@
+# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+#     "trl",
+#     "peft",
+#     "trackio",
+#     "kernels",
+# ]
+# ///
+
+import argparse
+import os
+from typing import Optional
+
+from accelerate import logging
+from datasets import load_dataset
+
+from trl import (
+    DatasetMixtureConfig,
+    ModelConfig,
+    RewardConfig,
+    RewardTrainer,
+    ScriptArguments,
+    TrlParser,
+    get_dataset,
+    get_peft_config,
+)
+
+
+logger = logging.get_logger(__name__)
+
+# Enable logging in a Hugging Face Space
+os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
+
+
+def main(script_args, training_args, model_args, dataset_args):
+    # Load the dataset
+    if dataset_args.datasets and script_args.dataset_name:
+        logger.warning(
+            "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
+            "dataset and `dataset_name` will be ignored."
+        )
+        dataset = get_dataset(dataset_args)
+    elif dataset_args.datasets and not script_args.dataset_name:
+        dataset = get_dataset(dataset_args)
+    elif not dataset_args.datasets and script_args.dataset_name:
+        dataset = load_dataset(
+            script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
+        )
+    else:
+        raise ValueError("Either `datasets` or `dataset_name` must be provided.")
+
+    # Initialize the RewardTrainer
+    trainer = RewardTrainer(
+        model=model_args.model_name_or_path,
+        args=training_args,
+        train_dataset=dataset[script_args.dataset_train_split],
+        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
+        peft_config=get_peft_config(model_args),
+    )
+
+    # Train the model
+    trainer.train()
+
+    # Log training complete
+    trainer.accelerator.print("ā
 Training completed.")
+
+    # Save and push to Hub
+    trainer.save_model(training_args.output_dir)
+    trainer.accelerator.print(f"š¾ Model saved to {training_args.output_dir}.")
+
+    if training_args.push_to_hub:
+        trainer.push_to_hub(dataset_name=script_args.dataset_name)
+        trainer.accelerator.print(f"š¤ Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
+
+
+def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
+    dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig)
+    if subparsers is not None:
+        parser = subparsers.add_parser(
+            "reward", help="Run the reward training script", dataclass_types=dataclass_types
+        )
+    else:
+        parser = TrlParser(dataclass_types)
+    return parser
+
+
+if __name__ == "__main__":
+    parser = make_parser()
+    # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
+    # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
+    # `return_remaining_strings=True`, then ignore the remaining strings.
+    script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
+        return_remaining_strings=True
+    )
+    main(script_args, training_args, model_args, dataset_args)
diff --git a/trl/templates/rm_model_card.md b/trl/templates/rm_model_card.md
new file mode 100644
index 00000000000..685ed007bd5
--- /dev/null
+++ b/trl/templates/rm_model_card.md
@@ -0,0 +1,55 @@
+---
+{{ card_data }}
+---
+
+# Model Card for {{ model_name }}
+
+This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
+It has been trained using [TRL](https://github.com/huggingface/trl).
+
+## Quick start
+
+```python
+from transformers import pipeline
+
+text = "The capital of France is Paris."
+rewarder = pipeline(model="{{ hub_model_id }}", device="cuda")
+output = rewarder(text)[0]
+print(output["score"])
+```
+
+## Training procedure
+
+{% if wandb_url %}[ ]({{ wandb_url }}){% endif %} 
+{% if comet_url %}[
]({{ wandb_url }}){% endif %} 
+{% if comet_url %}[ ]({{ comet_url }}){% endif %}
+
+This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
+
+### Framework versions
+
+- TRL: {{ trl_version }}
+- Transformers: {{ transformers_version }}
+- Pytorch: {{ pytorch_version }}
+- Datasets: {{ datasets_version }}
+- Tokenizers: {{ tokenizers_version }}
+
+## Citations
+
+{% if trainer_citation %}Cite {{ trainer_name }} as:
+
+```bibtex
+{{ trainer_citation }}
+```{% endif %}
+
+Cite TRL as:
+    
+```bibtex
+{% raw %}@misc{vonwerra2022trl,
+	title        = {{TRL: Transformer Reinforcement Learning}},
+	author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
+	year         = 2020,
+	journal      = {GitHub repository},
+	publisher    = {GitHub},
+	howpublished = {\url{https://github.com/huggingface/trl}}
+}{% endraw %}
+```
diff --git a/trl/trainer/base_trainer.py b/trl/trainer/base_trainer.py
index e7cb05def71..bb88cbfc934 100644
--- a/trl/trainer/base_trainer.py
+++ b/trl/trainer/base_trainer.py
@@ -28,6 +28,7 @@ class BaseTrainer(Trainer):
     _tag_names = []
     _name = "Base"
     _paper = {}
+    _template_file = None
 
     def create_model_card(
         self,
@@ -78,6 +79,7 @@ def create_model_card(
             comet_url=get_comet_experiment_url(),
             trainer_name=self._name,
             trainer_citation=self._paper.get("citation"),
+            template_file=self._template_file,
             paper_title=self._paper.get("title"),
             paper_id=self._paper.get("id"),
         )
diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py
index 9a3aabc39ee..33d248e635e 100644
--- a/trl/trainer/reward_config.py
+++ b/trl/trainer/reward_config.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from dataclasses import dataclass, field
-from typing import Optional
+from typing import Any, Optional
 
 from transformers import TrainingArguments
 
@@ -32,22 +32,53 @@ class may differ from those in [`~transformers.TrainingArguments`].
     command line.
 
     Parameters:
-        max_length (`int` or `None`, *optional*, defaults to `1024`):
-            Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
-            limit. This argument is required if you want to use the default data collator.
+        > Parameters that control the model
+
+        model_init_kwargs (`dict[str, Any]`, *optional*):
+            Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
+            argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
+            to include the load balancing/auxilliary loss as a part of the final loss, remember to set
+            `output_router_logits=True` in this dictionary.
+        chat_template_path (`str`, *optional*):
+            If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
+            or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
+            ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
+            embedding layer is resized accordingly.
         disable_dropout (`bool`, *optional*, defaults to `True`):
             Whether to disable dropout in the model.
+
+        > Parameters that control the data preprocessing
+
         dataset_num_proc (`int`, *optional*):
             Number of processes to use for processing the dataset.
+        eos_token (`str`, *optional*):
+            Token used to indicate the end of a turn or sequence. If `None`, it defaults to
+            `processing_class.eos_token`.
+        pad_token (`str`, *optional*):
+            Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
+            it falls back to `processing_class.eos_token`.
+        max_length (`int` or `None`, *optional*, defaults to `1024`):
+            Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
+            exceeds this value. If `None`, no filtering is applied.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, the sequences will be padded to a multiple of this value.
+
+        > Parameters that control the training
+
         center_rewards_coefficient (`float`, *optional*):
             Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
             https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
-        remove_unused_columns (`bool`, *optional*, defaults to `False`):
-            Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the
-            dataset is pretokenized.
+        activation_offloading (`bool`, *optional*, defaults to `False`):
+            Whether to offload the activations to the CPU.
     """
 
+    _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
+
     # Parameters whose default values are overridden from TrainingArguments
+    learning_rate: float = field(
+        default=1e-4,
+        metadata={"help": "The initial learning rate for AdamW."},
+    )
     logging_steps: float = field(
         default=10,
         metadata={
@@ -70,21 +101,59 @@ class may differ from those in [`~transformers.TrainingArguments`].
         },
     )
 
-    max_length: Optional[int] = field(
-        default=1024,
+    # Parameters that control the model
+    model_init_kwargs: Optional[dict[str, Any]] = field(
+        default=None,
         metadata={
-            "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that "
-            "exceed the limit. This argument is required if you want to use the default data collator."
+            "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
+            "the `RewardTrainer` is provided as a string."
+        },
+    )
+    chat_template_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
+            "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
+            "you must ensure that any special tokens referenced in the template are added to the tokenizer and "
+            "that the model's embedding layer is resized accordingly."
         },
     )
     disable_dropout: bool = field(
         default=True,
-        metadata={"help": "Whether to disable dropout in the model and reference model."},
+        metadata={"help": "Whether to disable dropout in the model."},
     )
+
+    # Parameters that control the data preprocessing
     dataset_num_proc: Optional[int] = field(
         default=None,
         metadata={"help": "Number of processes to use for processing the dataset."},
     )
+    eos_token: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
+        },
+    )
+    pad_token: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
+            "is also `None`, it falls back to `processing_class.eos_token`."
+        },
+    )
+    max_length: Optional[int] = field(
+        default=1024,
+        metadata={
+            "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
+            "the right. If `None`, no truncation is applied."
+        },
+    )
+    pad_to_multiple_of: Optional[int] = field(
+        default=None,
+        metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
+    )
+
+    # Parameters that control the training
     center_rewards_coefficient: Optional[float] = field(
         default=None,
         metadata={
@@ -92,15 +161,11 @@ class may differ from those in [`~transformers.TrainingArguments`].
             "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
         },
     )
-    remove_unused_columns: bool = field(
+    activation_offloading: bool = field(
         default=False,
-        metadata={
-            "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
-            "if the dataset is pretokenized."
-        },
+        metadata={"help": "Whether to offload the activations to the CPU."},
     )
 
     def __post_init__(self):
         self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
-
         super().__post_init__()
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 79570eff582..5408db49967 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -12,132 +12,348 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import contextlib
+import logging
+import os
+import re
 from collections import defaultdict
-from dataclasses import FrozenInstanceError, replace
+from contextlib import contextmanager
+from dataclasses import dataclass
 from pathlib import Path
 from typing import Any, Callable, Optional, Union
 
-import pandas as pd
 import torch
 import torch.nn as nn
-from accelerate import PartialState, logging
-from accelerate.utils import gather_object
-from datasets import Dataset
+import transformers
+from accelerate import PartialState
+from accelerate.logging import get_logger
+from datasets import Dataset, IterableDataset
 from transformers import (
-    BaseImageProcessor,
+    AutoModelForSequenceClassification,
+    AutoTokenizer,
     DataCollator,
-    FeatureExtractionMixin,
     PreTrainedModel,
     PreTrainedTokenizerBase,
-    ProcessorMixin,
 )
+from transformers.data.data_collator import DataCollatorMixin
 from transformers.trainer_callback import TrainerCallback
-from transformers.trainer_pt_utils import nested_detach
 from transformers.trainer_utils import EvalPrediction
-from transformers.utils import is_peft_available, is_rich_available
+from transformers.utils import is_peft_available
 
-from ..data_utils import maybe_apply_chat_template
-from ..models import prepare_peft_model
+from ..data_utils import is_conversational
+from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
 from .base_trainer import BaseTrainer
 from .reward_config import RewardConfig
-from .utils import (
-    RewardDataCollatorWithPadding,
-    compute_accuracy,
-    decode_and_strip_padding,
-    disable_dropout_in_model,
-    log_table_to_comet_experiment,
-    print_rich_table,
-)
+from .utils import disable_dropout_in_model, pad, remove_none_values
 
 
 if is_peft_available():
-    from peft import PeftModel
+    from peft import PeftConfig, PeftModel
+
+
+logger = get_logger(__name__)
 
 
-logger = logging.get_logger(__name__)
+# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly
+# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
+# avoid confusing users.
+@contextmanager
+def suppress_from_pretrained_warning(logger: logging.Logger):
+    pattern = re.compile(
+        r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
+        r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
+        r"inference\.$"
+    )
 
+    class _Filter(logging.Filter):
+        def filter(self, record: logging.LogRecord) -> bool:
+            return not pattern.search(record.getMessage())
 
-def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]:
-    """Tokenize a batch from a reward modelling dataset."""
-    new_examples = {
-        "input_ids_chosen": [],
-        "attention_mask_chosen": [],
-        "input_ids_rejected": [],
-        "attention_mask_rejected": [],
-    }
-    for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
-        tokenized_chosen = tokenizer(chosen)
-        tokenized_rejected = tokenizer(rejected)
-        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
-        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
-        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
-        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
+    f = _Filter()
+    logger.addFilter(f)
+    try:
+        yield
+    finally:
+        logger.removeFilter(f)
+
+
+@dataclass
+class DataCollatorForPreference(DataCollatorMixin):
+    """
+    Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
 
-    return new_examples
+    This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and
+    `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys:
+    - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
+        corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
+    - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
+
+    Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a
+    `"margin"` key with a tensor of margins.
+
+    Args:
+        pad_token_id (`int`):
+            Token ID to use for padding.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, the sequences will be padded to a multiple of this value.
+        return_tensors (`str`, *optional*, defaults to `"pt"`):
+            Type of Tensor to return. Only `"pt"` is currently supported.
+
+    Examples:
+    ```python
+    >>> from trl.trainer.reward_trainer import DataCollatorForPreference
+
+    >>> collator = DataCollatorForPreference(pad_token_id=0)
+    >>> examples = [
+    ...     {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
+    ...     {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
+    ... ]
+    >>> collator(examples)
+    {'input_ids': tensor([[1, 2, 3],
+                          [6, 7, 0],
+                          [4, 5, 0],
+                          [8, 0, 0]]),
+     'attention_mask': tensor([[1, 1, 1],
+                               [1, 1, 0],
+                               [1, 1, 0],
+                               [1, 0, 0]])}
+
+    >>> examples = [
+    ...     {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5},
+    ...     {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0},
+    ... ]
+    >>> collator(examples)
+    {'input_ids': tensor([[1, 2, 3],
+                          [6, 7, 0],
+                          [4, 5, 0],
+                          [8, 0, 0]]),
+     'attention_mask': tensor([[1, 1, 1],
+                               [1, 1, 0],
+                               [1, 1, 0],
+                               [1, 0, 0]]),
+     'margin': tensor([0.5, 0.0])}
+    ```
+    """
+
+    pad_token_id: int
+    pad_to_multiple_of: Optional[int] = None
+    return_tensors: str = "pt"
+
+    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
+        # Convert to tensor
+        chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
+        rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
+        if "margin" in examples[0]:
+            margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float)
+        input_ids = chosen_input_ids + rejected_input_ids
+        attention_mask = [torch.ones_like(ids) for ids in input_ids]
+
+        output = {}
+
+        # Pad
+        output["input_ids"] = pad(
+            input_ids,
+            padding_value=self.pad_token_id,
+            padding_side="right",
+            pad_to_multiple_of=self.pad_to_multiple_of,
+        )
+        output["attention_mask"] = pad(
+            attention_mask,
+            padding_value=0,
+            padding_side="right",
+            pad_to_multiple_of=self.pad_to_multiple_of,
+        )
+        if "margin" in examples[0]:
+            output["margin"] = margins
+        return output
 
 
 class RewardTrainer(BaseTrainer):
     """
-    Trainer for custom reward.
+    Trainer for Outcome-supervised Reward Models (ORM).
+
+    This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
+
+    Example:
+
+    ```python
+    from trl import RewardTrainer
+    from datasets import load_dataset
+
+    dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
+
+    trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
+    trainer.train()
+    ```
 
     Args:
-        model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*):
-            Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`].
+        model (`Union[str, PreTrainedModel]`):
+            Model to be trained. Can be either:
+
+            - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
+              path to a *directory* containing model weights saved using
+              [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+              using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
+              `args.model_init_kwargs`.
+            - A sequence classification [`~transformers.PreTrainedModel`] object.
         args ([`RewardConfig`], *optional*):
-            Training arguments.
+            Configuration for this trainer. If `None`, a default configuration is used.
         data_collator ([`~transformers.DataCollator`], *optional*):
-            The data collator to use for training. If None is specified, the default data collator
-            [`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum
-            length of the sequences in the batch, given a dataset of paired sequences.
-        train_dataset ([`~datasets.Dataset`], *optional*):
-            The dataset to use for training.
-        eval_dataset ([`~datasets.Dataset`], *optional*):
-            The dataset to use for evaluation.
-        processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
-            Processing class used to process the data. If provided, will be used to automatically process the inputs
-            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
-            reuse the fine-tuned model.
-        model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*):
-            The model initializer to use for training. If None is specified, the default model initializer will be
-            used.
-        compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]):
-            Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
-            dictionary string to float.
-        callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
-            Callbacks to use during training.
-        optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
-            Tuple containing the optimizer and the learning rate scheduler to use for training.
+            Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
+            Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
+        train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
+            Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
+            explicit prompt). The format of the samples can be either:
+
+            - [Standard](dataset_formats#standard): Each sample contains plain text.
+            - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
+              and content).
+
+            The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
+            `rejected_input_ids` fields.
+        eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
+            Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
+        processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
+            Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
+            [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
+            set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
+            default.
+        compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
+            The function that will be used to compute metrics at evaluation. Must take a
+            [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
+            [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
+            boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
+            function needs to calculate and return the global summary statistics rather than accumulating the
+            batch-level statistics.
+        callbacks (list of [`~transformers.TrainerCallback`], *optional*):
+            List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
+            in [here](https://huggingface.co/docs/transformers/main_classes/callback).
+
+            If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
+            method.
+        optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
+            A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
+            model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
+        optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
+            A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
+            `args`. Incompatible with the `optimizers` argument.
+
+            Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
+            initializing the Trainer.
         preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
-            Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
-            return the logits to be used for metrics computation.
-        peft_config (`dict`, *optional*):
-            PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
-            wrapped with the specified PEFT adapter.
+            A function that preprocess the logits right before caching them at each evaluation step. Must take two
+            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+            by this function will be reflected in the predictions received by `compute_metrics`.
+
+            Note that the labels (second parameter) will be `None` if the dataset does not have them.
+        peft_config ([`~peft.PeftConfig`], *optional*):
+            PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
+            model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
+            to ensure that the reward head is properly trained.
     """
 
     _tag_names = ["trl", "reward-trainer"]
     _name = "Reward"
+    _template_file = "rm_model_card.md"
 
     def __init__(
         self,
-        model: Optional[Union[PreTrainedModel, nn.Module]] = None,
+        model: Union[str, PreTrainedModel],
         args: Optional[RewardConfig] = None,
         data_collator: Optional[DataCollator] = None,
-        train_dataset: Optional[Dataset] = None,
+        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
         eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
-        processing_class: Optional[
-            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
-        ] = None,
-        model_init: Optional[Callable[[], PreTrainedModel]] = None,
+        processing_class: Optional[PreTrainedTokenizerBase] = None,
         compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
         callbacks: Optional[list[TrainerCallback]] = None,
-        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
-            None,
-            None,
-        ),
+        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+        optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
         preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
-        peft_config: Optional[dict] = None,
+        peft_config: Optional["PeftConfig"] = None,
     ):
+        # Args
+        if args is None:
+            model_name = model if isinstance(model, str) else model.config._name_or_path
+            model_name = model_name.split("/")[-1]
+            args = RewardConfig(f"{model_name}-Reward")
+
+        # Model
+        model_init_kwargs = args.model_init_kwargs or {}
+        if isinstance(model, str):
+            model_id = model
+            dtype = model_init_kwargs.get("dtype")
+            if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
+                pass  # dtype is already a torch.dtype or "auto" or None
+            elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
+                model_init_kwargs["dtype"] = getattr(torch, dtype)
+            else:
+                raise ValueError(
+                    "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
+                    f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
+                )
+            with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
+                model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
+        else:
+            model_id = model.config._name_or_path
+            if args.model_init_kwargs is not None:
+                logger.warning(
+                    "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
+                    "The `model_init_kwargs` will be ignored."
+                )
+
+        # Processing class
+        if processing_class is None:
+            processing_class = AutoTokenizer.from_pretrained(model_id)
+
+        # Handle pad token for processors or tokenizers
+        if args.eos_token is not None:
+            eos_token = args.eos_token
+            eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
+            if eos_token_id is None:
+                raise ValueError(
+                    f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
+                    f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
+                    "in the vocabulary before using it as an EOS token."
+                )
+            processing_class.eos_token_id = eos_token_id
+
+        if args.chat_template_path is not None:
+            if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
+                with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
+                    processing_class.chat_template = chat_template_file.read()
+                added_tokens = []
+            else:
+                model, processing_class, added_tokens = clone_chat_template(
+                    model, processing_class, args.chat_template_path
+                )
+        else:
+            added_tokens = []
+
+        # PEFT configuration and model wrapping
+        if peft_config is not None:
+            if added_tokens:
+                # Ensure that the added tokens are trainable
+                if peft_config.trainable_token_indices is None:
+                    peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
+                elif "embed_tokens" not in peft_config.trainable_token_indices:
+                    peft_config.trainable_token_indices["embed_tokens"] = added_tokens
+                else:
+                    peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
+
+                # Ensure that the lm_head is trainable
+                if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
+                    logger.warning(
+                        "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
+                        "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
+                        "tokens, leading to degraded generation quality. To fix this, add "
+                        "`modules_to_save=['lm_head']` to your PEFT configuration."
+                    )
+
+                    if peft_config.modules_to_save is None:
+                        peft_config.modules_to_save = ["lm_head"]
+                    else:
+                        peft_config.modules_to_save.append("lm_head")
+
         if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
             model = prepare_peft_model(model, peft_config, args)
 
@@ -145,78 +361,47 @@ def __init__(
         if args.disable_dropout:
             disable_dropout_in_model(model)
 
-        if compute_metrics is None:
-            compute_metrics = compute_accuracy
+        # Pad token (needed for SequenceClassification models)
+        # If not provided, use the one from the processing class or the eos token if the processing class does not have
+        # a pad token.
+        pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
+        pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
+        if pad_token_id is None:
+            raise ValueError(
+                f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
+                f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
+                "in the vocabulary before using it as a padding token."
+            )
+        model.config.pad_token_id = pad_token_id
+        processing_class.pad_token_id = pad_token_id
 
+        # Data collator
         if data_collator is None:
-            if processing_class is None:
-                raise ValueError(
-                    "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
-                )
+            data_collator = DataCollatorForPreference(
+                pad_token_id=pad_token_id,
+                pad_to_multiple_of=args.pad_to_multiple_of,
+            )
 
-            max_length = args.max_length
+        # Dataset
+        train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
+        if eval_dataset is not None:
+            if isinstance(eval_dataset, dict):
+                eval_dataset = {
+                    key: self._prepare_dataset(dataset, processing_class, args, key)
+                    for key, dataset in eval_dataset.items()
+                }
+            else:
+                eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
 
-            data_collator = RewardDataCollatorWithPadding(processing_class)
+        # Initialize the metrics
+        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
+        self._total_train_tokens = 0
 
-            if args.remove_unused_columns:
-                try:  # for bc before https://github.com/huggingface/transformers/pull/25435
-                    args.remove_unused_columns = False
-                except FrozenInstanceError:
-                    args = replace(args, remove_unused_columns=False)
-                # warn users
-                logger.warning(
-                    "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
-                    " we have set it for you, but you should do it yourself in the future.",
-                )
-
-            self.use_reward_data_collator = True
-        else:
-            self.use_reward_data_collator = False
-
-        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
-        # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
-        # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
-        # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
-        # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
-        # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
-        # issued.
-        model.warnings_issued["estimate_tokens"] = True
-
-        if "input_ids_chosen" not in train_dataset.column_names:
-            with PartialState().main_process_first():
-                fn_kwargs = {"tokenizer": processing_class}
-                train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
-                train_dataset = train_dataset.map(
-                    _tokenize,
-                    batched=True,
-                    fn_kwargs=fn_kwargs,
-                    num_proc=args.dataset_num_proc,
-                )
-                # This filter is important because otherwise you get samples that exceed the model's context length and
-                # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
-                # user might get surprised if N samples are missing from training.
-                train_dataset = train_dataset.filter(
-                    lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
-                    num_proc=args.dataset_num_proc,
-                )
-                if eval_dataset is not None:
-                    eval_dataset = eval_dataset.map(
-                        maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
-                    )
-                    eval_dataset = eval_dataset.map(
-                        _tokenize,
-                        fn_kwargs=fn_kwargs,
-                        batched=True,
-                        num_proc=args.dataset_num_proc,
-                    )
-                    # This filter is important because otherwise you get samples that exceed the model's context length and
-                    # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
-                    # user might get surprised if N samples are missing from training.
-                    eval_dataset = eval_dataset.filter(
-                        lambda x: len(x["input_ids_chosen"]) <= max_length
-                        and len(x["input_ids_rejected"]) <= max_length,
-                        num_proc=args.dataset_num_proc,
-                    )
+        # Initialize the Trainer. Parent class will handle:
+        # - DeepSpeed configuration (through create_accelerator_and_postprocess)
+        # - FSDP setup
+        # - Distributed training setup
+        # - Optimizer and scheduler creation
 
         super().__init__(
             model=model,
@@ -225,35 +410,140 @@ def __init__(
             train_dataset=train_dataset,
             eval_dataset=eval_dataset,
             processing_class=processing_class,
-            model_init=model_init,
             compute_metrics=compute_metrics,
             callbacks=callbacks,
             optimizers=optimizers,
+            optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
             preprocess_logits_for_metrics=preprocess_logits_for_metrics,
         )
 
+        # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty.
+        self.can_return_loss = True
+        self.label_names = []
+
+        # Initialize activation offloading context
+        if self.args.activation_offloading:
+            self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
+        else:
+            self.maybe_activation_offload_context = contextlib.nullcontext()
+
         # Add tags for models that have been loaded with the correct transformers version
         if hasattr(self.model, "add_model_tags"):
             self.model.add_model_tags(self._tag_names)
 
+        self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
+
+    def _prepare_dataset(
+        self,
+        dataset: Union[Dataset, IterableDataset],
+        processing_class: PreTrainedTokenizerBase,
+        args: RewardConfig,
+        dataset_name: str,
+    ) -> Union[Dataset, IterableDataset]:
+        # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
+        # sampled data.
+        if isinstance(dataset, Dataset):  # IterableDataset does not support `with_transform`
+            dataset = dataset.with_transform(remove_none_values)
+
+        # If the dataset is already preprocessed (tokenized), skip the processing steps.
+        column_names = list(next(iter(dataset)).keys())
+        is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
+
+        # Build the kwargs for the `map` function
+        map_kwargs = {}
+        if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc
+            map_kwargs["num_proc"] = args.dataset_num_proc
+
+        with PartialState().main_process_first():
+            if not is_processed:
+                # Add EOS token to the end of the sequences if needed
+                first_example = next(iter(dataset))
+                if not is_conversational(first_example):
+                    if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                        map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
+
+                    def add_eos(example, eos_token):
+                        if not example["chosen"].endswith(eos_token):
+                            example["chosen"] = example["chosen"] + eos_token
+                        if "rejected" in example and not example["rejected"].endswith(eos_token):
+                            example["rejected"] = example["rejected"] + eos_token
+                        return example
+
+                    dataset = dataset.map(
+                        add_eos,
+                        fn_kwargs={"eos_token": processing_class.eos_token},
+                        **map_kwargs,
+                    )
+
+                # Tokenize the dataset
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
+
+                def tokenize_fn(example, processing_class):
+                    if "prompt" in example:  # explicit prompt case
+                        example["chosen"] = example["prompt"] + example["chosen"]
+                        example["rejected"] = example["prompt"] + example["rejected"]
+
+                    if is_conversational(example):
+                        chosen_input_ids = processing_class.apply_chat_template(
+                            example["chosen"],
+                            tools=example.get("tools"),
+                            **example.get("chat_template_kwargs", {}),
+                        )
+                        rejected_input_ids = processing_class.apply_chat_template(
+                            example["rejected"],
+                            tools=example.get("tools"),
+                            **example.get("chat_template_kwargs", {}),
+                        )
+                        output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
+                    else:
+                        output = {
+                            "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
+                            "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
+                        }
+                    return output
+
+                dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
+
+            # Filter samples that are longer than `max_length`
+            if args.max_length is not None:
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
+                dataset = dataset.filter(
+                    lambda example: len(example["chosen_input_ids"]) <= args.max_length
+                    and len(example["rejected_input_ids"]) <= args.max_length,
+                    **map_kwargs,
+                )
+
+        return dataset
+
+    def _set_signature_columns_if_needed(self):
+        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+        # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
+        # and "attention_mask").
+        if self._signature_columns is None:
+            self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
+
     def compute_loss(
         self,
-        model: Union[PreTrainedModel, nn.Module],
+        model: nn.Module,
         inputs: dict[str, Union[torch.Tensor, Any]],
-        return_outputs=False,
-        num_items_in_batch=None,
-    ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
-        rewards_chosen = model(
-            input_ids=inputs["input_ids_chosen"],
-            attention_mask=inputs["attention_mask_chosen"],
-            return_dict=True,
-        )["logits"]
-        rewards_rejected = model(
-            input_ids=inputs["input_ids_rejected"],
-            attention_mask=inputs["attention_mask_rejected"],
-            return_dict=True,
-        )["logits"]
-        # calculate loss, optionally modulate with margin
+        return_outputs: bool = False,
+        num_items_in_batch: Optional[torch.Tensor] = None,
+    ):
+        """
+        Compute training loss and additionally compute token accuracies
+        """
+        mode = "train" if self.model.training else "eval"
+
+        # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
+        inputs["use_cache"] = False
+        outputs = model(**inputs)
+
+        # Split the rewards into chosen and rejected
+        rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
+
+        # Calculate loss, optionally modulate with margin
         if "margin" in inputs:
             loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
         else:
@@ -262,86 +552,45 @@ def compute_loss(
         if self.args.center_rewards_coefficient is not None:
             loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
 
-        if return_outputs:
-            return loss, {
-                "rewards_chosen": rewards_chosen,
-                "rewards_rejected": rewards_rejected,
-            }
-        return loss
-
-    def prediction_step(
-        self,
-        model: Union[PreTrainedModel, nn.Module],
-        inputs: dict[str, Union[torch.Tensor, Any]],
-        prediction_loss_only: bool,
-        ignore_keys: Optional[list[str]] = None,
-    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
-        inputs = self._prepare_inputs(inputs)
-        if ignore_keys is None:
-            if hasattr(self.model, "config"):
-                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
-            else:
-                ignore_keys = []
+        if mode == "train":
+            num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
+            self._total_train_tokens += num_tokens_in_batch
+        self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
 
+        # Compute min, mean, max, accuracy and margin
         with torch.no_grad():
-            loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
-
-        if prediction_loss_only:
-            return (loss, None, None)
-
-        loss = loss.detach()
-        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
-        logits = nested_detach(logits)
-        # Stack accepted against rejected, mean over logits
-        # and softmax to get preferences between accepted and rejected to sum to 1
-        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
-
-        labels = torch.zeros(logits.shape[0])
-        labels = self._prepare_inputs(labels)
-
-        return loss, logits, labels
-
-    def evaluate(self, *args, **kwargs):
-        num_print_samples = kwargs.pop("num_print_samples", 4)
-        self.visualize_samples(num_print_samples)
-        return super().evaluate(*args, **kwargs)
-
-    def visualize_samples(self, num_print_samples: int):
-        """
-        Visualize the reward model logits prediction
-
-        Args:
-            num_print_samples (`int`, defaults to `4`):
-                The number of samples to print. Set to `-1` to print all samples.
-        """
-        eval_dataloader = self.get_eval_dataloader()
-        table = defaultdict(list)
-        for _, inputs in enumerate(eval_dataloader):
-            _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
-            chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
-            rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
-            table["chosen_text"].extend(gather_object(chosen_text))
-            table["rejected_text"].extend(gather_object(rejected_text))
-            table["logits"].extend(
-                gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
-            )
-            if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
-                break
-        df = pd.DataFrame(table)
-        if self.accelerator.process_index == 0:
-            if is_rich_available():
-                print_rich_table(df[:num_print_samples])
-            if "wandb" in self.args.report_to:
-                import wandb
-
-                if wandb.run is not None:
-                    wandb.log({"completions": wandb.Table(dataframe=df)})
-
-            if "comet_ml" in self.args.report_to:
-                log_table_to_comet_experiment(
-                    name="completions.csv",
-                    table=df,
-                )
+            all_rewards = self.accelerator.gather(outputs.logits)
+            self._metrics[mode]["min_reward"].append(all_rewards.min().item())
+            self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
+            self._metrics[mode]["max_reward"].append(all_rewards.max().item())
+
+            mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
+            mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
+            self._metrics[mode]["accuracy"].append(mean_accuracy)
+
+            mean_margin = (rewards_chosen - rewards_rejected).mean()
+            mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
+            self._metrics[mode]["margin"].append(mean_margin.item())
+
+        return (loss, outputs) if return_outputs else loss
+
+    # Override training step to add activation offloading context.
+    def training_step(self, *args, **kwargs):
+        with self.maybe_activation_offload_context:
+            return super().training_step(*args, **kwargs)
+
+    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+        mode = "train" if self.model.training else "eval"
+        metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}  # average the metrics
+
+        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
+        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
+        if mode == "eval":
+            metrics = {f"eval_{key}": val for key, val in metrics.items()}
+
+        logs.update(metrics)
+        super().log(logs, start_time)
+        self._metrics[mode].clear()
 
     # Ensure the model card is saved along with the checkpoint
     def _save_checkpoint(self, model, trial):
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 6f5db07ba50..c0f26bc9373 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -15,10 +15,9 @@
 import contextlib
 import os
 from collections import defaultdict
-from collections.abc import Mapping
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Callable, Optional, TypeVar, Union
+from typing import Any, Callable, Optional, Union
 
 import torch
 import torch.nn as nn
@@ -56,6 +55,7 @@
     entropy_from_logits,
     flush_left,
     pad,
+    remove_none_values,
     selective_log_softmax,
 )
 
@@ -66,7 +66,6 @@
 
 logger = logging.get_logger(__name__)
 
-TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
 
 FLASH_ATTENTION_VARIANTS = {
     "flash_attention_2",
@@ -77,38 +76,6 @@
 }
 
 
-def remove_none_values(example: TListOrMapping) -> TListOrMapping:
-    """
-    Recursively removes entries with `None` values from a nested structure (list or dictionary).
-
-    Args:
-        example (`list` or `Mapping`):
-            Input nested structure (list or dictionary) from which to remove `None`.
-
-    Example:
-    ```python
-    >>> [
-    ...     {
-    ...         "a": {"aa": None, "ab": 1},
-    ...         "b": "my_string",
-    ...     }
-    ... ]
-    >>> remove_none_values(example)
-    [{'a': {'ab': 1}, 'b': 'my_string'}]
-    ```
-    """
-    if isinstance(example, list):
-        return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
-    elif isinstance(example, Mapping):
-        return {
-            key: remove_none_values(value) if isinstance(value, (dict, list)) else value
-            for key, value in example.items()
-            if value is not None
-        }
-    else:
-        raise TypeError("Input must be a list or a dictionary.")
-
-
 def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
     return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
 
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 1ef75bd0da2..a12fdec7b44 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -18,11 +18,12 @@
 import os
 import random
 import socket
-from collections.abc import Sequence, Sized
+import warnings
+from collections.abc import Mapping, Sequence, Sized
 from dataclasses import dataclass, field
 from importlib.metadata import version
 from itertools import accumulate
-from typing import Any, Literal, Optional, Union
+from typing import Any, Literal, Optional, TypeVar, Union
 
 import numpy as np
 import pandas as pd
@@ -219,6 +220,10 @@ class RewardDataCollatorWithPadding:
     r"""
     Reward DataCollator class that pads the inputs to the maximum length of the batch.
 
+    > [!WARNING]
+    > This class is deprecated and will be removed in version 0.27.0. Please use
+    `trl.trainer.reward_trainer.DataCollatorForPreference` instead.
+
     Args:
         tokenizer (`PreTrainedTokenizerBase`):
             The tokenizer used for encoding the data.
@@ -235,6 +240,14 @@ class RewardDataCollatorWithPadding:
     pad_to_multiple_of: Optional[int] = None
     return_tensors: str = "pt"
 
+    def __init__(self, *args, **kwargs) -> None:
+        warnings.warn(
+            "The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use "
+            "`trl.trainer.reward_trainer.DataCollatorForPreference` instead.",
+            DeprecationWarning,
+        )
+        super().__init__(*args, **kwargs)
+
     def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
         features_chosen = []
         features_rejected = []
@@ -1241,6 +1254,10 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
     """
     Decodes the input tensor and strips the padding tokens.
 
+    > [!WARNING]
+    > This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy
+    > the code into your codebase and use it from there.
+
     Args:
         inputs (`torch.Tensor`):
             The input tensor to be decoded.
@@ -1251,6 +1268,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
         `list[str]`:
             The list of decoded strings with padding tokens stripped.
     """
+    warnings.warn(
+        "The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want "
+        "to keep using it, please copy the code into your codebase and use it from there.",
+        DeprecationWarning,
+    )
     decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
     return [d.replace(tokenizer.pad_token, "") for d in decoded]
 
@@ -1264,6 +1286,7 @@ def generate_model_card(
     wandb_url: Optional[str],
     trainer_name: str,
     trainer_citation: Optional[str] = None,
+    template_file: Optional[str] = None,
     paper_title: Optional[str] = None,
     paper_id: Optional[str] = None,
     comet_url: Optional[str] = None,
@@ -1290,6 +1313,8 @@ def generate_model_card(
             Trainer name.
         trainer_citation (`str` or `None`, defaults to `None`):
             Trainer citation as a BibTeX entry.
+        template_file (`str` *optional*):
+            Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`.
         paper_title (`str` or `None`, defaults to `None`):
             Paper title.
         paper_id (`str` or `None`, defaults to `None`):
@@ -1307,9 +1332,10 @@ def generate_model_card(
         model_name=model_name,
         tags=["generated_from_trainer", *tags],
     )
+    template_file = template_file or "lm_model_card.md"
     card = ModelCard.from_template(
         card_data,
-        template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")),
+        template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")),
         base_model=base_model,
         model_name=model_name,
         hub_model_id=hub_model_id,
@@ -1956,6 +1982,41 @@ def process_sequence(ids, mask):
     return torch.stack(truncated_seq), torch.stack(truncated_mask)
 
 
+TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
+
+
+def remove_none_values(example: TListOrMapping) -> TListOrMapping:
+    """
+    Recursively removes entries with `None` values from a nested structure (list or dictionary).
+
+    Args:
+        example (`list` or `Mapping`):
+            Input nested structure (list or dictionary) from which to remove `None`.
+
+    Example:
+    ```python
+    >>> [
+    ...     {
+    ...         "a": {"aa": None, "ab": 1},
+    ...         "b": "my_string",
+    ...     }
+    ... ]
+    >>> remove_none_values(example)
+    [{'a': {'ab': 1}, 'b': 'my_string'}]
+    ```
+    """
+    if isinstance(example, list):
+        return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
+    elif isinstance(example, Mapping):
+        return {
+            key: remove_none_values(value) if isinstance(value, (dict, list)) else value
+            for key, value in example.items()
+            if value is not None
+        }
+    else:
+        raise TypeError("Input must be a list or a dictionary.")
+
+
 def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
     """
     Create a model from a given path using the specified initialization arguments.
]({{ comet_url }}){% endif %}
+
+This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
+
+### Framework versions
+
+- TRL: {{ trl_version }}
+- Transformers: {{ transformers_version }}
+- Pytorch: {{ pytorch_version }}
+- Datasets: {{ datasets_version }}
+- Tokenizers: {{ tokenizers_version }}
+
+## Citations
+
+{% if trainer_citation %}Cite {{ trainer_name }} as:
+
+```bibtex
+{{ trainer_citation }}
+```{% endif %}
+
+Cite TRL as:
+    
+```bibtex
+{% raw %}@misc{vonwerra2022trl,
+	title        = {{TRL: Transformer Reinforcement Learning}},
+	author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
+	year         = 2020,
+	journal      = {GitHub repository},
+	publisher    = {GitHub},
+	howpublished = {\url{https://github.com/huggingface/trl}}
+}{% endraw %}
+```
diff --git a/trl/trainer/base_trainer.py b/trl/trainer/base_trainer.py
index e7cb05def71..bb88cbfc934 100644
--- a/trl/trainer/base_trainer.py
+++ b/trl/trainer/base_trainer.py
@@ -28,6 +28,7 @@ class BaseTrainer(Trainer):
     _tag_names = []
     _name = "Base"
     _paper = {}
+    _template_file = None
 
     def create_model_card(
         self,
@@ -78,6 +79,7 @@ def create_model_card(
             comet_url=get_comet_experiment_url(),
             trainer_name=self._name,
             trainer_citation=self._paper.get("citation"),
+            template_file=self._template_file,
             paper_title=self._paper.get("title"),
             paper_id=self._paper.get("id"),
         )
diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py
index 9a3aabc39ee..33d248e635e 100644
--- a/trl/trainer/reward_config.py
+++ b/trl/trainer/reward_config.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from dataclasses import dataclass, field
-from typing import Optional
+from typing import Any, Optional
 
 from transformers import TrainingArguments
 
@@ -32,22 +32,53 @@ class may differ from those in [`~transformers.TrainingArguments`].
     command line.
 
     Parameters:
-        max_length (`int` or `None`, *optional*, defaults to `1024`):
-            Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
-            limit. This argument is required if you want to use the default data collator.
+        > Parameters that control the model
+
+        model_init_kwargs (`dict[str, Any]`, *optional*):
+            Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
+            argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
+            to include the load balancing/auxilliary loss as a part of the final loss, remember to set
+            `output_router_logits=True` in this dictionary.
+        chat_template_path (`str`, *optional*):
+            If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
+            or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
+            ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
+            embedding layer is resized accordingly.
         disable_dropout (`bool`, *optional*, defaults to `True`):
             Whether to disable dropout in the model.
+
+        > Parameters that control the data preprocessing
+
         dataset_num_proc (`int`, *optional*):
             Number of processes to use for processing the dataset.
+        eos_token (`str`, *optional*):
+            Token used to indicate the end of a turn or sequence. If `None`, it defaults to
+            `processing_class.eos_token`.
+        pad_token (`str`, *optional*):
+            Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
+            it falls back to `processing_class.eos_token`.
+        max_length (`int` or `None`, *optional*, defaults to `1024`):
+            Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
+            exceeds this value. If `None`, no filtering is applied.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, the sequences will be padded to a multiple of this value.
+
+        > Parameters that control the training
+
         center_rewards_coefficient (`float`, *optional*):
             Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
             https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
-        remove_unused_columns (`bool`, *optional*, defaults to `False`):
-            Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the
-            dataset is pretokenized.
+        activation_offloading (`bool`, *optional*, defaults to `False`):
+            Whether to offload the activations to the CPU.
     """
 
+    _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
+
     # Parameters whose default values are overridden from TrainingArguments
+    learning_rate: float = field(
+        default=1e-4,
+        metadata={"help": "The initial learning rate for AdamW."},
+    )
     logging_steps: float = field(
         default=10,
         metadata={
@@ -70,21 +101,59 @@ class may differ from those in [`~transformers.TrainingArguments`].
         },
     )
 
-    max_length: Optional[int] = field(
-        default=1024,
+    # Parameters that control the model
+    model_init_kwargs: Optional[dict[str, Any]] = field(
+        default=None,
         metadata={
-            "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that "
-            "exceed the limit. This argument is required if you want to use the default data collator."
+            "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
+            "the `RewardTrainer` is provided as a string."
+        },
+    )
+    chat_template_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
+            "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
+            "you must ensure that any special tokens referenced in the template are added to the tokenizer and "
+            "that the model's embedding layer is resized accordingly."
         },
     )
     disable_dropout: bool = field(
         default=True,
-        metadata={"help": "Whether to disable dropout in the model and reference model."},
+        metadata={"help": "Whether to disable dropout in the model."},
     )
+
+    # Parameters that control the data preprocessing
     dataset_num_proc: Optional[int] = field(
         default=None,
         metadata={"help": "Number of processes to use for processing the dataset."},
     )
+    eos_token: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
+        },
+    )
+    pad_token: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
+            "is also `None`, it falls back to `processing_class.eos_token`."
+        },
+    )
+    max_length: Optional[int] = field(
+        default=1024,
+        metadata={
+            "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
+            "the right. If `None`, no truncation is applied."
+        },
+    )
+    pad_to_multiple_of: Optional[int] = field(
+        default=None,
+        metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
+    )
+
+    # Parameters that control the training
     center_rewards_coefficient: Optional[float] = field(
         default=None,
         metadata={
@@ -92,15 +161,11 @@ class may differ from those in [`~transformers.TrainingArguments`].
             "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
         },
     )
-    remove_unused_columns: bool = field(
+    activation_offloading: bool = field(
         default=False,
-        metadata={
-            "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
-            "if the dataset is pretokenized."
-        },
+        metadata={"help": "Whether to offload the activations to the CPU."},
     )
 
     def __post_init__(self):
         self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
-
         super().__post_init__()
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 79570eff582..5408db49967 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -12,132 +12,348 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import contextlib
+import logging
+import os
+import re
 from collections import defaultdict
-from dataclasses import FrozenInstanceError, replace
+from contextlib import contextmanager
+from dataclasses import dataclass
 from pathlib import Path
 from typing import Any, Callable, Optional, Union
 
-import pandas as pd
 import torch
 import torch.nn as nn
-from accelerate import PartialState, logging
-from accelerate.utils import gather_object
-from datasets import Dataset
+import transformers
+from accelerate import PartialState
+from accelerate.logging import get_logger
+from datasets import Dataset, IterableDataset
 from transformers import (
-    BaseImageProcessor,
+    AutoModelForSequenceClassification,
+    AutoTokenizer,
     DataCollator,
-    FeatureExtractionMixin,
     PreTrainedModel,
     PreTrainedTokenizerBase,
-    ProcessorMixin,
 )
+from transformers.data.data_collator import DataCollatorMixin
 from transformers.trainer_callback import TrainerCallback
-from transformers.trainer_pt_utils import nested_detach
 from transformers.trainer_utils import EvalPrediction
-from transformers.utils import is_peft_available, is_rich_available
+from transformers.utils import is_peft_available
 
-from ..data_utils import maybe_apply_chat_template
-from ..models import prepare_peft_model
+from ..data_utils import is_conversational
+from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
 from .base_trainer import BaseTrainer
 from .reward_config import RewardConfig
-from .utils import (
-    RewardDataCollatorWithPadding,
-    compute_accuracy,
-    decode_and_strip_padding,
-    disable_dropout_in_model,
-    log_table_to_comet_experiment,
-    print_rich_table,
-)
+from .utils import disable_dropout_in_model, pad, remove_none_values
 
 
 if is_peft_available():
-    from peft import PeftModel
+    from peft import PeftConfig, PeftModel
+
+
+logger = get_logger(__name__)
 
 
-logger = logging.get_logger(__name__)
+# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly
+# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
+# avoid confusing users.
+@contextmanager
+def suppress_from_pretrained_warning(logger: logging.Logger):
+    pattern = re.compile(
+        r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
+        r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
+        r"inference\.$"
+    )
 
+    class _Filter(logging.Filter):
+        def filter(self, record: logging.LogRecord) -> bool:
+            return not pattern.search(record.getMessage())
 
-def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]:
-    """Tokenize a batch from a reward modelling dataset."""
-    new_examples = {
-        "input_ids_chosen": [],
-        "attention_mask_chosen": [],
-        "input_ids_rejected": [],
-        "attention_mask_rejected": [],
-    }
-    for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
-        tokenized_chosen = tokenizer(chosen)
-        tokenized_rejected = tokenizer(rejected)
-        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
-        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
-        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
-        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
+    f = _Filter()
+    logger.addFilter(f)
+    try:
+        yield
+    finally:
+        logger.removeFilter(f)
+
+
+@dataclass
+class DataCollatorForPreference(DataCollatorMixin):
+    """
+    Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
 
-    return new_examples
+    This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and
+    `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys:
+    - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
+        corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
+    - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
+
+    Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a
+    `"margin"` key with a tensor of margins.
+
+    Args:
+        pad_token_id (`int`):
+            Token ID to use for padding.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, the sequences will be padded to a multiple of this value.
+        return_tensors (`str`, *optional*, defaults to `"pt"`):
+            Type of Tensor to return. Only `"pt"` is currently supported.
+
+    Examples:
+    ```python
+    >>> from trl.trainer.reward_trainer import DataCollatorForPreference
+
+    >>> collator = DataCollatorForPreference(pad_token_id=0)
+    >>> examples = [
+    ...     {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
+    ...     {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
+    ... ]
+    >>> collator(examples)
+    {'input_ids': tensor([[1, 2, 3],
+                          [6, 7, 0],
+                          [4, 5, 0],
+                          [8, 0, 0]]),
+     'attention_mask': tensor([[1, 1, 1],
+                               [1, 1, 0],
+                               [1, 1, 0],
+                               [1, 0, 0]])}
+
+    >>> examples = [
+    ...     {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5},
+    ...     {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0},
+    ... ]
+    >>> collator(examples)
+    {'input_ids': tensor([[1, 2, 3],
+                          [6, 7, 0],
+                          [4, 5, 0],
+                          [8, 0, 0]]),
+     'attention_mask': tensor([[1, 1, 1],
+                               [1, 1, 0],
+                               [1, 1, 0],
+                               [1, 0, 0]]),
+     'margin': tensor([0.5, 0.0])}
+    ```
+    """
+
+    pad_token_id: int
+    pad_to_multiple_of: Optional[int] = None
+    return_tensors: str = "pt"
+
+    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
+        # Convert to tensor
+        chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
+        rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
+        if "margin" in examples[0]:
+            margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float)
+        input_ids = chosen_input_ids + rejected_input_ids
+        attention_mask = [torch.ones_like(ids) for ids in input_ids]
+
+        output = {}
+
+        # Pad
+        output["input_ids"] = pad(
+            input_ids,
+            padding_value=self.pad_token_id,
+            padding_side="right",
+            pad_to_multiple_of=self.pad_to_multiple_of,
+        )
+        output["attention_mask"] = pad(
+            attention_mask,
+            padding_value=0,
+            padding_side="right",
+            pad_to_multiple_of=self.pad_to_multiple_of,
+        )
+        if "margin" in examples[0]:
+            output["margin"] = margins
+        return output
 
 
 class RewardTrainer(BaseTrainer):
     """
-    Trainer for custom reward.
+    Trainer for Outcome-supervised Reward Models (ORM).
+
+    This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
+
+    Example:
+
+    ```python
+    from trl import RewardTrainer
+    from datasets import load_dataset
+
+    dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
+
+    trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
+    trainer.train()
+    ```
 
     Args:
-        model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*):
-            Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`].
+        model (`Union[str, PreTrainedModel]`):
+            Model to be trained. Can be either:
+
+            - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
+              path to a *directory* containing model weights saved using
+              [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+              using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
+              `args.model_init_kwargs`.
+            - A sequence classification [`~transformers.PreTrainedModel`] object.
         args ([`RewardConfig`], *optional*):
-            Training arguments.
+            Configuration for this trainer. If `None`, a default configuration is used.
         data_collator ([`~transformers.DataCollator`], *optional*):
-            The data collator to use for training. If None is specified, the default data collator
-            [`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum
-            length of the sequences in the batch, given a dataset of paired sequences.
-        train_dataset ([`~datasets.Dataset`], *optional*):
-            The dataset to use for training.
-        eval_dataset ([`~datasets.Dataset`], *optional*):
-            The dataset to use for evaluation.
-        processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
-            Processing class used to process the data. If provided, will be used to automatically process the inputs
-            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
-            reuse the fine-tuned model.
-        model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*):
-            The model initializer to use for training. If None is specified, the default model initializer will be
-            used.
-        compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]):
-            Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
-            dictionary string to float.
-        callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
-            Callbacks to use during training.
-        optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
-            Tuple containing the optimizer and the learning rate scheduler to use for training.
+            Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
+            Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
+        train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
+            Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
+            explicit prompt). The format of the samples can be either:
+
+            - [Standard](dataset_formats#standard): Each sample contains plain text.
+            - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
+              and content).
+
+            The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
+            `rejected_input_ids` fields.
+        eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
+            Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
+        processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
+            Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
+            [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
+            set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
+            default.
+        compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
+            The function that will be used to compute metrics at evaluation. Must take a
+            [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
+            [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
+            boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
+            function needs to calculate and return the global summary statistics rather than accumulating the
+            batch-level statistics.
+        callbacks (list of [`~transformers.TrainerCallback`], *optional*):
+            List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
+            in [here](https://huggingface.co/docs/transformers/main_classes/callback).
+
+            If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
+            method.
+        optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
+            A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
+            model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
+        optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
+            A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
+            `args`. Incompatible with the `optimizers` argument.
+
+            Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
+            initializing the Trainer.
         preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
-            Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
-            return the logits to be used for metrics computation.
-        peft_config (`dict`, *optional*):
-            PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
-            wrapped with the specified PEFT adapter.
+            A function that preprocess the logits right before caching them at each evaluation step. Must take two
+            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+            by this function will be reflected in the predictions received by `compute_metrics`.
+
+            Note that the labels (second parameter) will be `None` if the dataset does not have them.
+        peft_config ([`~peft.PeftConfig`], *optional*):
+            PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
+            model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
+            to ensure that the reward head is properly trained.
     """
 
     _tag_names = ["trl", "reward-trainer"]
     _name = "Reward"
+    _template_file = "rm_model_card.md"
 
     def __init__(
         self,
-        model: Optional[Union[PreTrainedModel, nn.Module]] = None,
+        model: Union[str, PreTrainedModel],
         args: Optional[RewardConfig] = None,
         data_collator: Optional[DataCollator] = None,
-        train_dataset: Optional[Dataset] = None,
+        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
         eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
-        processing_class: Optional[
-            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
-        ] = None,
-        model_init: Optional[Callable[[], PreTrainedModel]] = None,
+        processing_class: Optional[PreTrainedTokenizerBase] = None,
         compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
         callbacks: Optional[list[TrainerCallback]] = None,
-        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
-            None,
-            None,
-        ),
+        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+        optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
         preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
-        peft_config: Optional[dict] = None,
+        peft_config: Optional["PeftConfig"] = None,
     ):
+        # Args
+        if args is None:
+            model_name = model if isinstance(model, str) else model.config._name_or_path
+            model_name = model_name.split("/")[-1]
+            args = RewardConfig(f"{model_name}-Reward")
+
+        # Model
+        model_init_kwargs = args.model_init_kwargs or {}
+        if isinstance(model, str):
+            model_id = model
+            dtype = model_init_kwargs.get("dtype")
+            if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
+                pass  # dtype is already a torch.dtype or "auto" or None
+            elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
+                model_init_kwargs["dtype"] = getattr(torch, dtype)
+            else:
+                raise ValueError(
+                    "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
+                    f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
+                )
+            with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
+                model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
+        else:
+            model_id = model.config._name_or_path
+            if args.model_init_kwargs is not None:
+                logger.warning(
+                    "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
+                    "The `model_init_kwargs` will be ignored."
+                )
+
+        # Processing class
+        if processing_class is None:
+            processing_class = AutoTokenizer.from_pretrained(model_id)
+
+        # Handle pad token for processors or tokenizers
+        if args.eos_token is not None:
+            eos_token = args.eos_token
+            eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
+            if eos_token_id is None:
+                raise ValueError(
+                    f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
+                    f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
+                    "in the vocabulary before using it as an EOS token."
+                )
+            processing_class.eos_token_id = eos_token_id
+
+        if args.chat_template_path is not None:
+            if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
+                with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
+                    processing_class.chat_template = chat_template_file.read()
+                added_tokens = []
+            else:
+                model, processing_class, added_tokens = clone_chat_template(
+                    model, processing_class, args.chat_template_path
+                )
+        else:
+            added_tokens = []
+
+        # PEFT configuration and model wrapping
+        if peft_config is not None:
+            if added_tokens:
+                # Ensure that the added tokens are trainable
+                if peft_config.trainable_token_indices is None:
+                    peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
+                elif "embed_tokens" not in peft_config.trainable_token_indices:
+                    peft_config.trainable_token_indices["embed_tokens"] = added_tokens
+                else:
+                    peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
+
+                # Ensure that the lm_head is trainable
+                if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
+                    logger.warning(
+                        "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
+                        "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
+                        "tokens, leading to degraded generation quality. To fix this, add "
+                        "`modules_to_save=['lm_head']` to your PEFT configuration."
+                    )
+
+                    if peft_config.modules_to_save is None:
+                        peft_config.modules_to_save = ["lm_head"]
+                    else:
+                        peft_config.modules_to_save.append("lm_head")
+
         if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
             model = prepare_peft_model(model, peft_config, args)
 
@@ -145,78 +361,47 @@ def __init__(
         if args.disable_dropout:
             disable_dropout_in_model(model)
 
-        if compute_metrics is None:
-            compute_metrics = compute_accuracy
+        # Pad token (needed for SequenceClassification models)
+        # If not provided, use the one from the processing class or the eos token if the processing class does not have
+        # a pad token.
+        pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
+        pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
+        if pad_token_id is None:
+            raise ValueError(
+                f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
+                f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
+                "in the vocabulary before using it as a padding token."
+            )
+        model.config.pad_token_id = pad_token_id
+        processing_class.pad_token_id = pad_token_id
 
+        # Data collator
         if data_collator is None:
-            if processing_class is None:
-                raise ValueError(
-                    "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
-                )
+            data_collator = DataCollatorForPreference(
+                pad_token_id=pad_token_id,
+                pad_to_multiple_of=args.pad_to_multiple_of,
+            )
 
-            max_length = args.max_length
+        # Dataset
+        train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
+        if eval_dataset is not None:
+            if isinstance(eval_dataset, dict):
+                eval_dataset = {
+                    key: self._prepare_dataset(dataset, processing_class, args, key)
+                    for key, dataset in eval_dataset.items()
+                }
+            else:
+                eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
 
-            data_collator = RewardDataCollatorWithPadding(processing_class)
+        # Initialize the metrics
+        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
+        self._total_train_tokens = 0
 
-            if args.remove_unused_columns:
-                try:  # for bc before https://github.com/huggingface/transformers/pull/25435
-                    args.remove_unused_columns = False
-                except FrozenInstanceError:
-                    args = replace(args, remove_unused_columns=False)
-                # warn users
-                logger.warning(
-                    "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
-                    " we have set it for you, but you should do it yourself in the future.",
-                )
-
-            self.use_reward_data_collator = True
-        else:
-            self.use_reward_data_collator = False
-
-        # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
-        # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
-        # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
-        # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
-        # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
-        # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
-        # issued.
-        model.warnings_issued["estimate_tokens"] = True
-
-        if "input_ids_chosen" not in train_dataset.column_names:
-            with PartialState().main_process_first():
-                fn_kwargs = {"tokenizer": processing_class}
-                train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
-                train_dataset = train_dataset.map(
-                    _tokenize,
-                    batched=True,
-                    fn_kwargs=fn_kwargs,
-                    num_proc=args.dataset_num_proc,
-                )
-                # This filter is important because otherwise you get samples that exceed the model's context length and
-                # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
-                # user might get surprised if N samples are missing from training.
-                train_dataset = train_dataset.filter(
-                    lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
-                    num_proc=args.dataset_num_proc,
-                )
-                if eval_dataset is not None:
-                    eval_dataset = eval_dataset.map(
-                        maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
-                    )
-                    eval_dataset = eval_dataset.map(
-                        _tokenize,
-                        fn_kwargs=fn_kwargs,
-                        batched=True,
-                        num_proc=args.dataset_num_proc,
-                    )
-                    # This filter is important because otherwise you get samples that exceed the model's context length and
-                    # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
-                    # user might get surprised if N samples are missing from training.
-                    eval_dataset = eval_dataset.filter(
-                        lambda x: len(x["input_ids_chosen"]) <= max_length
-                        and len(x["input_ids_rejected"]) <= max_length,
-                        num_proc=args.dataset_num_proc,
-                    )
+        # Initialize the Trainer. Parent class will handle:
+        # - DeepSpeed configuration (through create_accelerator_and_postprocess)
+        # - FSDP setup
+        # - Distributed training setup
+        # - Optimizer and scheduler creation
 
         super().__init__(
             model=model,
@@ -225,35 +410,140 @@ def __init__(
             train_dataset=train_dataset,
             eval_dataset=eval_dataset,
             processing_class=processing_class,
-            model_init=model_init,
             compute_metrics=compute_metrics,
             callbacks=callbacks,
             optimizers=optimizers,
+            optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
             preprocess_logits_for_metrics=preprocess_logits_for_metrics,
         )
 
+        # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty.
+        self.can_return_loss = True
+        self.label_names = []
+
+        # Initialize activation offloading context
+        if self.args.activation_offloading:
+            self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
+        else:
+            self.maybe_activation_offload_context = contextlib.nullcontext()
+
         # Add tags for models that have been loaded with the correct transformers version
         if hasattr(self.model, "add_model_tags"):
             self.model.add_model_tags(self._tag_names)
 
+        self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
+
+    def _prepare_dataset(
+        self,
+        dataset: Union[Dataset, IterableDataset],
+        processing_class: PreTrainedTokenizerBase,
+        args: RewardConfig,
+        dataset_name: str,
+    ) -> Union[Dataset, IterableDataset]:
+        # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
+        # sampled data.
+        if isinstance(dataset, Dataset):  # IterableDataset does not support `with_transform`
+            dataset = dataset.with_transform(remove_none_values)
+
+        # If the dataset is already preprocessed (tokenized), skip the processing steps.
+        column_names = list(next(iter(dataset)).keys())
+        is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
+
+        # Build the kwargs for the `map` function
+        map_kwargs = {}
+        if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc
+            map_kwargs["num_proc"] = args.dataset_num_proc
+
+        with PartialState().main_process_first():
+            if not is_processed:
+                # Add EOS token to the end of the sequences if needed
+                first_example = next(iter(dataset))
+                if not is_conversational(first_example):
+                    if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                        map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
+
+                    def add_eos(example, eos_token):
+                        if not example["chosen"].endswith(eos_token):
+                            example["chosen"] = example["chosen"] + eos_token
+                        if "rejected" in example and not example["rejected"].endswith(eos_token):
+                            example["rejected"] = example["rejected"] + eos_token
+                        return example
+
+                    dataset = dataset.map(
+                        add_eos,
+                        fn_kwargs={"eos_token": processing_class.eos_token},
+                        **map_kwargs,
+                    )
+
+                # Tokenize the dataset
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
+
+                def tokenize_fn(example, processing_class):
+                    if "prompt" in example:  # explicit prompt case
+                        example["chosen"] = example["prompt"] + example["chosen"]
+                        example["rejected"] = example["prompt"] + example["rejected"]
+
+                    if is_conversational(example):
+                        chosen_input_ids = processing_class.apply_chat_template(
+                            example["chosen"],
+                            tools=example.get("tools"),
+                            **example.get("chat_template_kwargs", {}),
+                        )
+                        rejected_input_ids = processing_class.apply_chat_template(
+                            example["rejected"],
+                            tools=example.get("tools"),
+                            **example.get("chat_template_kwargs", {}),
+                        )
+                        output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
+                    else:
+                        output = {
+                            "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
+                            "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
+                        }
+                    return output
+
+                dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
+
+            # Filter samples that are longer than `max_length`
+            if args.max_length is not None:
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
+                dataset = dataset.filter(
+                    lambda example: len(example["chosen_input_ids"]) <= args.max_length
+                    and len(example["rejected_input_ids"]) <= args.max_length,
+                    **map_kwargs,
+                )
+
+        return dataset
+
+    def _set_signature_columns_if_needed(self):
+        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+        # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
+        # and "attention_mask").
+        if self._signature_columns is None:
+            self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
+
     def compute_loss(
         self,
-        model: Union[PreTrainedModel, nn.Module],
+        model: nn.Module,
         inputs: dict[str, Union[torch.Tensor, Any]],
-        return_outputs=False,
-        num_items_in_batch=None,
-    ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
-        rewards_chosen = model(
-            input_ids=inputs["input_ids_chosen"],
-            attention_mask=inputs["attention_mask_chosen"],
-            return_dict=True,
-        )["logits"]
-        rewards_rejected = model(
-            input_ids=inputs["input_ids_rejected"],
-            attention_mask=inputs["attention_mask_rejected"],
-            return_dict=True,
-        )["logits"]
-        # calculate loss, optionally modulate with margin
+        return_outputs: bool = False,
+        num_items_in_batch: Optional[torch.Tensor] = None,
+    ):
+        """
+        Compute training loss and additionally compute token accuracies
+        """
+        mode = "train" if self.model.training else "eval"
+
+        # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
+        inputs["use_cache"] = False
+        outputs = model(**inputs)
+
+        # Split the rewards into chosen and rejected
+        rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
+
+        # Calculate loss, optionally modulate with margin
         if "margin" in inputs:
             loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
         else:
@@ -262,86 +552,45 @@ def compute_loss(
         if self.args.center_rewards_coefficient is not None:
             loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
 
-        if return_outputs:
-            return loss, {
-                "rewards_chosen": rewards_chosen,
-                "rewards_rejected": rewards_rejected,
-            }
-        return loss
-
-    def prediction_step(
-        self,
-        model: Union[PreTrainedModel, nn.Module],
-        inputs: dict[str, Union[torch.Tensor, Any]],
-        prediction_loss_only: bool,
-        ignore_keys: Optional[list[str]] = None,
-    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
-        inputs = self._prepare_inputs(inputs)
-        if ignore_keys is None:
-            if hasattr(self.model, "config"):
-                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
-            else:
-                ignore_keys = []
+        if mode == "train":
+            num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
+            self._total_train_tokens += num_tokens_in_batch
+        self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
 
+        # Compute min, mean, max, accuracy and margin
         with torch.no_grad():
-            loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
-
-        if prediction_loss_only:
-            return (loss, None, None)
-
-        loss = loss.detach()
-        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
-        logits = nested_detach(logits)
-        # Stack accepted against rejected, mean over logits
-        # and softmax to get preferences between accepted and rejected to sum to 1
-        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
-
-        labels = torch.zeros(logits.shape[0])
-        labels = self._prepare_inputs(labels)
-
-        return loss, logits, labels
-
-    def evaluate(self, *args, **kwargs):
-        num_print_samples = kwargs.pop("num_print_samples", 4)
-        self.visualize_samples(num_print_samples)
-        return super().evaluate(*args, **kwargs)
-
-    def visualize_samples(self, num_print_samples: int):
-        """
-        Visualize the reward model logits prediction
-
-        Args:
-            num_print_samples (`int`, defaults to `4`):
-                The number of samples to print. Set to `-1` to print all samples.
-        """
-        eval_dataloader = self.get_eval_dataloader()
-        table = defaultdict(list)
-        for _, inputs in enumerate(eval_dataloader):
-            _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
-            chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
-            rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
-            table["chosen_text"].extend(gather_object(chosen_text))
-            table["rejected_text"].extend(gather_object(rejected_text))
-            table["logits"].extend(
-                gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
-            )
-            if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
-                break
-        df = pd.DataFrame(table)
-        if self.accelerator.process_index == 0:
-            if is_rich_available():
-                print_rich_table(df[:num_print_samples])
-            if "wandb" in self.args.report_to:
-                import wandb
-
-                if wandb.run is not None:
-                    wandb.log({"completions": wandb.Table(dataframe=df)})
-
-            if "comet_ml" in self.args.report_to:
-                log_table_to_comet_experiment(
-                    name="completions.csv",
-                    table=df,
-                )
+            all_rewards = self.accelerator.gather(outputs.logits)
+            self._metrics[mode]["min_reward"].append(all_rewards.min().item())
+            self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
+            self._metrics[mode]["max_reward"].append(all_rewards.max().item())
+
+            mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
+            mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
+            self._metrics[mode]["accuracy"].append(mean_accuracy)
+
+            mean_margin = (rewards_chosen - rewards_rejected).mean()
+            mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
+            self._metrics[mode]["margin"].append(mean_margin.item())
+
+        return (loss, outputs) if return_outputs else loss
+
+    # Override training step to add activation offloading context.
+    def training_step(self, *args, **kwargs):
+        with self.maybe_activation_offload_context:
+            return super().training_step(*args, **kwargs)
+
+    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+        mode = "train" if self.model.training else "eval"
+        metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()}  # average the metrics
+
+        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
+        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
+        if mode == "eval":
+            metrics = {f"eval_{key}": val for key, val in metrics.items()}
+
+        logs.update(metrics)
+        super().log(logs, start_time)
+        self._metrics[mode].clear()
 
     # Ensure the model card is saved along with the checkpoint
     def _save_checkpoint(self, model, trial):
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 6f5db07ba50..c0f26bc9373 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -15,10 +15,9 @@
 import contextlib
 import os
 from collections import defaultdict
-from collections.abc import Mapping
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Callable, Optional, TypeVar, Union
+from typing import Any, Callable, Optional, Union
 
 import torch
 import torch.nn as nn
@@ -56,6 +55,7 @@
     entropy_from_logits,
     flush_left,
     pad,
+    remove_none_values,
     selective_log_softmax,
 )
 
@@ -66,7 +66,6 @@
 
 logger = logging.get_logger(__name__)
 
-TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
 
 FLASH_ATTENTION_VARIANTS = {
     "flash_attention_2",
@@ -77,38 +76,6 @@
 }
 
 
-def remove_none_values(example: TListOrMapping) -> TListOrMapping:
-    """
-    Recursively removes entries with `None` values from a nested structure (list or dictionary).
-
-    Args:
-        example (`list` or `Mapping`):
-            Input nested structure (list or dictionary) from which to remove `None`.
-
-    Example:
-    ```python
-    >>> [
-    ...     {
-    ...         "a": {"aa": None, "ab": 1},
-    ...         "b": "my_string",
-    ...     }
-    ... ]
-    >>> remove_none_values(example)
-    [{'a': {'ab': 1}, 'b': 'my_string'}]
-    ```
-    """
-    if isinstance(example, list):
-        return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
-    elif isinstance(example, Mapping):
-        return {
-            key: remove_none_values(value) if isinstance(value, (dict, list)) else value
-            for key, value in example.items()
-            if value is not None
-        }
-    else:
-        raise TypeError("Input must be a list or a dictionary.")
-
-
 def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
     return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
 
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 1ef75bd0da2..a12fdec7b44 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -18,11 +18,12 @@
 import os
 import random
 import socket
-from collections.abc import Sequence, Sized
+import warnings
+from collections.abc import Mapping, Sequence, Sized
 from dataclasses import dataclass, field
 from importlib.metadata import version
 from itertools import accumulate
-from typing import Any, Literal, Optional, Union
+from typing import Any, Literal, Optional, TypeVar, Union
 
 import numpy as np
 import pandas as pd
@@ -219,6 +220,10 @@ class RewardDataCollatorWithPadding:
     r"""
     Reward DataCollator class that pads the inputs to the maximum length of the batch.
 
+    > [!WARNING]
+    > This class is deprecated and will be removed in version 0.27.0. Please use
+    `trl.trainer.reward_trainer.DataCollatorForPreference` instead.
+
     Args:
         tokenizer (`PreTrainedTokenizerBase`):
             The tokenizer used for encoding the data.
@@ -235,6 +240,14 @@ class RewardDataCollatorWithPadding:
     pad_to_multiple_of: Optional[int] = None
     return_tensors: str = "pt"
 
+    def __init__(self, *args, **kwargs) -> None:
+        warnings.warn(
+            "The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use "
+            "`trl.trainer.reward_trainer.DataCollatorForPreference` instead.",
+            DeprecationWarning,
+        )
+        super().__init__(*args, **kwargs)
+
     def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
         features_chosen = []
         features_rejected = []
@@ -1241,6 +1254,10 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
     """
     Decodes the input tensor and strips the padding tokens.
 
+    > [!WARNING]
+    > This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy
+    > the code into your codebase and use it from there.
+
     Args:
         inputs (`torch.Tensor`):
             The input tensor to be decoded.
@@ -1251,6 +1268,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
         `list[str]`:
             The list of decoded strings with padding tokens stripped.
     """
+    warnings.warn(
+        "The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want "
+        "to keep using it, please copy the code into your codebase and use it from there.",
+        DeprecationWarning,
+    )
     decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
     return [d.replace(tokenizer.pad_token, "") for d in decoded]
 
@@ -1264,6 +1286,7 @@ def generate_model_card(
     wandb_url: Optional[str],
     trainer_name: str,
     trainer_citation: Optional[str] = None,
+    template_file: Optional[str] = None,
     paper_title: Optional[str] = None,
     paper_id: Optional[str] = None,
     comet_url: Optional[str] = None,
@@ -1290,6 +1313,8 @@ def generate_model_card(
             Trainer name.
         trainer_citation (`str` or `None`, defaults to `None`):
             Trainer citation as a BibTeX entry.
+        template_file (`str` *optional*):
+            Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`.
         paper_title (`str` or `None`, defaults to `None`):
             Paper title.
         paper_id (`str` or `None`, defaults to `None`):
@@ -1307,9 +1332,10 @@ def generate_model_card(
         model_name=model_name,
         tags=["generated_from_trainer", *tags],
     )
+    template_file = template_file or "lm_model_card.md"
     card = ModelCard.from_template(
         card_data,
-        template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")),
+        template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")),
         base_model=base_model,
         model_name=model_name,
         hub_model_id=hub_model_id,
@@ -1956,6 +1982,41 @@ def process_sequence(ids, mask):
     return torch.stack(truncated_seq), torch.stack(truncated_mask)
 
 
+TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
+
+
+def remove_none_values(example: TListOrMapping) -> TListOrMapping:
+    """
+    Recursively removes entries with `None` values from a nested structure (list or dictionary).
+
+    Args:
+        example (`list` or `Mapping`):
+            Input nested structure (list or dictionary) from which to remove `None`.
+
+    Example:
+    ```python
+    >>> [
+    ...     {
+    ...         "a": {"aa": None, "ab": 1},
+    ...         "b": "my_string",
+    ...     }
+    ... ]
+    >>> remove_none_values(example)
+    [{'a': {'ab': 1}, 'b': 'my_string'}]
+    ```
+    """
+    if isinstance(example, list):
+        return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
+    elif isinstance(example, Mapping):
+        return {
+            key: remove_none_values(value) if isinstance(value, (dict, list)) else value
+            for key, value in example.items()
+            if value is not None
+        }
+    else:
+        raise TypeError("Input must be a list or a dictionary.")
+
+
 def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
     """
     Create a model from a given path using the specified initialization arguments.