From 562f6437d0afccfa3a7f5c46d71e45147a7b9a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 20:28:07 +0000 Subject: [PATCH 01/62] from sft --- trl/trainer/reward_config.py | 218 +++++- trl/trainer/reward_trainer.py | 1320 +++++++++++++++++++++++++++------ 2 files changed, 1267 insertions(+), 271 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 9a3aabc39ee..a5b9f2c4267 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -13,41 +13,99 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from transformers import TrainingArguments @dataclass -class RewardConfig(TrainingArguments): +class SFTConfig(TrainingArguments): r""" - Configuration class for the [`RewardTrainer`]. + Configuration class for the [`SFTTrainer`]. - This class includes only the parameters that are specific to Reward training. For a full list of training - arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this - class may differ from those in [`~transformers.TrainingArguments`]. + This class includes only the parameters that are specific to SFT training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the - limit. This argument is required if you want to use the default data collator. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to + include the load balancing/auxilliary loss as a part of the final loss, remember to set + `output_router_logits=True` in this dictionary. + chat_template_path (`str`, *optional*): + If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory + or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must + ensure that any special tokens referenced in the template are added to the tokenizer and that the model's + embedding layer is resized accordingly. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. + dataset_kwargs (`dict[str, Any]`, *optional*): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` + regardless of the provided value, since preprocessing is done on the fly. dataset_num_proc (`int`, *optional*): Number of processes to use for processing the dataset. - center_rewards_coefficient (`float`, *optional*): - Coefficient to incentivize the reward model to output mean-zero rewards (proposed by - https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the - dataset is pretokenized. + eos_token (`str`, *optional*): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to + `processing_class.eos_token`. + pad_token (`int`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + packing (`bool`, *optional*, defaults to `False`): + Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce + padding. Uses `max_length` to define sequence length. + packing_strategy (`str`, *optional*, defaults to `"bfd"`): + Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When + packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this + parameter. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + eval_packing (`bool`, *optional*): + Whether to pack the eval dataset. If `None`, uses the same value as `packing`. + + > Parameters that control the training + + completion_only_loss (`bool`, *optional*): + Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed + only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If + `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: + loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full + sequence for [language modeling](#language-modeling) datasets. + assistant_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only + on the assistant responses, which is supported only for [conversational](#conversational) datasets. If + `False`, loss is computed on the entire sequence. + loss_type (`str`, *optional*, defaults to `"nll"`): + Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic + Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. """ + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) logging_steps: float = field( default=10, metadata={ @@ -70,37 +128,135 @@ class may differ from those in [`~transformers.TrainingArguments`]. }, ) - max_length: Optional[int] = field( - default=1024, + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " - "exceed the limit. This argument is required if you want to use the default data collator." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string. If you're training a MoE architecture and want to include the " + "load balancing/auxilliary loss as a part of the final loss, remember to set `output_router_logits=True` " + "in this dictionary." }, ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, + chat_template_path: Optional[str] = field( + default=None, + metadata={ + "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local " + "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, " + "you must ensure that any special tokens referenced in the template are added to the tokenizer and " + "that the model's embedding layer is resized accordingly." + }, + ) + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, + ) + dataset_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model " + "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, " + "since preprocessing is done on the fly." + }, ) dataset_num_proc: Optional[int] = field( default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) - center_rewards_coefficient: Optional[float] = field( + eos_token: Optional[str] = field( default=None, metadata={ - "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " - "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." }, ) - remove_unused_columns: bool = field( + pad_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." + }, + ) + packing: bool = field( + default=False, + metadata={ + "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " + "and reduce padding. Uses `max_length` to define sequence length." + }, + ) + packing_strategy: str = field( + default="bfd", + metadata={ + "help": "Strategy for packing sequences. Can be either `'bfd'` (best-fit decreasing, default), or " + "`'wrapped'`." + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " + "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " + "structure. When packing is enabled with strategy `'bfd'`, padding-free is enabled, regardless of the " + "value of this parameter." + }, + ) + pad_to_multiple_of: Optional[int] = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + eval_packing: Optional[bool] = field( + default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + completion_only_loss: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is " + "computed only on the completion, which is supported only for prompt-completion datasets. If `False`, " + "loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: " + "loss is computed on the completion for prompt-completion datasets, and on the full sequence for " + "language modeling datasets." + ) + }, + ) + assistant_only_loss: bool = field( default=False, metadata={ - "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " - "if the dataset is pretokenized." + "help": ( + "Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is " + "computed only on the assistant responses, which is supported only for conversational datasets. If `False`, " + "loss is computed on the entire sequence." + ) + }, + ) + loss_type: str = field( + default="nll", + metadata={ + "help": ( + 'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` ' + "(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)." + ) }, ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - super().__post_init__() diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 804e4d8bf3f..bdaeedfdd2f 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -12,19 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os from collections import defaultdict -from dataclasses import FrozenInstanceError, replace +from collections.abc import Mapping +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union -import pandas as pd import torch import torch.nn as nn +import transformers from accelerate import PartialState, logging -from accelerate.utils import gather_object -from datasets import Dataset +from datasets import Dataset, IterableDataset from transformers import ( + AutoConfig, + AutoProcessor, BaseImageProcessor, DataCollator, FeatureExtractionMixin, @@ -32,198 +35,837 @@ PreTrainedTokenizerBase, ProcessorMixin, Trainer, + TrainingArguments, is_wandb_available, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_callback import TrainerCallback -from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction -from transformers.utils import is_peft_available, is_rich_available +from transformers.utils import is_peft_available -from ..data_utils import maybe_apply_chat_template -from ..models import prepare_peft_model -from .reward_config import RewardConfig +from ..data_utils import ( + apply_chat_template, + is_conversational, + is_conversational_from_value, + maybe_convert_to_chatml, + pack_dataset, + prepare_multimodal_messages, + truncate_dataset, +) +from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model +from .sft_config import SFTConfig from .utils import ( - RewardDataCollatorWithPadding, - compute_accuracy, - decode_and_strip_padding, - disable_dropout_in_model, + entropy_from_logits, + flush_left, generate_model_card, get_comet_experiment_url, - log_table_to_comet_experiment, - print_rich_table, + pad, + selective_log_softmax, ) if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel if is_wandb_available(): import wandb - logger = logging.get_logger(__name__) +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) + -def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]: - """Tokenize a batch from a reward modelling dataset.""" - new_examples = { - "input_ids_chosen": [], - "attention_mask_chosen": [], - "input_ids_rejected": [], - "attention_mask_rejected": [], - } - for chosen, rejected in zip(batch["chosen"], batch["rejected"]): - tokenized_chosen = tokenizer(chosen) - tokenized_rejected = tokenizer(rejected) - new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) - new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. - return new_examples + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") -class RewardTrainer(Trainer): +@dataclass +class DataCollatorForLanguageModeling(DataCollatorMixin): """ - Trainer for custom reward. + Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. + If the input contains a `"completion_mask"`, it is used to set the labels to `-100` for tokens that are not in the + completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not + in the assistant part of the sequence. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + - `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch. + - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to + `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are + not in the assistant part of the sequence are set to -100. Args: - model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*): - Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`RewardConfig`], *optional*): - Training arguments. + pad_token_id (`int`): + Token ID to use for padding. + completion_only_loss (`bool`, *optional*, defaults to `True`): + When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens + that are no in the completion. + padding_free (`bool`, *optional*, defaults to `False`): + If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be + generated accordingly. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[ 1, 2, 3], + [ 4, 5, -100]])} + + >>> # With completion mask + >>> examples = [ + ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + ... {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[-100, 2, 3], + [-100, 5, -100]])} + + >>> # With padding_free + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1]]), + 'position_ids': tensor([[0, 1, 2, 0, 1]]), + 'labels': tensor([[1, 2, 3, 4, 5]])} + ``` + """ + + pad_token_id: int + completion_only_loss: bool = True + padding_free: bool = False + return_position_ids: bool = True + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + # Convert to tensor + input_ids = [torch.tensor(example["input_ids"]) for example in examples] + + # Check if we have meaningful seq_lengths from packing (restarting sequences) + has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free + + # For packing with position_ids, we should NOT create attention_mask as it causes + # FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask + if not has_packed_position_ids: + attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] + + if self.return_position_ids: + if "seq_lengths" in examples[0]: + position_ids = self.get_position_ids_from_packed_seq_lengths( + [example["seq_lengths"] for example in examples] + ) + else: + position_ids = [torch.arange(len(ids)) for ids in input_ids] + if "labels" in examples[0]: + labels = [torch.tensor(example["labels"]) for example in examples] + else: + labels = [torch.tensor(example["input_ids"]) for example in examples] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] + + # If padding_free, flatten everything into a single sequence + output = {} + if self.padding_free: + input_ids = [torch.cat(input_ids, dim=0)] + if not has_packed_position_ids: + attention_mask = [torch.cat(attention_mask, dim=0)] + if self.return_position_ids: + position_ids = [torch.cat(position_ids, dim=0)] + labels = [torch.cat(labels, dim=0)] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.cat(completion_mask, dim=0)] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.cat(assistant_masks, dim=0)] + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if not has_packed_position_ids: + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.return_position_ids: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"] = pad( + labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion + if "assistant_masks" in examples[0]: + assistant_masks = pad( + assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][assistant_masks == 0] = -100 + return output + + @staticmethod + def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]: + """ + Get position IDs for packed sequences. + + Args: + batch_seq_lengths (`list[list[int]]`): + A list of lists containing the lengths of each individual document in the packed batch. + + Return: + `list[torch.Tensor]`: + A list of tensors containing the position IDs for each packed sequence. + """ + # Get lengths per row + example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths] + # Flat list of lengths + batch_seq_lengths = torch.tensor( + [seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths] + ) + position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype) + position_ids[0] = 0 + # Reset position ids to 0 at the start of each sequence + position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1) + position_ids = position_ids.cumsum(0) + # Split back into one tensor per example + return list(position_ids.split(example_lengths)) + + +@dataclass +class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + """ + Data collator for vision-language modeling tasks. + + Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching, + vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, + making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and + image processing on-the-fly to efficiently prepare batches. + + Each input example should be a dictionary containing at least: + - An `"images"` key holding the image data. + - [language modeling](#language-modeling) type: either a `"messages"` key for conversational inputs or a `"text"` + key for standard text inputs. + - [prompt-completion](#prompt-completion) type: keys `"prompt"` and `"completion"` for the prompt and completion. + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs. + - `"attention_mask"`: Tensor indicating attention mask. + - `"pixel_values"`: Tensor representing image pixel values. + - `"labels"`: Tensor for training labels. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"`. + + Args: + processor (`ProcessorMixin`): + The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` and + include a `tokenizer` with a defined `pad_token_id`. + max_length (`int` or `None`, optional, defaults to `None`): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + completion_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the completion part of the sequence. When `True`, the labels for the prompt + part are set to -100. It requires the dataset type to be prompt-completion. + pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + dataset_text_field (`str`, optional, defaults to `"text"`): + Name of the column that contains text data in the dataset. This parameter is only relevant for [standard + datasets format](dataset_formats#standard). + return_tensors (`str`, optional, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + + Example: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling + >>> from transformers import AutoProcessor + + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> collator = DataCollatorForVisionLanguageModeling(processor) + >>> examples = [ + ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, + ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], + [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], + [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], + ..., + [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], + [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], + [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), + 'image_grid_thw': tensor([[1, 4, 4], + [1, 4, 4]]), + 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]])} + ``` + """ + + processor: ProcessorMixin + max_length: Optional[int] = None + completion_only_loss: bool = False # default not used in practice; SFTTrainer always passes the relevant value + pad_to_multiple_of: Optional[int] = None + dataset_text_field: str = "text" + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + if "messages" in examples[0] or self.dataset_text_field in examples[0]: + if self.completion_only_loss: + raise ValueError( + "The `completion_only_loss` argument is not supported for language modeling datasets." + ) + return self._collate_language_modeling(examples) + elif "prompt" in examples[0] and "completion" in examples[0]: + return self._collate_prompt_completion(examples) + else: + raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.") + + def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + images = [example["images"] for example in examples] + + if "messages" in examples[0]: # conversational case + for example in examples: + prepare_multimodal_messages(example["messages"], len(example["images"])) + messages = [example["messages"] for example in examples] + texts = self.processor.apply_chat_template(messages) + elif self.dataset_text_field in examples[0]: # standard case + texts = [example[self.dataset_text_field] for example in examples] + else: + raise KeyError( + "The input examples must contain either 'messages' for conversational data or 'text' for standard " + "data." + ) + + output = self.processor( + images=images, + text=texts, + padding=True, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + truncation=self.max_length is not None, + max_length=self.max_length, + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + labels = output["input_ids"].clone() + labels[output["attention_mask"] == 0] = -100 + # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in + # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision + # token definitions vary across architectures. + output["labels"] = labels + return output + + def _collate_prompt_completion(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + if self.pad_to_multiple_of is not None: + raise NotImplementedError( + "Padding to a multiple of a value is not yet implemented for vision-language modeling and " + "prompt-completion data yet." + ) + images = [example["images"] for example in examples] + if is_conversational(examples[0]): # conversational case + for example in examples: + prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) + examples = [apply_chat_template(example, self.processor) for example in examples] + + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + processed_prompts = self.processor( + images=images, + text=prompts, + padding=True, + padding_side="left", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + processed_completions = self.processor( + text=completions, + padding=True, + padding_side="right", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + + # Concatenate prompts and completions + prompt_ids, completion_ids = processed_prompts["input_ids"], processed_completions["input_ids"] + prompt_mask, completion_mask = processed_prompts["attention_mask"], processed_completions["attention_mask"] + input_ids = torch.cat((prompt_ids, completion_ids), dim=1) + attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) + completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) + + # Flush left to reduce padding + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + + # Truncate if necessary + if self.max_length is not None: + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + completion_mask = completion_mask[:, : self.max_length] + + # Create labels and mask padding tokens + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + if self.completion_only_loss: + labels[completion_mask == 0] = -100 + + # Build the output dictionary + output = processed_prompts # we take processed_prompts because it contains the images + output["input_ids"] = input_ids + output["attention_mask"] = attention_mask + output["labels"] = labels + return output + + +def dft_loss(outputs, labels, num_items_in_batch): + """ + DFT loss function, as presented in [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward + Rectification](https://huggingface.co/papers/2508.05629) + """ + labels = nn.functional.pad(labels, (0, 1), value=-100) + shift_labels = labels[..., 1:].contiguous() + loss_mask = shift_labels != -100 + shift_labels[~loss_mask] = 0 + logprobs = selective_log_softmax(outputs.logits, shift_labels) + per_token_loss = -logprobs.exp().detach() * logprobs + loss = (per_token_loss * loss_mask).sum() / num_items_in_batch + return loss + + +class SFTTrainer(Trainer): + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import SFTTrainer + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. + If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss + as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. + args ([`SFTConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - [`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum - length of the sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]): - Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a - dictionary string to float. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model + and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss + function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) + used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean + `compute_result` argument. This will be triggered after the last eval batch to signal that the function + needs to calculate and return the global summary statistics rather than accumulating the batch-level + statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and - return the logits to be used for metrics computation. - peft_config (`dict`, *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be - wrapped with the specified PEFT adapter. + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + formatting_func (`Callable`, *optional*): + Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly + converts the dataset into a [language modeling](#language-modeling) type. """ - _tag_names = ["trl", "reward-trainer"] + _tag_names = ["trl", "sft"] def __init__( self, - model: Optional[Union[PreTrainedModel, nn.Module]] = None, - args: Optional[RewardConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, + model: Union[str, nn.Module, PreTrainedModel], + args: Optional[Union[SFTConfig, TrainingArguments]] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, + peft_config: Optional["PeftConfig"] = None, + formatting_func: Optional[Callable[[dict], str]] = None, ): - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): - model = prepare_peft_model(model, peft_config, args) + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(model) + # Model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model_id) - if compute_metrics is None: - compute_metrics = compute_accuracy + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if data_collator is None: - if processing_class is None: + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + if eos_token_id is None: raise ValueError( - "A processing_class must be specified when using the default RewardDataCollatorWithPadding" + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." ) + tokenizer.eos_token_id = eos_token_id - max_length = args.max_length + if args.chat_template_path is not None: + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + else: + added_tokens = [] - data_collator = RewardDataCollatorWithPadding(processing_class) + # Catch some wrong configurations related to VLMs + if self._is_vlm and args.packing: + raise ValueError( + "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." + ) + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `SFTConfig`." + ) + if self._is_vlm and args.assistant_only_loss: + raise ValueError( + "Assistant-only loss is not yet supported for vision-language models. Please set " + "`assistant_only_loss=False` in the `SFTConfig`." + ) + + # PEFT configuration and model wrapping + if peft_config is not None: + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 - if args.remove_unused_columns: - try: # for bc before https://github.com/huggingface/transformers/pull/25435 - args.remove_unused_columns = False - except FrozenInstanceError: - args = replace(args, remove_unused_columns=False) - # warn users + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Data collator + # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing + # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. + self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd") + use_flash_attention = model.config._attn_implementation in [ + "flash_attention_2", + "flash_attention_3", + "kernels-community/vllm-flash-attn3", + ] + if self.padding_free: + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if args.packing and args.packing_strategy == "wrapped": + logger.warning( + "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " + "recommended. Please refer to the documentation to understand why this is not recommended." + ) + if not use_flash_attention: logger.warning( - "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" - " we have set it for you, but you should do it yourself in the future.", + "Padding-free training is enabled, but the attention implementation is not set to " + "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " + "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " + "other implementations may lead to unexpected behavior. To ensure compatibility, set " + "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " + "attention mechanism can handle flattened sequences." + ) + if args.per_device_train_batch_size == 1 and not args.packing: + logger.warning( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." ) - self.use_reward_data_collator = True + # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format + # is prompt-completion, and False if the dataset format is language modeling. + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample else: - self.use_reward_data_collator = False - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result, - # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point - # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's - # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been - # issued. - model.warnings_issued["estimate_tokens"] = True - - if "input_ids_chosen" not in train_dataset.column_names: - with PartialState().main_process_first(): - fn_kwargs = {"tokenizer": processing_class} - train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}) - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, + self.completion_only_loss = args.completion_only_loss + + if data_collator is None and not self._is_vlm: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - train_dataset = train_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + completion_only_loss=self.completion_only_loss, + padding_free=self.padding_free, + # Using position_ids without flash_attn hurts the training + return_position_ids=use_flash_attention, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vlm: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + completion_only_loss=self.completion_only_loss, + pad_to_multiple_of=args.pad_to_multiple_of, + dataset_text_field=args.dataset_text_field, + ) + + if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: + logger.warning( + "You are using packing, but the attention implementation is not set to 'flash_attention_2' or " + "'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash " + "Attention is the only known attention mechanisms that reliably support this. Using other " + "implementations may lead to cross-contamination between batches. To avoid this, either disable " + "packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or " + "`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration." + ) + if args.assistant_only_loss and not is_conversational(dataset_sample): + raise ValueError( + "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " + "supported for conversational datasets." + ) + + # Dataset + # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where + # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. + skip_prepare_dataset = ( + args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm + ) + if not skip_prepare_dataset: + if self.completion_only_loss and formatting_func: + raise ValueError( + "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " + "Using a formatter converts the dataset to a language modeling type, conflicting with " + "completion-only loss. To resolve this, apply your formatting function before passing the " + "dataset, or disable `completion_only_loss` in `SFTConfig`." ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class} - ) - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs=fn_kwargs, - batched=True, - num_proc=args.dataset_num_proc, - ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - eval_dataset = eval_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length - and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) + if eval_dataset is not None: + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" ) + # Loss function + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " + "`compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + super().__init__( model=model, args=args, @@ -231,124 +873,322 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, + compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - def compute_loss( + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + + def _prepare_dataset( self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - rewards_chosen = model( - input_ids=inputs["input_ids_chosen"], - attention_mask=inputs["attention_mask_chosen"], - return_dict=True, - )["logits"] - rewards_rejected = model( - input_ids=inputs["input_ids_rejected"], - attention_mask=inputs["attention_mask_rejected"], - return_dict=True, - )["logits"] - # calculate loss, optionally modulate with margin - if "margin" in inputs: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() - else: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: SFTConfig, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) - if self.args.center_rewards_coefficient is not None: - loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names - if return_outputs: - return loss, { - "rewards_chosen": rewards_chosen, - "rewards_rejected": rewards_rejected, - } - return loss + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] + with PartialState().main_process_first(): + # Apply the formatting function if any + if formatting_func is not None and is_processed: + logger.warning( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "`formatting_func` or pass a dataset that is not already processed.", + ) - with torch.no_grad(): - loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + if formatting_func is not None and not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" - if prediction_loss_only: - return (loss, None, None) + def _func(example): + return {"text": formatting_func(example)} - loss = loss.detach() - logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) - logits = nested_detach(logits) - # Stack accepted against rejected, mean over logits - # and softmax to get preferences between accepted and rejected to sum to 1 - logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + dataset = dataset.map(_func, batched=False, **map_kwargs) - labels = torch.zeros(logits.shape[0]) - labels = self._prepare_inputs(labels) + if not is_processed: + # Convert the dataset to ChatML if needed + first_example = next(iter(dataset)) + if is_conversational_from_value(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = next(iter(dataset)).keys() + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, + ) - return loss, logits, labels + # Apply the chat template if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - def evaluate(self, *args, **kwargs): - num_print_samples = kwargs.pop("num_print_samples", 4) - self.visualize_samples(num_print_samples) - return super().evaluate(*args, **kwargs) + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example - def visualize_samples(self, num_print_samples: int): - """ - Visualize the reward model logits prediction + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) - Args: - num_print_samples (`int`, defaults to `4`): - The number of samples to print. Set to `-1` to print all samples. - """ - eval_dataloader = self.get_eval_dataloader() - table = defaultdict(list) - for _, inputs in enumerate(eval_dataloader): - _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) - chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class) - rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class) - table["chosen_text"].extend(gather_object(chosen_text)) - table["rejected_text"].extend(gather_object(rejected_text)) - table["logits"].extend( - gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) - ) - if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: - break - df = pd.DataFrame(table) - if self.accelerator.process_index == 0: - if is_rich_available(): - print_rich_table(df[:num_print_samples]) - if "wandb" in self.args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="completions.csv", - table=df, + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(example, processing_class, dataset_text_field, assistant_only_loss): + if "prompt" in example: # prompt-completion case + output = {} + if is_conversational(example): + prompt_ids = processing_class.apply_chat_template( + example["prompt"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_completion_processed = processing_class.apply_chat_template( + example["prompt"] + example["completion"], + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + prompt_completion_ids = prompt_completion_processed["input_ids"] + if "assistant_masks" in prompt_completion_processed: + output["assistant_masks"] = prompt_completion_processed["assistant_masks"] + else: + prompt_ids = processing_class(text=example["prompt"])["input_ids"] + prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ + "input_ids" + ] + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + # Create a completion mask + completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) + output["input_ids"] = prompt_completion_ids + output["completion_mask"] = completion_mask + + else: # language modeling case + if is_conversational(example): + processed = processing_class.apply_chat_template( + example["messages"], + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + if "assistant_masks" in processed and 1 not in processed["assistant_masks"]: + raise RuntimeError( + "You're using `assistant_only_loss=True`, but at least one example has no " + "assistant tokens. This usually means the tokenizer's chat template doesn't " + "generate assistant masks — it may be missing the `{% generation %}` keyword. Please " + "check the template and ensure it's correctly configured to support assistant " + "masking." + ) + output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed} + else: + output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]} + return output + + dataset = dataset.map( + tokenize, + fn_kwargs={ + "processing_class": processing_class, + "dataset_text_field": args.dataset_text_field, + "assistant_only_loss": args.assistant_only_loss, + }, + **map_kwargs, ) + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + + columns = ["input_ids"] + if "completion_mask" in dataset.column_names: + columns.append("completion_mask") + if "assistant_masks" in dataset.column_names: + columns.append("assistant_masks") + + dataset = dataset.select_columns(columns) + + # Packing adds new column "seq_lengths" needed for document aware FlashAttention + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) + elif args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + # For Liger kernel, ensure only the essential columns + if args.use_liger_kernel: + collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} + dataset = dataset.select_columns(collator_expected_keys.intersection(dataset.column_names)) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the + # dataset. So we need to override the default signature columns to include "completion_mask" as well. + if self._signature_columns is None: + if self._is_vlm: + self._signature_columns = ["messages", "prompt", "completion", "images"] + else: + self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + mode = "train" if self.model.training else "eval" + + # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used. + # This can be removed when this issue is fixed. + labels = inputs["labels"] + + # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing + inputs["use_cache"] = False + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + # Compute entropy + if not self.args.use_liger_kernel: # liger doesn't return logits + with torch.no_grad(): + per_token_entropy = entropy_from_logits(outputs.logits) + if "attention_mask" in inputs: + attention_mask = inputs["attention_mask"] + # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). + virtual_attention_mask = torch.ones( + attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device + ) + attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) + entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() + elif "position_ids" in inputs: + entropy = torch.mean(per_token_entropy) + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) + + if mode == "train": + # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, + # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. + if "attention_mask" in inputs: + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + elif "position_ids" in inputs: + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if not self.args.use_liger_kernel: + with torch.no_grad(): + if "shift_labels" in inputs: + # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: + # - The first discarded token from inputs["labels"] actually belongs to process n-1 + # - The last logits require the label from process n+1 + shift_logits = outputs.logits.contiguous() + shift_labels = inputs["shift_labels"] + else: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do + # not correspond to actual input labels. + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] + + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs.update(metrics) + super().log(logs, start_time) + self._metrics[mode].clear() + # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): if self.args.hub_model_id is None: @@ -404,10 +1244,10 @@ def create_model_card( model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, - tags=tags, + tags=list(tags), wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), - trainer_name="Reward", + trainer_name="SFT", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) From 24a27dbbe7c47b97f0c90159d1b479729face155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 20:36:05 +0000 Subject: [PATCH 02/62] remove vision and dft --- trl/trainer/reward_trainer.py | 305 +++------------------------------- 1 file changed, 19 insertions(+), 286 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index bdaeedfdd2f..1348e57bc0c 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -27,13 +27,10 @@ from datasets import Dataset, IterableDataset from transformers import ( AutoConfig, - AutoProcessor, - BaseImageProcessor, + AutoTokenizer, DataCollator, - FeatureExtractionMixin, PreTrainedModel, PreTrainedTokenizerBase, - ProcessorMixin, Trainer, TrainingArguments, is_wandb_available, @@ -44,23 +41,19 @@ from transformers.utils import is_peft_available from ..data_utils import ( - apply_chat_template, is_conversational, is_conversational_from_value, maybe_convert_to_chatml, pack_dataset, - prepare_multimodal_messages, truncate_dataset, ) from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .sft_config import SFTConfig from .utils import ( entropy_from_logits, - flush_left, generate_model_card, get_comet_experiment_url, pad, - selective_log_softmax, ) @@ -285,215 +278,6 @@ def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) return list(position_ids.split(example_lengths)) -@dataclass -class DataCollatorForVisionLanguageModeling(DataCollatorMixin): - """ - Data collator for vision-language modeling tasks. - - Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching, - vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, - making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and - image processing on-the-fly to efficiently prepare batches. - - Each input example should be a dictionary containing at least: - - An `"images"` key holding the image data. - - [language modeling](#language-modeling) type: either a `"messages"` key for conversational inputs or a `"text"` - key for standard text inputs. - - [prompt-completion](#prompt-completion) type: keys `"prompt"` and `"completion"` for the prompt and completion. - - The collator outputs a dictionary including: - - `"input_ids"`: Tensor of token IDs. - - `"attention_mask"`: Tensor indicating attention mask. - - `"pixel_values"`: Tensor representing image pixel values. - - `"labels"`: Tensor for training labels. - - Additional keys may be present depending on the processor, such as `"image_grid_thw"`. - - Args: - processor (`ProcessorMixin`): - The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` and - include a `tokenizer` with a defined `pad_token_id`. - max_length (`int` or `None`, optional, defaults to `None`): - Maximum sequence length for input tokens. If `None`, no truncation is applied. - completion_only_loss (`bool`, *optional*, defaults to `False`): - Whether to compute loss only on the completion part of the sequence. When `True`, the labels for the prompt - part are set to -100. It requires the dataset type to be prompt-completion. - pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): - If set, the sequences will be padded to a multiple of this value. - dataset_text_field (`str`, optional, defaults to `"text"`): - Name of the column that contains text data in the dataset. This parameter is only relevant for [standard - datasets format](dataset_formats#standard). - return_tensors (`str`, optional, defaults to `"pt"`): - The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. - - Example: - ```python - >>> from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling - >>> from transformers import AutoProcessor - - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> collator = DataCollatorForVisionLanguageModeling(processor) - >>> examples = [ - ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, - ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}, - ... ] - >>> collator(examples) - {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, - 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, - 419, 30, 151645, 198], - [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, - 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, - 2168, 13, 151645, 198]]), - 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), - 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], - [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], - [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], - ..., - [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], - [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], - [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), - 'image_grid_thw': tensor([[1, 4, 4], - [1, 4, 4]]), - 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, - 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, - 419, 30, 151645, 198], - [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, - 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, - 2168, 13, 151645, 198]])} - ``` - """ - - processor: ProcessorMixin - max_length: Optional[int] = None - completion_only_loss: bool = False # default not used in practice; SFTTrainer always passes the relevant value - pad_to_multiple_of: Optional[int] = None - dataset_text_field: str = "text" - return_tensors: str = "pt" - - def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - if "messages" in examples[0] or self.dataset_text_field in examples[0]: - if self.completion_only_loss: - raise ValueError( - "The `completion_only_loss` argument is not supported for language modeling datasets." - ) - return self._collate_language_modeling(examples) - elif "prompt" in examples[0] and "completion" in examples[0]: - return self._collate_prompt_completion(examples) - else: - raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.") - - def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - images = [example["images"] for example in examples] - - if "messages" in examples[0]: # conversational case - for example in examples: - prepare_multimodal_messages(example["messages"], len(example["images"])) - messages = [example["messages"] for example in examples] - texts = self.processor.apply_chat_template(messages) - elif self.dataset_text_field in examples[0]: # standard case - texts = [example[self.dataset_text_field] for example in examples] - else: - raise KeyError( - "The input examples must contain either 'messages' for conversational data or 'text' for standard " - "data." - ) - - output = self.processor( - images=images, - text=texts, - padding=True, - padding_side="right", - pad_to_multiple_of=self.pad_to_multiple_of, - truncation=self.max_length is not None, - max_length=self.max_length, - return_tensors=self.return_tensors, - add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens - ) - labels = output["input_ids"].clone() - labels[output["attention_mask"] == 0] = -100 - # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in - # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision - # token definitions vary across architectures. - output["labels"] = labels - return output - - def _collate_prompt_completion(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: - if self.pad_to_multiple_of is not None: - raise NotImplementedError( - "Padding to a multiple of a value is not yet implemented for vision-language modeling and " - "prompt-completion data yet." - ) - images = [example["images"] for example in examples] - if is_conversational(examples[0]): # conversational case - for example in examples: - prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) - examples = [apply_chat_template(example, self.processor) for example in examples] - - prompts = [example["prompt"] for example in examples] - completions = [example["completion"] for example in examples] - - processed_prompts = self.processor( - images=images, - text=prompts, - padding=True, - padding_side="left", - return_tensors=self.return_tensors, - add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens - ) - processed_completions = self.processor( - text=completions, - padding=True, - padding_side="right", - return_tensors=self.return_tensors, - add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens - ) - - # Concatenate prompts and completions - prompt_ids, completion_ids = processed_prompts["input_ids"], processed_completions["input_ids"] - prompt_mask, completion_mask = processed_prompts["attention_mask"], processed_completions["attention_mask"] - input_ids = torch.cat((prompt_ids, completion_ids), dim=1) - attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) - completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) - - # Flush left to reduce padding - attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) - - # Truncate if necessary - if self.max_length is not None: - input_ids = input_ids[:, : self.max_length] - attention_mask = attention_mask[:, : self.max_length] - completion_mask = completion_mask[:, : self.max_length] - - # Create labels and mask padding tokens - labels = input_ids.clone() - labels[attention_mask == 0] = -100 - if self.completion_only_loss: - labels[completion_mask == 0] = -100 - - # Build the output dictionary - output = processed_prompts # we take processed_prompts because it contains the images - output["input_ids"] = input_ids - output["attention_mask"] = attention_mask - output["labels"] = labels - return output - - -def dft_loss(outputs, labels, num_items_in_batch): - """ - DFT loss function, as presented in [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward - Rectification](https://huggingface.co/papers/2508.05629) - """ - labels = nn.functional.pad(labels, (0, 1), value=-100) - shift_labels = labels[..., 1:].contiguous() - loss_mask = shift_labels != -100 - shift_labels[~loss_mask] = 0 - logprobs = selective_log_softmax(outputs.logits, shift_labels) - per_token_loss = -logprobs.exp().detach() * logprobs - loss = (per_token_loss * loss_mask).sum() / num_items_in_batch - return loss - - class SFTTrainer(Trainer): """ Trainer for Supervised Fine-Tuning (SFT) method. @@ -528,8 +312,7 @@ class SFTTrainer(Trainer): Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model - and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`]. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: @@ -541,10 +324,11 @@ class SFTTrainer(Trainer): The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If `None`, the processing class is loaded from the model's name - with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. - If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): + Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with + [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be + set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the + default. compute_loss_func (`Callable`, *optional*): A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss @@ -594,7 +378,7 @@ def __init__( data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, @@ -643,28 +427,19 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model_id) + processing_class = AutoTokenizer.from_pretrained(model_id) # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - self._is_vlm = True - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - self._is_vlm = False - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if args.eos_token is not None: eos_token = args.eos_token - eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + eos_token_id = processing_class.convert_tokens_to_ids(eos_token) if eos_token_id is None: raise ValueError( f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " "in the vocabulary before using it as an EOS token." ) - tokenizer.eos_token_id = eos_token_id + processing_class.eos_token_id = eos_token_id if args.chat_template_path is not None: if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): @@ -678,22 +453,6 @@ def __init__( else: added_tokens = [] - # Catch some wrong configurations related to VLMs - if self._is_vlm and args.packing: - raise ValueError( - "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." - ) - if self._is_vlm and args.padding_free: - raise ValueError( - "Padding-free training is yet not supported for vision-language models. Please set " - "`padding_free=False` in the `SFTConfig`." - ) - if self._is_vlm and args.assistant_only_loss: - raise ValueError( - "Assistant-only loss is not yet supported for vision-language models. Please set " - "`assistant_only_loss=False` in the `SFTConfig`." - ) - # PEFT configuration and model wrapping if peft_config is not None: if added_tokens: @@ -770,11 +529,11 @@ def __init__( else: self.completion_only_loss = args.completion_only_loss - if data_collator is None and not self._is_vlm: + if data_collator is None: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token + pad_token_id = processing_class.convert_tokens_to_ids(pad_token) if pad_token_id is None: raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " @@ -789,14 +548,6 @@ def __init__( return_position_ids=use_flash_attention, pad_to_multiple_of=args.pad_to_multiple_of, ) - elif data_collator is None and self._is_vlm: - data_collator = DataCollatorForVisionLanguageModeling( - processor=processing_class, - max_length=args.max_length, - completion_only_loss=self.completion_only_loss, - pad_to_multiple_of=args.pad_to_multiple_of, - dataset_text_field=args.dataset_text_field, - ) if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: logger.warning( @@ -814,10 +565,9 @@ def __init__( ) # Dataset - # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where - # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. - skip_prepare_dataset = ( - args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm + # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`. + skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get( + "skip_prepare_dataset", False ) if not skip_prepare_dataset: if self.completion_only_loss and formatting_func: @@ -842,20 +592,6 @@ def __init__( eval_dataset, processing_class, args, packing, formatting_func, "eval" ) - # Loss function - if args.loss_type == "nll": - pass # use the default loss - elif args.loss_type == "dft": - if compute_loss_func is not None: - raise ValueError( - "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " - "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " - "`compute_loss_func` is not allowed." - ) - compute_loss_func = dft_loss - else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") - # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 @@ -896,7 +632,7 @@ def __init__( def _prepare_dataset( self, dataset: Union[Dataset, IterableDataset], - processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + processing_class: PreTrainedTokenizerBase, args: SFTConfig, packing: bool, formatting_func: Optional[Callable[[dict], str]], @@ -1075,10 +811,7 @@ def _set_signature_columns_if_needed(self): # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the # dataset. So we need to override the default signature columns to include "completion_mask" as well. if self._signature_columns is None: - if self._is_vlm: - self._signature_columns = ["messages", "prompt", "completion", "images"] - else: - self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] + self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ From 91e546f5371e2ba5fb024c2524e04f4249d27c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 20:41:02 +0000 Subject: [PATCH 03/62] sft to reward --- trl/trainer/reward_config.py | 20 +++++++-------- trl/trainer/reward_trainer.py | 46 +++++++++++++++++------------------ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index a5b9f2c4267..6abb22c5865 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -19,13 +19,13 @@ @dataclass -class SFTConfig(TrainingArguments): +class RewardConfig(TrainingArguments): r""" - Configuration class for the [`SFTTrainer`]. + Configuration class for the [`RewardTrainer`]. - This class includes only the parameters that are specific to SFT training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. + This class includes only the parameters that are specific to Reward training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the @@ -36,8 +36,8 @@ class SFTConfig(TrainingArguments): model_init_kwargs (`dict[str, Any]`, *optional*): Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` - argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to - include the load balancing/auxilliary loss as a part of the final loss, remember to set + argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want + to include the load balancing/auxilliary loss as a part of the final loss, remember to set `output_router_logits=True` in this dictionary. chat_template_path (`str`, *optional*): If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory @@ -133,9 +133,9 @@ class SFTConfig(TrainingArguments): default=None, metadata={ "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " - "the `SFTTrainer` is provided as a string. If you're training a MoE architecture and want to include the " - "load balancing/auxilliary loss as a part of the final loss, remember to set `output_router_logits=True` " - "in this dictionary." + "the `RewardTrainer` is provided as a string. If you're training a MoE architecture and want to include " + "the load balancing/auxilliary loss as a part of the final loss, remember to set " + "`output_router_logits=True` in this dictionary." }, ) chat_template_path: Optional[str] = field( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 1348e57bc0c..1cc576472b0 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -48,7 +48,7 @@ truncate_dataset, ) from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model -from .sft_config import SFTConfig +from .reward_config import RewardConfig from .utils import ( entropy_from_logits, generate_model_card, @@ -132,7 +132,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): Examples: ```python - >>> from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + >>> from trl.trainer.reward_trainer import DataCollatorForLanguageModeling >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] @@ -278,9 +278,9 @@ def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) return list(position_ids.split(example_lengths)) -class SFTTrainer(Trainer): +class RewardTrainer(Trainer): """ - Trainer for Supervised Fine-Tuning (SFT) method. + Trainer for Outcome-supervised Reward Models (ORM). This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. @@ -288,11 +288,11 @@ class SFTTrainer(Trainer): ```python from datasets import load_dataset - from trl import SFTTrainer + from trl import RewardTrainer dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") - trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer = RewardTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) trainer.train() ``` @@ -308,13 +308,13 @@ class SFTTrainer(Trainer): - A [`~transformers.PreTrainedModel`] object. If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. - args ([`SFTConfig`], *optional*): + args ([`RewardConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`]. + Will default to [`~trainer.reward_trainer.DataCollatorForLanguageModeling`]. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: - [Standard](dataset_formats#standard): Each sample contains plain text. @@ -337,10 +337,10 @@ class SFTTrainer(Trainer): compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing - [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean - `compute_result` argument. This will be triggered after the last eval batch to signal that the function - needs to calculate and return the global summary statistics rather than accumulating the batch-level - statistics. + [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a + boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the + function needs to calculate and return the global summary statistics rather than accumulating the + batch-level statistics. callbacks (list of [`~transformers.TrainerCallback`], *optional*): List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). @@ -369,12 +369,12 @@ class SFTTrainer(Trainer): converts the dataset into a [language modeling](#language-modeling) type. """ - _tag_names = ["trl", "sft"] + _tag_names = ["trl", "reward"] def __init__( self, model: Union[str, nn.Module, PreTrainedModel], - args: Optional[Union[SFTConfig, TrainingArguments]] = None, + args: Optional[Union[RewardConfig, TrainingArguments]] = None, data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, @@ -392,12 +392,12 @@ def __init__( if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path model_name = model_name.split("/")[-1] - args = SFTConfig(f"{model_name}-SFT") - elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + args = RewardConfig(f"{model_name}-Reward") + elif isinstance(args, TrainingArguments) and not isinstance(args, RewardConfig): dict_args = args.to_dict() dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token dict_args.pop("push_to_hub_token") - args = SFTConfig(**dict_args) + args = RewardConfig(**dict_args) # Model model_init_kwargs = args.model_init_kwargs or {} @@ -411,7 +411,7 @@ def __init__( model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) config = AutoConfig.from_pretrained(model_id) @@ -421,7 +421,7 @@ def __init__( model_id = model.config._name_or_path if args.model_init_kwargs is not None: logger.warning( - "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " "The `model_init_kwargs` will be ignored." ) @@ -575,7 +575,7 @@ def __init__( "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " "Using a formatter converts the dataset to a language modeling type, conflicting with " "completion-only loss. To resolve this, apply your formatting function before passing the " - "dataset, or disable `completion_only_loss` in `SFTConfig`." + "dataset, or disable `completion_only_loss` in `RewardConfig`." ) train_dataset = self._prepare_dataset( train_dataset, processing_class, args, args.packing, formatting_func, "train" @@ -633,7 +633,7 @@ def _prepare_dataset( self, dataset: Union[Dataset, IterableDataset], processing_class: PreTrainedTokenizerBase, - args: SFTConfig, + args: RewardConfig, packing: bool, formatting_func: Optional[Callable[[dict], str]], dataset_name: str, @@ -980,7 +980,7 @@ def create_model_card( tags=list(tags), wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), - trainer_name="SFT", + trainer_name="Reward", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) From ce6d95deba34b653b610b3c7100a3483816c9870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 20:42:12 +0000 Subject: [PATCH 04/62] `DataCollatorForLanguageModeling` to `DataCollatorForPreference` --- trl/trainer/reward_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 1cc576472b0..e9ed399658e 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -101,7 +101,7 @@ def remove_none_values(example: TListOrMapping) -> TListOrMapping: @dataclass -class DataCollatorForLanguageModeling(DataCollatorMixin): +class DataCollatorForPreference(DataCollatorMixin): """ Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. @@ -132,9 +132,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): Examples: ```python - >>> from trl.trainer.reward_trainer import DataCollatorForLanguageModeling + >>> from trl.trainer.reward_trainer import DataCollatorForPreference - >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> collator = DataCollatorForPreference(pad_token_id=0) >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] >>> collator(examples) {'input_ids': tensor([[ 1, 2, 3], @@ -162,7 +162,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): [-100, 5, -100]])} >>> # With padding_free - >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator = DataCollatorForPreference(pad_token_id=0, padding_free=True) >>> collator(examples) {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), @@ -312,7 +312,7 @@ class RewardTrainer(Trainer): Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`~trainer.reward_trainer.DataCollatorForLanguageModeling`]. + Will default to [`~trainer.reward_trainer.DataCollatorForPreference`]. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: @@ -540,7 +540,7 @@ def __init__( f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - data_collator = DataCollatorForLanguageModeling( + data_collator = DataCollatorForPreference( pad_token_id=pad_token_id, completion_only_loss=self.completion_only_loss, padding_free=self.padding_free, From de984d59e00808da509bd1255cfa0f6e9bf4af28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 20:43:38 +0000 Subject: [PATCH 05/62] remove support for TrainingArguments --- trl/trainer/reward_trainer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index e9ed399658e..c356f4d5aea 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -374,7 +374,7 @@ class RewardTrainer(Trainer): def __init__( self, model: Union[str, nn.Module, PreTrainedModel], - args: Optional[Union[RewardConfig, TrainingArguments]] = None, + args: Optional[RewardConfig] = None, data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, @@ -393,11 +393,6 @@ def __init__( model_name = model if isinstance(model, str) else model.config._name_or_path model_name = model_name.split("/")[-1] args = RewardConfig(f"{model_name}-Reward") - elif isinstance(args, TrainingArguments) and not isinstance(args, RewardConfig): - dict_args = args.to_dict() - dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token - dict_args.pop("push_to_hub_token") - args = RewardConfig(**dict_args) # Model model_init_kwargs = args.model_init_kwargs or {} From 78a1ee9c784f154e592b688d9b14a8206d926c0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 21:20:01 +0000 Subject: [PATCH 06/62] properly load model --- trl/trainer/reward_trainer.py | 40 +++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index c356f4d5aea..2f6b44c1d34 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -13,9 +13,12 @@ # limitations under the License. import contextlib +import logging import os +import re from collections import defaultdict from collections.abc import Mapping +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Optional, TypeVar, Union @@ -23,16 +26,16 @@ import torch import torch.nn as nn import transformers -from accelerate import PartialState, logging +from accelerate import PartialState +from accelerate.logging import get_logger from datasets import Dataset, IterableDataset from transformers import ( - AutoConfig, + AutoModelForSequenceClassification, AutoTokenizer, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, - TrainingArguments, is_wandb_available, ) from transformers.data.data_collator import DataCollatorMixin @@ -63,7 +66,31 @@ if is_wandb_available(): import wandb -logger = logging.get_logger(__name__) +logger = get_logger(__name__) + + +# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly +# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to +# avoid confusing users. +@contextmanager +def suppress_from_pretrained_warning(logger: logging.Logger): + pattern = re.compile( + r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: " + r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and " + r"inference\.$" + ) + + class _Filter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return not pattern.search(record.getMessage()) + + f = _Filter() + logger.addFilter(f) + try: + yield + finally: + logger.removeFilter(f) + TListOrMapping = TypeVar("TListOrMapping", list, Mapping) @@ -409,9 +436,8 @@ def __init__( "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + with suppress_from_pretrained_warning(transformers.modeling_utils.logger): + model = AutoModelForSequenceClassification.from_pretrained(model_id, **model_init_kwargs) else: model_id = model.config._name_or_path if args.model_init_kwargs is not None: From bfd20060940dc63f95a8ae06d50acf42ecf5caea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 21:30:20 +0000 Subject: [PATCH 07/62] remove position_ids, packing, padding-free, seq_length --- trl/trainer/reward_config.py | 46 +------- trl/trainer/reward_trainer.py | 201 ++++------------------------------ 2 files changed, 22 insertions(+), 225 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 6abb22c5865..119046401c2 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -63,22 +63,9 @@ class may differ from those in [`~transformers.TrainingArguments`]. it falls back to `processing_class.eos_token`. max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. - If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. - packing (`bool`, *optional*, defaults to `False`): - Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce - padding. Uses `max_length` to define sequence length. - packing_strategy (`str`, *optional*, defaults to `"bfd"`): - Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`. - padding_free (`bool`, *optional*, defaults to `False`): - Whether to perform forward passes without padding by flattening all sequences in the batch into a single - continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only - supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When - packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this - parameter. + If `None`, no truncation is applied. pad_to_multiple_of (`int`, *optional*): If set, the sequences will be padded to a multiple of this value. - eval_packing (`bool`, *optional*): - Whether to pack the eval dataset. If `None`, uses the same value as `packing`. > Parameters that control the training @@ -183,42 +170,13 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=1024, metadata={ "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" - "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " - "sequence length." - }, - ) - packing: bool = field( - default=False, - metadata={ - "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " - "and reduce padding. Uses `max_length` to define sequence length." - }, - ) - packing_strategy: str = field( - default="bfd", - metadata={ - "help": "Strategy for packing sequences. Can be either `'bfd'` (best-fit decreasing, default), or " - "`'wrapped'`." - }, - ) - padding_free: bool = field( - default=False, - metadata={ - "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " - "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " - "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " - "structure. When packing is enabled with strategy `'bfd'`, padding-free is enabled, regardless of the " - "value of this parameter." + "the right. If `None`, no truncation is applied." }, ) pad_to_multiple_of: Optional[int] = field( default=None, metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, ) - eval_packing: Optional[bool] = field( - default=None, - metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, - ) # Parameters that control the training completion_only_loss: Optional[bool] = field( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 2f6b44c1d34..0a792d26fbb 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -43,13 +43,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..data_utils import ( - is_conversational, - is_conversational_from_value, - maybe_convert_to_chatml, - pack_dataset, - truncate_dataset, -) +from ..data_utils import is_conversational, is_conversational_from_value, maybe_convert_to_chatml, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig from .utils import ( @@ -138,7 +132,6 @@ class DataCollatorForPreference(DataCollatorMixin): in the assistant part of the sequence. The collator returns a dictionary containing the following keys: - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. - - `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch. - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are not in the assistant part of the sequence are set to -100. @@ -149,9 +142,6 @@ class DataCollatorForPreference(DataCollatorMixin): completion_only_loss (`bool`, *optional*, defaults to `True`): When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens that are no in the completion. - padding_free (`bool`, *optional*, defaults to `False`): - If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be - generated accordingly. pad_to_multiple_of (`int`, *optional*): If set, the sequences will be padded to a multiple of this value. return_tensors (`str`, *optional*, defaults to `"pt"`): @@ -168,8 +158,6 @@ class DataCollatorForPreference(DataCollatorMixin): [ 4, 5, 0]]), 'attention_mask': tensor([[ 1, 1, 1], [ 1, 1, 0]]), - 'position_ids': tensor([[0, 1, 2], - [0, 1, 0]]), 'labels': tensor([[ 1, 2, 3], [ 4, 5, -100]])} @@ -183,47 +171,21 @@ class DataCollatorForPreference(DataCollatorMixin): [ 4, 5, 0]]), 'attention_mask': tensor([[ 1, 1, 1], [ 1, 1, 0]]), - 'position_ids': tensor([[0, 1, 2], - [0, 1, 0]]), 'labels': tensor([[-100, 2, 3], [-100, 5, -100]])} - - >>> # With padding_free - >>> collator = DataCollatorForPreference(pad_token_id=0, padding_free=True) - >>> collator(examples) - {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), - 'attention_mask': tensor([[1, 1, 1, 1, 1]]), - 'position_ids': tensor([[0, 1, 2, 0, 1]]), - 'labels': tensor([[1, 2, 3, 4, 5]])} ``` """ pad_token_id: int completion_only_loss: bool = True - padding_free: bool = False - return_position_ids: bool = True pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: # Convert to tensor input_ids = [torch.tensor(example["input_ids"]) for example in examples] + attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] - # Check if we have meaningful seq_lengths from packing (restarting sequences) - has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free - - # For packing with position_ids, we should NOT create attention_mask as it causes - # FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask - if not has_packed_position_ids: - attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] - - if self.return_position_ids: - if "seq_lengths" in examples[0]: - position_ids = self.get_position_ids_from_packed_seq_lengths( - [example["seq_lengths"] for example in examples] - ) - else: - position_ids = [torch.arange(len(ids)) for ids in input_ids] if "labels" in examples[0]: labels = [torch.tensor(example["labels"]) for example in examples] else: @@ -233,19 +195,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d if "assistant_masks" in examples[0]: assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] - # If padding_free, flatten everything into a single sequence output = {} - if self.padding_free: - input_ids = [torch.cat(input_ids, dim=0)] - if not has_packed_position_ids: - attention_mask = [torch.cat(attention_mask, dim=0)] - if self.return_position_ids: - position_ids = [torch.cat(position_ids, dim=0)] - labels = [torch.cat(labels, dim=0)] - if self.completion_only_loss and "completion_mask" in examples[0]: - completion_mask = [torch.cat(completion_mask, dim=0)] - if "assistant_masks" in examples[0]: - assistant_masks = [torch.cat(assistant_masks, dim=0)] # Pad output["input_ids"] = pad( @@ -254,14 +204,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ) - if not has_packed_position_ids: - output["attention_mask"] = pad( - attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - if self.return_position_ids: - output["position_ids"] = pad( - position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) output["labels"] = pad( labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of ) @@ -277,33 +222,6 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d output["labels"][assistant_masks == 0] = -100 return output - @staticmethod - def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]: - """ - Get position IDs for packed sequences. - - Args: - batch_seq_lengths (`list[list[int]]`): - A list of lists containing the lengths of each individual document in the packed batch. - - Return: - `list[torch.Tensor]`: - A list of tensors containing the position IDs for each packed sequence. - """ - # Get lengths per row - example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths] - # Flat list of lengths - batch_seq_lengths = torch.tensor( - [seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths] - ) - position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype) - position_ids[0] = 0 - # Reset position ids to 0 at the start of each sequence - position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1) - position_ids = position_ids.cumsum(0) - # Split back into one tensor per example - return list(position_ids.split(example_lengths)) - class RewardTrainer(Trainer): """ @@ -510,38 +428,6 @@ def __init__( self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) # Data collator - # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing - # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. - self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd") - use_flash_attention = model.config._attn_implementation in [ - "flash_attention_2", - "flash_attention_3", - "kernels-community/vllm-flash-attn3", - ] - if self.padding_free: - if data_collator is not None: - raise ValueError("Passing a custom data collator is not supported when using padding-free.") - if args.packing and args.packing_strategy == "wrapped": - logger.warning( - "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " - "recommended. Please refer to the documentation to understand why this is not recommended." - ) - if not use_flash_attention: - logger.warning( - "Padding-free training is enabled, but the attention implementation is not set to " - "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " - "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " - "other implementations may lead to unexpected behavior. To ensure compatibility, set " - "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " - "attention mechanism can handle flattened sequences." - ) - if args.per_device_train_batch_size == 1 and not args.packing: - logger.warning( - "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " - "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " - "to at least 2." - ) - # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format # is prompt-completion, and False if the dataset format is language modeling. dataset_sample = next(iter(train_dataset)) @@ -564,21 +450,9 @@ def __init__( data_collator = DataCollatorForPreference( pad_token_id=pad_token_id, completion_only_loss=self.completion_only_loss, - padding_free=self.padding_free, - # Using position_ids without flash_attn hurts the training - return_position_ids=use_flash_attention, pad_to_multiple_of=args.pad_to_multiple_of, ) - if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: - logger.warning( - "You are using packing, but the attention implementation is not set to 'flash_attention_2' or " - "'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash " - "Attention is the only known attention mechanisms that reliably support this. Using other " - "implementations may lead to cross-contamination between batches. To avoid this, either disable " - "packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or " - "`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration." - ) if args.assistant_only_loss and not is_conversational(dataset_sample): raise ValueError( "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " @@ -598,20 +472,15 @@ def __init__( "completion-only loss. To resolve this, apply your formatting function before passing the " "dataset, or disable `completion_only_loss` in `RewardConfig`." ) - train_dataset = self._prepare_dataset( - train_dataset, processing_class, args, args.packing, formatting_func, "train" - ) + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, formatting_func, "train") if eval_dataset is not None: - packing = args.packing if args.eval_packing is None else args.eval_packing if isinstance(eval_dataset, dict): eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + key: self._prepare_dataset(dataset, processing_class, args, formatting_func, key) for key, dataset in eval_dataset.items() } else: - eval_dataset = self._prepare_dataset( - eval_dataset, processing_class, args, packing, formatting_func, "eval" - ) + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, formatting_func, "eval") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} @@ -655,7 +524,6 @@ def _prepare_dataset( dataset: Union[Dataset, IterableDataset], processing_class: PreTrainedTokenizerBase, args: RewardConfig, - packing: bool, formatting_func: Optional[Callable[[dict], str]], dataset_name: str, ) -> Union[Dataset, IterableDataset]: @@ -798,30 +666,14 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss) **map_kwargs, ) - # Pack or truncate - if packing: - if args.max_length is None: - raise ValueError("When packing is enabled, `max_length` can't be `None`.") - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Packing {dataset_name} dataset" - - columns = ["input_ids"] - if "completion_mask" in dataset.column_names: - columns.append("completion_mask") - if "assistant_masks" in dataset.column_names: - columns.append("assistant_masks") - - dataset = dataset.select_columns(columns) - - # Packing adds new column "seq_lengths" needed for document aware FlashAttention - dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) - elif args.max_length is not None: + # Truncate + if args.max_length is not None: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Truncating {dataset_name} dataset" dataset = truncate_dataset(dataset, args.max_length, map_kwargs) # For Liger kernel, ensure only the essential columns if args.use_liger_kernel: - collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} + collator_expected_keys = {"input_ids", "completion_mask", "assistant_masks"} dataset = dataset.select_columns(collator_expected_keys.intersection(dataset.column_names)) return dataset @@ -832,7 +684,7 @@ def _set_signature_columns_if_needed(self): # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the # dataset. So we need to override the default signature columns to include "completion_mask" as well. if self._signature_columns is None: - self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] + self._signature_columns = ["input_ids", "labels", "completion_mask", "assistant_masks"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -854,31 +706,18 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if not self.args.use_liger_kernel: # liger doesn't return logits with torch.no_grad(): per_token_entropy = entropy_from_logits(outputs.logits) - if "attention_mask" in inputs: - attention_mask = inputs["attention_mask"] - # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). - virtual_attention_mask = torch.ones( - attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device - ) - attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) - entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() - elif "position_ids" in inputs: - entropy = torch.mean(per_token_entropy) - else: - raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + attention_mask = inputs["attention_mask"] + # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). + virtual_attention_mask = torch.ones( + attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device + ) + attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) + entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() entropy = self.accelerator.gather_for_metrics(entropy).mean().item() self._metrics[mode]["entropy"].append(entropy) if mode == "train": - # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, - # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. - if "attention_mask" in inputs: - num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() - elif "position_ids" in inputs: - local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) - num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() - else: - raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] From c12b4cc14ba3caca66b94dc466f4509f21fcb870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 23:00:59 +0000 Subject: [PATCH 08/62] remove completion_only_loss, mompletion mask, formatting_func, assistant_only_loss; support implicit/explicit - standard/conversational; remove liger --- trl/trainer/reward_trainer.py | 361 +++++++++++----------------------- 1 file changed, 110 insertions(+), 251 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 0a792d26fbb..5f57e6cbf4a 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -46,12 +46,7 @@ from ..data_utils import is_conversational, is_conversational_from_value, maybe_convert_to_chatml, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig -from .utils import ( - entropy_from_logits, - generate_model_card, - get_comet_experiment_url, - pad, -) +from .utils import entropy_from_logits, generate_model_card, get_comet_experiment_url, pad if is_peft_available(): @@ -127,21 +122,14 @@ class DataCollatorForPreference(DataCollatorMixin): Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. - If the input contains a `"completion_mask"`, it is used to set the labels to `-100` for tokens that are not in the - completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not - in the assistant part of the sequence. The collator returns a dictionary containing the following keys: + The collator returns a dictionary containing the following keys: - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. - - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to - `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are - not in the assistant part of the sequence are set to -100. + - `"labels"`: Tensor of labels, padded to the maximum length of the batch. Args: pad_token_id (`int`): Token ID to use for padding. - completion_only_loss (`bool`, *optional*, defaults to `True`): - When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens - that are no in the completion. pad_to_multiple_of (`int`, *optional*): If set, the sequences will be padded to a multiple of this value. return_tensors (`str`, *optional*, defaults to `"pt"`): @@ -160,24 +148,10 @@ class DataCollatorForPreference(DataCollatorMixin): [ 1, 1, 0]]), 'labels': tensor([[ 1, 2, 3], [ 4, 5, -100]])} - - >>> # With completion mask - >>> examples = [ - ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, - ... {"input_ids": [4, 5], "completion_mask": [0, 1]}, - ... ] - >>> collator(examples) - {'input_ids': tensor([[ 1, 2, 3], - [ 4, 5, 0]]), - 'attention_mask': tensor([[ 1, 1, 1], - [ 1, 1, 0]]), - 'labels': tensor([[-100, 2, 3], - [-100, 5, -100]])} ``` """ pad_token_id: int - completion_only_loss: bool = True pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" @@ -190,10 +164,6 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d labels = [torch.tensor(example["labels"]) for example in examples] else: labels = [torch.tensor(example["input_ids"]) for example in examples] - if self.completion_only_loss and "completion_mask" in examples[0]: - completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] - if "assistant_masks" in examples[0]: - assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] output = {} @@ -210,16 +180,6 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d output["labels"] = pad( labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of ) - if self.completion_only_loss and "completion_mask" in examples[0]: - completion_mask = pad( - completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion - if "assistant_masks" in examples[0]: - assistant_masks = pad( - assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"][assistant_masks == 0] = -100 return output @@ -309,9 +269,6 @@ class RewardTrainer(Trainer): Note that the labels (second parameter) will be `None` if the dataset does not have them. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - formatting_func (`Callable`, *optional*): - Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly - converts the dataset into a [language modeling](#language-modeling) type. """ _tag_names = ["trl", "reward"] @@ -331,7 +288,6 @@ def __init__( optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, - formatting_func: Optional[Callable[[dict], str]] = None, ): # Args if args is None: @@ -428,14 +384,6 @@ def __init__( self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) # Data collator - # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format - # is prompt-completion, and False if the dataset format is language modeling. - dataset_sample = next(iter(train_dataset)) - if args.completion_only_loss is None: - self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample - else: - self.completion_only_loss = args.completion_only_loss - if data_collator is None: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. @@ -449,38 +397,24 @@ def __init__( ) data_collator = DataCollatorForPreference( pad_token_id=pad_token_id, - completion_only_loss=self.completion_only_loss, pad_to_multiple_of=args.pad_to_multiple_of, ) - if args.assistant_only_loss and not is_conversational(dataset_sample): - raise ValueError( - "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " - "supported for conversational datasets." - ) - # Dataset # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`. skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get( "skip_prepare_dataset", False ) if not skip_prepare_dataset: - if self.completion_only_loss and formatting_func: - raise ValueError( - "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " - "Using a formatter converts the dataset to a language modeling type, conflicting with " - "completion-only loss. To resolve this, apply your formatting function before passing the " - "dataset, or disable `completion_only_loss` in `RewardConfig`." - ) - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, formatting_func, "train") + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") if eval_dataset is not None: if isinstance(eval_dataset, dict): eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, formatting_func, key) + key: self._prepare_dataset(dataset, processing_class, args, key) for key, dataset in eval_dataset.items() } else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, formatting_func, "eval") + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} @@ -524,7 +458,6 @@ def _prepare_dataset( dataset: Union[Dataset, IterableDataset], processing_class: PreTrainedTokenizerBase, args: RewardConfig, - formatting_func: Optional[Callable[[dict], str]], dataset_name: str, ) -> Union[Dataset, IterableDataset]: # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from @@ -532,159 +465,87 @@ def _prepare_dataset( if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` dataset = dataset.with_transform(remove_none_values) - # If the dataset is already preprocessed (tokenized), skip the processing steps. - column_names = list(next(iter(dataset)).keys()) - is_processed = "input_ids" in column_names - # Build the kwargs for the `map` function map_kwargs = {} if isinstance(dataset, Dataset): # IterableDataset does not support num_proc map_kwargs["num_proc"] = args.dataset_num_proc with PartialState().main_process_first(): - # Apply the formatting function if any - if formatting_func is not None and is_processed: - logger.warning( - "You passed a dataset that is already processed (contains an `input_ids` field) together with a " - "formatting function. Therefore `formatting_func` will be ignored. Either remove the " - "`formatting_func` or pass a dataset that is not already processed.", + # Convert the dataset to ChatML if needed + first_example = next(iter(dataset)) + if is_conversational_from_value(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = next(iter(dataset)).keys() + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, ) - if formatting_func is not None and not is_processed: + # Add EOS token to the end of the sequences if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" - - def _func(example): - return {"text": formatting_func(example)} - - dataset = dataset.map(_func, batched=False, **map_kwargs) - - if not is_processed: - # Convert the dataset to ChatML if needed - first_example = next(iter(dataset)) - if is_conversational_from_value(first_example): - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" - column_names = next(iter(dataset)).keys() - dataset = dataset.map( - maybe_convert_to_chatml, - remove_columns="conversations" if "conversations" in column_names else None, - **map_kwargs, - ) + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - # Apply the chat template if needed - first_example = next(iter(dataset)) - if not is_conversational(first_example): - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - - def add_eos(example, eos_token): - if "text" in example and not example["text"].endswith(eos_token): # language modeling case - example["text"] = example["text"] + eos_token - elif "completion" in example and not example["completion"].endswith(eos_token): - example["completion"] = example["completion"] + eos_token - return example - - dataset = dataset.map( - add_eos, - fn_kwargs={"eos_token": processing_class.eos_token}, - remove_columns="messages" if "messages" in column_names else None, # renamed to "text" - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - def tokenize(example, processing_class, dataset_text_field, assistant_only_loss): - if "prompt" in example: # prompt-completion case - output = {} - if is_conversational(example): - prompt_ids = processing_class.apply_chat_template( - example["prompt"], - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - prompt_completion_processed = processing_class.apply_chat_template( - example["prompt"] + example["completion"], - return_dict=True, - return_assistant_tokens_mask=assistant_only_loss, - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - prompt_completion_ids = prompt_completion_processed["input_ids"] - if "assistant_masks" in prompt_completion_processed: - output["assistant_masks"] = prompt_completion_processed["assistant_masks"] - else: - prompt_ids = processing_class(text=example["prompt"])["input_ids"] - prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ - "input_ids" - ] - - # Check if the tokenized prompt starts with the tokenized prompt+completion - if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: - logger.warning( - "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " - "This may be due to unexpected tokenizer behavior, whitespace issues, or special " - "token handling. Verify that the tokenizer is processing text consistently." - ) - - # Create a completion mask - completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) - output["input_ids"] = prompt_completion_ids - output["completion_mask"] = completion_mask - - else: # language modeling case - if is_conversational(example): - processed = processing_class.apply_chat_template( - example["messages"], - return_dict=True, - return_assistant_tokens_mask=assistant_only_loss, - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - if "assistant_masks" in processed and 1 not in processed["assistant_masks"]: - raise RuntimeError( - "You're using `assistant_only_loss=True`, but at least one example has no " - "assistant tokens. This usually means the tokenizer's chat template doesn't " - "generate assistant masks — it may be missing the `{% generation %}` keyword. Please " - "check the template and ensure it's correctly configured to support assistant " - "masking." - ) - output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed} - else: - output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]} - return output + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if "rejected" in example and not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + return example dataset = dataset.map( - tokenize, - fn_kwargs={ - "processing_class": processing_class, - "dataset_text_field": args.dataset_text_field, - "assistant_only_loss": args.assistant_only_loss, - }, + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, **map_kwargs, ) + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(example, processing_class): + if "prompt" in example: # explicit prompt case + example["chosen"] = example["prompt"] + example["chosen"] + example["rejected"] = example["prompt"] + example["rejected"] + + if is_conversational(example): + chosen_input_ids = processing_class.apply_chat_template( + example["chosen"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + rejected_input_ids = processing_class.apply_chat_template( + example["chosen"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} + else: + output = { + "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], + "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], + } + return output + + dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + # Truncate if args.max_length is not None: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Truncating {dataset_name} dataset" dataset = truncate_dataset(dataset, args.max_length, map_kwargs) - # For Liger kernel, ensure only the essential columns - if args.use_liger_kernel: - collator_expected_keys = {"input_ids", "completion_mask", "assistant_masks"} - dataset = dataset.select_columns(collator_expected_keys.intersection(dataset.column_names)) return dataset def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" - # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the - # dataset. So we need to override the default signature columns to include "completion_mask" as well. + # and "attention_mask"). if self._signature_columns is None: - self._signature_columns = ["input_ids", "labels", "completion_mask", "assistant_masks"] + self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "labels"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -703,64 +564,62 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Compute entropy - if not self.args.use_liger_kernel: # liger doesn't return logits - with torch.no_grad(): - per_token_entropy = entropy_from_logits(outputs.logits) - attention_mask = inputs["attention_mask"] - # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). - virtual_attention_mask = torch.ones( - attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device - ) - attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) - entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() - entropy = self.accelerator.gather_for_metrics(entropy).mean().item() - self._metrics[mode]["entropy"].append(entropy) + with torch.no_grad(): + per_token_entropy = entropy_from_logits(outputs.logits) + attention_mask = inputs["attention_mask"] + # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). + virtual_attention_mask = torch.ones( + attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device + ) + attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) + entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) if mode == "train": num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - if not self.args.use_liger_kernel: - with torch.no_grad(): - if "shift_labels" in inputs: - # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: - # - The first discarded token from inputs["labels"] actually belongs to process n-1 - # - The last logits require the label from process n+1 - shift_logits = outputs.logits.contiguous() - shift_labels = inputs["shift_labels"] - else: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do - # not correspond to actual input labels. - shift_logits = shift_logits[:, self.num_virtual_tokens :, :] - - # Get predictions - predictions = shift_logits.argmax(dim=-1) - - # Create mask for non-padding tokens (assuming ignore_index is -100) - mask = shift_labels != -100 - - # Calculate accuracy only on non-padding tokens - correct_predictions = (predictions == shift_labels) & mask - total_tokens = mask.sum() - correct_tokens = correct_predictions.sum() - - # Gather the correct_tokens and total_tokens across all processes - correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) - total_tokens = self.accelerator.gather_for_metrics(total_tokens) - - # Compute the mean token accuracy and log it - total_sum = total_tokens.sum() - accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 - self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if self.aux_loss_enabled: - aux_loss = outputs.aux_loss - aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() - self._metrics[mode]["aux_loss"].append(aux_loss) + # Compute token accuracy if we have labels + with torch.no_grad(): + if "shift_labels" in inputs: + # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: + # - The first discarded token from inputs["labels"] actually belongs to process n-1 + # - The last logits require the label from process n+1 + shift_logits = outputs.logits.contiguous() + shift_labels = inputs["shift_labels"] + else: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do + # not correspond to actual input labels. + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] + + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) return (loss, outputs) if return_outputs else loss From 59b955a7e5baf6f9094b135327d6cc37f7d36913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 23:52:53 +0000 Subject: [PATCH 09/62] update config --- trl/trainer/reward_config.py | 67 ++++-------------------------------- 1 file changed, 7 insertions(+), 60 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 119046401c2..88050ba8c1a 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -47,12 +47,6 @@ class may differ from those in [`~transformers.TrainingArguments`]. > Parameters that control the data preprocessing - dataset_text_field (`str`, *optional*, defaults to `"text"`): - Name of the column that contains text data in the dataset. - dataset_kwargs (`dict[str, Any]`, *optional*): - Dictionary of optional keyword arguments for the dataset preparation. The only supported key is - `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` - regardless of the provided value, since preprocessing is done on the fly. dataset_num_proc (`int`, *optional*): Number of processes to use for processing the dataset. eos_token (`str`, *optional*): @@ -69,19 +63,9 @@ class may differ from those in [`~transformers.TrainingArguments`]. > Parameters that control the training - completion_only_loss (`bool`, *optional*): - Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed - only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If - `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: - loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full - sequence for [language modeling](#language-modeling) datasets. - assistant_only_loss (`bool`, *optional*, defaults to `False`): - Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only - on the assistant responses, which is supported only for [conversational](#conversational) datasets. If - `False`, loss is computed on the entire sequence. - loss_type (`str`, *optional*, defaults to `"nll"`): - Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic - Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). + center_rewards_coefficient (`float`, *optional*): + Coefficient to incentivize the reward model to output mean-zero rewards (proposed by + https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. activation_offloading (`bool`, *optional*, defaults to `False`): Whether to offload the activations to the CPU. """ @@ -90,7 +74,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. # Parameters whose default values are overridden from TrainingArguments learning_rate: float = field( - default=2e-5, + default=1e-4, metadata={"help": "The initial learning rate for AdamW."}, ) logging_steps: float = field( @@ -136,19 +120,6 @@ class may differ from those in [`~transformers.TrainingArguments`]. ) # Parameters that control the data preprocessing - dataset_text_field: str = field( - default="text", - metadata={"help": "Name of the column that contains text data in the dataset."}, - ) - dataset_kwargs: Optional[dict[str, Any]] = field( - default=None, - metadata={ - "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " - "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model " - "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, " - "since preprocessing is done on the fly." - }, - ) dataset_num_proc: Optional[int] = field( default=None, metadata={"help": "Number of processes to use for processing the dataset."}, @@ -179,35 +150,11 @@ class may differ from those in [`~transformers.TrainingArguments`]. ) # Parameters that control the training - completion_only_loss: Optional[bool] = field( + center_rewards_coefficient: Optional[float] = field( default=None, metadata={ - "help": ( - "Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is " - "computed only on the completion, which is supported only for prompt-completion datasets. If `False`, " - "loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: " - "loss is computed on the completion for prompt-completion datasets, and on the full sequence for " - "language modeling datasets." - ) - }, - ) - assistant_only_loss: bool = field( - default=False, - metadata={ - "help": ( - "Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is " - "computed only on the assistant responses, which is supported only for conversational datasets. If `False`, " - "loss is computed on the entire sequence." - ) - }, - ) - loss_type: str = field( - default="nll", - metadata={ - "help": ( - 'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` ' - "(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)." - ) + "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " + "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." }, ) activation_offloading: bool = field( From e718f38beb91ac7878968e6af60b49bcc5903055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 15 Sep 2025 23:54:51 +0000 Subject: [PATCH 10/62] now it looks good --- trl/trainer/reward_trainer.py | 160 +++++++++++++--------------------- 1 file changed, 63 insertions(+), 97 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 5f57e6cbf4a..577784b41dc 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -46,7 +46,7 @@ from ..data_utils import is_conversational, is_conversational_from_value, maybe_convert_to_chatml, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig -from .utils import entropy_from_logits, generate_model_card, get_comet_experiment_url, pad +from .utils import generate_model_card, get_comet_experiment_url, pad if is_peft_available(): @@ -119,13 +119,14 @@ def remove_none_values(example: TListOrMapping) -> TListOrMapping: @dataclass class DataCollatorForPreference(DataCollatorMixin): """ - Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. - This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. - The collator returns a dictionary containing the following keys: - - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. + This collator expects each example in the input list to be a dictionary containing at least the + `"chosen_input_ids"` and `"rejected_input_ids"` keys. The collator returns a dictionary containing the following + keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch + corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. - - `"labels"`: Tensor of labels, padded to the maximum length of the batch. Args: pad_token_id (`int`): @@ -140,14 +141,19 @@ class DataCollatorForPreference(DataCollatorMixin): >>> from trl.trainer.reward_trainer import DataCollatorForPreference >>> collator = DataCollatorForPreference(pad_token_id=0) - >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ... ] >>> collator(examples) - {'input_ids': tensor([[ 1, 2, 3], - [ 4, 5, 0]]), - 'attention_mask': tensor([[ 1, 1, 1], - [ 1, 1, 0]]), - 'labels': tensor([[ 1, 2, 3], - [ 4, 5, -100]])} + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]])} ``` """ @@ -157,13 +163,10 @@ class DataCollatorForPreference(DataCollatorMixin): def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: # Convert to tensor - input_ids = [torch.tensor(example["input_ids"]) for example in examples] - attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] - - if "labels" in examples[0]: - labels = [torch.tensor(example["labels"]) for example in examples] - else: - labels = [torch.tensor(example["input_ids"]) for example in examples] + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + input_ids = chosen_input_ids + rejected_input_ids + attention_mask = [torch.ones_like(ids) for ids in input_ids] output = {} @@ -175,10 +178,10 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d pad_to_multiple_of=self.pad_to_multiple_of, ) output["attention_mask"] = pad( - attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"] = pad( - labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + attention_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, ) return output @@ -311,7 +314,7 @@ def __init__( f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) with suppress_from_pretrained_warning(transformers.modeling_utils.logger): - model = AutoModelForSequenceClassification.from_pretrained(model_id, **model_init_kwargs) + model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs) else: model_id = model.config._name_or_path if args.model_init_kwargs is not None: @@ -401,20 +404,15 @@ def __init__( ) # Dataset - # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`. - skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get( - "skip_prepare_dataset", False - ) - if not skip_prepare_dataset: - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} @@ -545,7 +543,7 @@ def _set_signature_columns_if_needed(self): # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" # and "attention_mask"). if self._signature_columns is None: - self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "labels"] + self._signature_columns = ["chosen_input_ids", "rejected_input_ids"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -553,73 +551,41 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N """ mode = "train" if self.model.training else "eval" - # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used. - # This can be removed when this issue is fixed. - labels = inputs["labels"] - # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False - (loss, outputs) = super().compute_loss( - model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch - ) + outputs = model(**inputs) - # Compute entropy - with torch.no_grad(): - per_token_entropy = entropy_from_logits(outputs.logits) - attention_mask = inputs["attention_mask"] - # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1). - virtual_attention_mask = torch.ones( - attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device - ) - attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1) - entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() - entropy = self.accelerator.gather_for_metrics(entropy).mean().item() - self._metrics[mode]["entropy"].append(entropy) + # Split the rewards into chosen and rejected + rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2) + + # Calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if self.args.center_rewards_coefficient is not None: + loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) if mode == "train": num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - # Compute token accuracy if we have labels + # Compute min, mean, max, accuracy and margin with torch.no_grad(): - if "shift_labels" in inputs: - # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: - # - The first discarded token from inputs["labels"] actually belongs to process n-1 - # - The last logits require the label from process n+1 - shift_logits = outputs.logits.contiguous() - shift_labels = inputs["shift_labels"] - else: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do - # not correspond to actual input labels. - shift_logits = shift_logits[:, self.num_virtual_tokens :, :] - - # Get predictions - predictions = shift_logits.argmax(dim=-1) - - # Create mask for non-padding tokens (assuming ignore_index is -100) - mask = shift_labels != -100 - - # Calculate accuracy only on non-padding tokens - correct_predictions = (predictions == shift_labels) & mask - total_tokens = mask.sum() - correct_tokens = correct_predictions.sum() - - # Gather the correct_tokens and total_tokens across all processes - correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) - total_tokens = self.accelerator.gather_for_metrics(total_tokens) - - # Compute the mean token accuracy and log it - total_sum = total_tokens.sum() - accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 - self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if self.aux_loss_enabled: - aux_loss = outputs.aux_loss - aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() - self._metrics[mode]["aux_loss"].append(aux_loss) + all_rewards = self.accelerator.gather(outputs.logits) + self._metrics[mode]["min_reward"].append(all_rewards.min()) + self._metrics[mode]["mean_reward"].append(all_rewards.mean()) + self._metrics[mode]["max_reward"].append(all_rewards.max()) + + mean_accuracy = (rewards_chosen > rewards_rejected).float().mean() + mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean() + self._metrics[mode]["accuracy"].append(mean_accuracy) + + mean_margin = (rewards_chosen - rewards_rejected).mean() + mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean() + self._metrics[mode]["margin"].append(mean_margin) return (loss, outputs) if return_outputs else loss From 4a8e579d8cc5684a2e95ea016fc605dfe0918f0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 00:18:02 +0000 Subject: [PATCH 11/62] template for reward --- trl/templates/rm_card.md | 55 +++++++++++++++++++++++++++++++++++ trl/trainer/reward_trainer.py | 1 + trl/trainer/utils.py | 6 +++- 3 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 trl/templates/rm_card.md diff --git a/trl/templates/rm_card.md b/trl/templates/rm_card.md new file mode 100644 index 00000000000..685ed007bd5 --- /dev/null +++ b/trl/templates/rm_card.md @@ -0,0 +1,55 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_name }} + +This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +text = "The capital of France is Paris." +rewarder = pipeline(model="{{ hub_model_id }}", device="cuda") +output = rewarder(text)[0] +print(output["score"]) +``` + +## Training procedure + +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} + +This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. + +### Framework versions + +- TRL: {{ trl_version }} +- Transformers: {{ transformers_version }} +- Pytorch: {{ pytorch_version }} +- Datasets: {{ datasets_version }} +- Tokenizers: {{ tokenizers_version }} + +## Citations + +{% if trainer_citation %}Cite {{ trainer_name }} as: + +```bibtex +{{ trainer_citation }} +```{% endif %} + +Cite TRL as: + +```bibtex +{% raw %}@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +}{% endraw %} +``` diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 577784b41dc..eea4a369167 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -666,6 +666,7 @@ def create_model_card( wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="Reward", + template_file="rm_model_card.md", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 0e4a6465673..78088549648 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1208,6 +1208,7 @@ def generate_model_card( wandb_url: Optional[str], trainer_name: str, trainer_citation: Optional[str] = None, + template_file: Optional[str] = None, paper_title: Optional[str] = None, paper_id: Optional[str] = None, comet_url: Optional[str] = None, @@ -1234,6 +1235,8 @@ def generate_model_card( Trainer name. trainer_citation (`str` or `None`, defaults to `None`): Trainer citation as a BibTeX entry. + template_file (`str` *optional*): + Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`. paper_title (`str` or `None`, defaults to `None`): Paper title. paper_id (`str` or `None`, defaults to `None`): @@ -1251,9 +1254,10 @@ def generate_model_card( model_name=model_name, tags=["generated_from_trainer", *tags], ) + template_file = template_file or "lm_model_card.md" card = ModelCard.from_template( card_data, - template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), + template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")), base_model=base_model, model_name=model_name, hub_model_id=hub_model_id, From 156bff9f3c467539ee81c5e5aa5b36516aa394f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 02:42:31 +0000 Subject: [PATCH 12/62] new tiny model + fix tiny reward model --- scripts/generate_tiny_models.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 79693999395..0150e527a24 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -73,6 +73,7 @@ Qwen3ForSequenceClassification, Qwen3MoeConfig, Qwen3MoeForCausalLM, + Qwen3MoeForSequenceClassification, SmolVLMForConditionalGeneration, T5ForConditionalGeneration, ) @@ -240,15 +241,27 @@ def init_weights_tiny_model(model): ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) - config = config_class( - vocab_size=len(tokenizer.vocab), - hidden_size=8, - num_attention_heads=4, - num_key_value_heads=2, - num_hidden_layers=2, - intermediate_size=32, - num_labels=1, - ) + config = AutoConfig.from_pretrained(model_id) + config.hidden_size = 16 + config.num_attention_heads = 4 + config.num_key_value_heads = 2 + config.num_hidden_layers = 2 + config.num_labels = 1 + model = model_class(config) + push_to_hub(model, tokenizer, "tiny", suffix) + +# MoE Reward models +for model_id, config_class, model_class, suffix in [ + ("Qwen/Qwen3-30B-A3B", Qwen3MoeConfig, Qwen3MoeForSequenceClassification, None), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + config.hidden_size = 16 + config.num_attention_heads = 4 + config.num_hidden_layers = 2 + config.num_labels = 1 + config.num_experts = 4 + config.num_experts_per_tok = 2 model = model_class(config) push_to_hub(model, tokenizer, "tiny", suffix) From 978dad318dd91785d821cfd005a5bc3f988e03e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 02:42:48 +0000 Subject: [PATCH 13/62] fix template name --- trl/templates/{rm_card.md => rm_model_card.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename trl/templates/{rm_card.md => rm_model_card.md} (100%) diff --git a/trl/templates/rm_card.md b/trl/templates/rm_model_card.md similarity index 100% rename from trl/templates/rm_card.md rename to trl/templates/rm_model_card.md From 1478879e170e09417647e3d2a74f1a40bf374a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 02:43:48 +0000 Subject: [PATCH 14/62] rm promptencoder; rm non-chatml support ; fix padding token; fix processing + logging --- trl/trainer/reward_trainer.py | 59 +++++++++++++---------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index eea4a369167..74fa79c9b18 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -43,7 +43,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..data_utils import is_conversational, is_conversational_from_value, maybe_convert_to_chatml, truncate_dataset +from ..data_utils import is_conversational, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig from .utils import generate_model_card, get_comet_experiment_url, pad @@ -274,7 +274,7 @@ class RewardTrainer(Trainer): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. """ - _tag_names = ["trl", "reward"] + _tag_names = ["trl", "reward-trainer"] def __init__( self, @@ -376,28 +376,25 @@ def __init__( else: peft_config.modules_to_save.append("lm_head") - # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the - # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. - self.num_virtual_tokens = 0 - if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) - if model.active_adapter in model.peft_config: - peft_model_config = model.peft_config[model.active_adapter] - self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Pad token (needed for SequenceClassification models) + # If not provided, use the one from the processing class or the eos token if the processing class does not have + # a pad token. + pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token + pad_token_id = processing_class.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + model.config.pad_token_id = pad_token_id + processing_class.pad_token_id = pad_token_id # Data collator if data_collator is None: - # Get the pad token: if not provided, use the one from the processing class or the eos token - # if the processing class does not have a pad token. - pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token - pad_token_id = processing_class.convert_tokens_to_ids(pad_token) - if pad_token_id is None: - raise ValueError( - f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " - "in the vocabulary before using it as a padding token." - ) data_collator = DataCollatorForPreference( pad_token_id=pad_token_id, pad_to_multiple_of=args.pad_to_multiple_of, @@ -469,18 +466,6 @@ def _prepare_dataset( map_kwargs["num_proc"] = args.dataset_num_proc with PartialState().main_process_first(): - # Convert the dataset to ChatML if needed - first_example = next(iter(dataset)) - if is_conversational_from_value(first_example): - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" - column_names = next(iter(dataset)).keys() - dataset = dataset.map( - maybe_convert_to_chatml, - remove_columns="conversations" if "conversations" in column_names else None, - **map_kwargs, - ) - # Add EOS token to the end of the sequences if needed first_example = next(iter(dataset)) if not is_conversational(first_example): @@ -516,7 +501,7 @@ def tokenize(example, processing_class): **example.get("chat_template_kwargs", {}), ) rejected_input_ids = processing_class.apply_chat_template( - example["chosen"], + example["rejected"], tools=example.get("tools"), **example.get("chat_template_kwargs", {}), ) @@ -575,17 +560,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Compute min, mean, max, accuracy and margin with torch.no_grad(): all_rewards = self.accelerator.gather(outputs.logits) - self._metrics[mode]["min_reward"].append(all_rewards.min()) - self._metrics[mode]["mean_reward"].append(all_rewards.mean()) - self._metrics[mode]["max_reward"].append(all_rewards.max()) + self._metrics[mode]["min_reward"].append(all_rewards.min().item()) + self._metrics[mode]["mean_reward"].append(all_rewards.mean().item()) + self._metrics[mode]["max_reward"].append(all_rewards.max().item()) mean_accuracy = (rewards_chosen > rewards_rejected).float().mean() - mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean() + mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item() self._metrics[mode]["accuracy"].append(mean_accuracy) mean_margin = (rewards_chosen - rewards_rejected).mean() mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean() - self._metrics[mode]["margin"].append(mean_margin) + self._metrics[mode]["margin"].append(mean_margin.item()) return (loss, outputs) if return_outputs else loss From 2a9b8bd900d5596787ea42d1701b64348a16621f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 04:31:41 +0000 Subject: [PATCH 15/62] fix tiny models --- scripts/generate_tiny_models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 0150e527a24..b4193f98f3f 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -235,10 +235,10 @@ def init_weights_tiny_model(model): push_to_hub(model, tokenizer, "small") # Reward models -for model_id, config_class, model_class, suffix in [ - ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"), - ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"), - ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None), +for model_id, model_class, suffix in [ + ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"), + ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"), + ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) config = AutoConfig.from_pretrained(model_id) @@ -251,8 +251,8 @@ def init_weights_tiny_model(model): push_to_hub(model, tokenizer, "tiny", suffix) # MoE Reward models -for model_id, config_class, model_class, suffix in [ - ("Qwen/Qwen3-30B-A3B", Qwen3MoeConfig, Qwen3MoeForSequenceClassification, None), +for model_id, model_class, suffix in [ + ("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) config = AutoConfig.from_pretrained(model_id) From 67de45a9000f9ec4dad38748bdc7178c18b4d6aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 04:31:51 +0000 Subject: [PATCH 16/62] fix eval --- trl/trainer/reward_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 74fa79c9b18..6e2d7bf5f2b 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -436,6 +436,10 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty. + self.can_return_loss = True + self.label_names = [] + # Initialize activation offloading context if self.args.activation_offloading: self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) From f92041c07089821239661f60841ca2e276e4d079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 04:32:02 +0000 Subject: [PATCH 17/62] test!!! --- tests/test_reward_trainer.py | 758 ++++++++++++++++++++++++++++------- 1 file changed, 624 insertions(+), 134 deletions(-) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index a92c2c8cb0f..c1fd6566d45 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -12,217 +12,707 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pathlib +import unittest import torch -from datasets import Dataset, load_dataset +from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template -from trl.trainer.reward_trainer import _tokenize +from trl import RewardConfig, RewardTrainer +from trl.trainer.reward_trainer import DataCollatorForPreference from .testing_utils import TrlTestCase if is_peft_available(): - from peft import LoraConfig, TaskType + from peft import LoraConfig, PeftModel, get_peft_model + + +class TestDataCollatorForPreference(TrlTestCase): + def test_basic_padding(self): + """Test basic padding functionality without completion masks.""" + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + + def test_pad_to_multiple_of(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]]) + ) + + def test_single_example(self): + """Test collator with a single example.""" + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + + def test_different_pad_token_id(self): + """Test with different pad token ID.""" + collator = DataCollatorForPreference(pad_token_id=999) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) class RewardTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() - self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) - self.model.config.pad_token_id = self.tokenizer.pad_token_id - - def test_preprocessing_conversational(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + @parameterized.expand( + [ + ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",), + ("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification",), + ("trl-internal-testing/tiny-LlamaForSequenceClassification-3.2",), + ] + ) + def test_train(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @parameterized.expand( + [ + ("standard_preference",), + ("conversational_preference",), + ("standard_implicit_prompt_preference",), + ("conversational_implicit_prompt_preference",), + ] + ) + def test_train_dataset_types(self, config_name): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model(self): + # Instantiate the model + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" ) - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) - self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) - def test_preprocessing_standard(self): - # No chat template, so we load a fresh tokenizer - tokenizer = AutoTokenizer.from_pretrained(self.model_id) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_from_causal_lm(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer}) - self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) - def test_train_full(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model_dtype(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + model_init_kwargs={"dtype": torch.float16}, + learning_rate=0.1, + report_to="none", + ) trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, ) + + # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() + # Check that the training loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + + # Check the params have changed for n, param in previous_trainable_params.items(): + # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does + # locally. We ignore this parameter for now + if "layernorm" in n: + continue new_param = trainer.model.get_parameter(n) - if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + # Check the torch dtype + self.assertEqual(new_param.dtype, torch.float16) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_dense_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") - def test_train_full_pretokenized(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), ) + + # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() + # Check that the training loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + + # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") @require_peft - def test_train_lora(self): - peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + def test_train_moe_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( - model=self.model, + model=model_id, args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset, - peft_config=peft_config, + train_dataset=dataset, + peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), ) - previous_trainable_params = {} - previous_non_trainable_params = {} - # due to a change in the way the modules to save are dealt in PEFT. - trainable_params_name = ["lora", "modules_to_save"] + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) - # check gradients are not None - for n, param in trainer.model.named_parameters(): - if any(t in n for t in trainable_params_name): - previous_trainable_params[n] = param.clone() - else: - previous_non_trainable_params[n] = param.clone() + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + # Train the model trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_dense_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the non trainable parameters have not changed - for n, param in previous_non_trainable_params.items(): + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") @require_peft - def test_train_lora_pretokenized(self): - peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + def test_train_moe_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + trainer = RewardTrainer( - model=self.model, + model=model_id, args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset, - peft_config=peft_config, + train_dataset=dataset, + peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), ) - previous_trainable_params = {} - previous_non_trainable_params = {} - # due to a change in the way the modules to save are dealt in PEFT. - trainable_params_name = ["lora", "modules_to_save"] + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - # check gradients are not None - for n, param in trainer.model.named_parameters(): - if any(t in n for t in trainable_params_name): - previous_trainable_params[n] = param.clone() - else: - previous_non_trainable_params[n] = param.clone() + # Train the model + trainer.train() + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_with_peft_model_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + model = get_peft_model(model, LoraConfig()) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Verify model is a PeftModel + self.assertIsInstance(trainer.model, PeftModel) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_iterable_dataset(self): + # Get the dataset + dataset = load_dataset( + "trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True + ) - # Check that the non trainable parameters have not changed - for n, param in previous_non_trainable_params.items(): + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") - def test_margin(self): - dummy_dataset_dict = { - "input_ids_chosen": [ - torch.LongTensor([0, 1, 2]), - ], - "attention_mask_chosen": [ - torch.LongTensor([1, 1, 1]), - ], - "input_ids_rejected": [ - torch.LongTensor([0, 2]), - ], - "attention_mask_rejected": [ - torch.LongTensor([1, 1]), - ], - "margin": [ - torch.FloatTensor([1.0]), - ], - } - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + def test_train_with_chat_template_kwargs(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5") + # The following template is a simplified version of the Qwen chat template, where an additional argument + # `role_capital` is used to control the capitalization of roles. + tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n\\n" + message.content + "\\n" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' + + dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))]) + + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_set_chat_template_from_model(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") + # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification", + args=training_args, + train_dataset=dataset, ) - batch = [dummy_dataset[0]] - batch = trainer.data_collator(batch) - batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True) + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - l_val = -torch.nn.functional.logsigmoid( - outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] - ).mean() + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") - self.assertLess(abs(loss - l_val), 1e-6) + def test_train_with_set_chat_template_from_path(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"), + report_to="none", + ) + # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default + trainer = RewardTrainer( + model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - def test_tags(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + # Check that the template saved in the output directory is the same as the one used for training + template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" + self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + + with open(template_path) as f: + template_content = f.read() + with open(training_args.chat_template_path) as f: + original_template_content = f.read() + self.assertEqual( + template_content, original_template_content, "Chat template content does not match the original" + ) + + @unittest.skip("Skipping until we have a dataset with tool calls") + def test_train_toolcall_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/toolcall", split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_eval(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], ) - self.assertEqual(trainer.model.model_tags, trainer._tag_names) + + # Train the model + trainer.train() + + # Check that the eval loss is not None + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + def test_train_with_multiple_eval_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, + ) + # Train the model + trainer.train() + + # Check that the eval losses are not None + self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) + self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + + def test_train_with_gradient_checkpointing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_tag_added(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + ) + + for tag in ["reward-trainer", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + @require_peft + def test_tag_added_peft(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + for tag in ["reward-trainer", "trl"]: + self.assertIn(tag, trainer.model.model_tags) From fd4b0a04b920bf83869e88debd84e85f322ad98d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 15:56:25 +0000 Subject: [PATCH 18/62] tiny GPTNoeX --- scripts/generate_tiny_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b4193f98f3f..fb4971b15f4 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -236,6 +236,7 @@ def init_weights_tiny_model(model): # Reward models for model_id, model_class, suffix in [ + ("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None), ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"), ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"), ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None), From 59e75f3c3aedcd6f7815f95808de58acbd704f9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 15:56:39 +0000 Subject: [PATCH 19/62] fix peft target modules --- tests/test_reward_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index c1fd6566d45..79c226c53c5 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -299,7 +299,7 @@ def test_train_moe_with_peft_config(self): model=model_id, args=training_args, train_dataset=dataset, - peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), + peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]), ) # Save the initial parameters to compare them later @@ -410,7 +410,7 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self): model=model_id, args=training_args, train_dataset=dataset, - peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), + peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]), ) # Save the initial parameters to compare them later From c43b6334d12f5c3bb9b47b66a24828ca474466d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 15:56:54 +0000 Subject: [PATCH 20/62] add indication peft_config --- trl/trainer/reward_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 6e2d7bf5f2b..615164d22f4 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -271,7 +271,9 @@ class RewardTrainer(Trainer): Note that the labels (second parameter) will be `None` if the dataset does not have them. peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded + model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration + to ensure that the reward head is properly trained. """ _tag_names = ["trl", "reward-trainer"] From 09f5bffb7d77a4c0b94d622ac47bf30117fc6707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 16:01:44 +0000 Subject: [PATCH 21/62] move `remove_none_values` --- trl/trainer/reward_trainer.py | 40 ++--------------------------------- trl/trainer/sft_trainer.py | 38 ++------------------------------- trl/trainer/utils.py | 39 ++++++++++++++++++++++++++++++++-- 3 files changed, 41 insertions(+), 76 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 615164d22f4..417730238c5 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -17,11 +17,10 @@ import os import re from collections import defaultdict -from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -46,7 +45,7 @@ from ..data_utils import is_conversational, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig -from .utils import generate_model_card, get_comet_experiment_url, pad +from .utils import generate_model_card, get_comet_experiment_url, pad, remove_none_values if is_peft_available(): @@ -81,41 +80,6 @@ def filter(self, record: logging.LogRecord) -> bool: logger.removeFilter(f) -TListOrMapping = TypeVar("TListOrMapping", list, Mapping) - - -def remove_none_values(example: TListOrMapping) -> TListOrMapping: - """ - Recursively removes entries with `None` values from a nested structure (list or dictionary). - - Args: - example (`list` or `Mapping`): - Input nested structure (list or dictionary) from which to remove `None`. - - Example: - ```python - >>> [ - ... { - ... "a": {"aa": None, "ab": 1}, - ... "b": "my_string", - ... } - ... ] - >>> remove_none_values(example) - [{'a': {'ab': 1}, 'b': 'my_string'}] - ``` - """ - if isinstance(example, list): - return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] - elif isinstance(example, Mapping): - return { - key: remove_none_values(value) if isinstance(value, (dict, list)) else value - for key, value in example.items() - if value is not None - } - else: - raise TypeError("Input must be a list or a dictionary.") - - @dataclass class DataCollatorForPreference(DataCollatorMixin): """ diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index bdaeedfdd2f..c214db67c9b 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -15,10 +15,9 @@ import contextlib import os from collections import defaultdict -from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -60,6 +59,7 @@ generate_model_card, get_comet_experiment_url, pad, + remove_none_values, selective_log_softmax, ) @@ -72,40 +72,6 @@ logger = logging.get_logger(__name__) -TListOrMapping = TypeVar("TListOrMapping", list, Mapping) - - -def remove_none_values(example: TListOrMapping) -> TListOrMapping: - """ - Recursively removes entries with `None` values from a nested structure (list or dictionary). - - Args: - example (`list` or `Mapping`): - Input nested structure (list or dictionary) from which to remove `None`. - - Example: - ```python - >>> [ - ... { - ... "a": {"aa": None, "ab": 1}, - ... "b": "my_string", - ... } - ... ] - >>> remove_none_values(example) - [{'a': {'ab': 1}, 'b': 'my_string'}] - ``` - """ - if isinstance(example, list): - return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] - elif isinstance(example, Mapping): - return { - key: remove_none_values(value) if isinstance(value, (dict, list)) else value - for key, value in example.items() - if value is not None - } - else: - raise TypeError("Input must be a list or a dictionary.") - @dataclass class DataCollatorForLanguageModeling(DataCollatorMixin): diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 78088549648..3c45729f84c 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -16,10 +16,10 @@ import importlib.resources as pkg_resources import json import random -from collections.abc import Sequence, Sized +from collections.abc import Mapping, Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, TypeVar, Union import numpy as np import pandas as pd @@ -1868,3 +1868,38 @@ def process_sequence(ids, mask): truncated_mask.append(new_mask) return torch.stack(truncated_seq), torch.stack(truncated_mask) + + +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) + + +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. + + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") From 37a15a898573d27b4b7f875ffef744d9094af423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 17:07:26 +0000 Subject: [PATCH 22/62] remove compute_loss_func; fix model docstring; allow is_processed; fix tiny gptnoex --- scripts/generate_tiny_models.py | 3 +- tests/test_reward_trainer.py | 33 +++++++++ trl/trainer/reward_trainer.py | 123 ++++++++++++++++---------------- 3 files changed, 95 insertions(+), 64 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index fb4971b15f4..b6d532488a0 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -43,6 +43,7 @@ GPT2LMHeadModel, GPTNeoXConfig, GPTNeoXForCausalLM, + GPTNeoXForSequenceClassification, GptOssConfig, GptOssForCausalLM, Idefics2Config, @@ -236,7 +237,7 @@ def init_weights_tiny_model(model): # Reward models for model_id, model_class, suffix in [ - ("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None), + ("EleutherAI/pythia-14m", GPTNeoXForSequenceClassification, None), ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"), ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"), ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None), diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 79c226c53c5..ae28993ff60 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -466,6 +466,39 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + def test_train_with_pretokenized_data(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def tokenize_example(example): + return { + "chosen_input_ids": tokenizer(example["chosen"]).input_ids, + "rejected_input_ids": tokenizer(example["rejected"]).input_ids, + } + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"]) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + def test_train_with_iterable_dataset(self): # Get the dataset dataset = load_dataset( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 417730238c5..aa06a52baa9 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -85,8 +85,8 @@ class DataCollatorForPreference(DataCollatorMixin): """ Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. - This collator expects each example in the input list to be a dictionary containing at least the - `"chosen_input_ids"` and `"rejected_input_ids"` keys. The collator returns a dictionary containing the following + This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and + `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys: - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. @@ -162,7 +162,7 @@ class RewardTrainer(Trainer): from datasets import load_dataset from trl import RewardTrainer - dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + dataset = load_dataset("trl-lib/tldr-preference", split="train[:1%]") trainer = RewardTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) trainer.train() @@ -175,25 +175,24 @@ class RewardTrainer(Trainer): - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using `.from_pretrained` (where `` is derived from the model - config) with the keyword arguments in `args.model_init_kwargs`. + using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in + `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object. - If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss - as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. args ([`RewardConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~trainer.reward_trainer.DataCollatorForPreference`]. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and - [prompt-completion](#prompt-completion) type. The format of the samples can be either: + Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and + explicit prompt). The format of the samples can be either: - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content). - The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and + `rejected_input_ids` fields. eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): @@ -201,11 +200,6 @@ class RewardTrainer(Trainer): [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the default. - compute_loss_func (`Callable`, *optional*): - A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated - batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss - function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) - used by [`Trainer`]. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing @@ -250,7 +244,6 @@ def __init__( train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, - compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), @@ -394,7 +387,6 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - compute_loss_func=compute_loss_func, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, @@ -430,60 +422,65 @@ def _prepare_dataset( if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` dataset = dataset.with_transform(remove_none_values) + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names + # Build the kwargs for the `map` function map_kwargs = {} if isinstance(dataset, Dataset): # IterableDataset does not support num_proc map_kwargs["num_proc"] = args.dataset_num_proc with PartialState().main_process_first(): - # Add EOS token to the end of the sequences if needed - first_example = next(iter(dataset)) - if not is_conversational(first_example): - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - - def add_eos(example, eos_token): - if not example["chosen"].endswith(eos_token): - example["chosen"] = example["chosen"] + eos_token - if "rejected" in example and not example["rejected"].endswith(eos_token): - example["rejected"] = example["rejected"] + eos_token - return example - - dataset = dataset.map( - add_eos, - fn_kwargs={"eos_token": processing_class.eos_token}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - def tokenize(example, processing_class): - if "prompt" in example: # explicit prompt case - example["chosen"] = example["prompt"] + example["chosen"] - example["rejected"] = example["prompt"] + example["rejected"] - - if is_conversational(example): - chosen_input_ids = processing_class.apply_chat_template( - example["chosen"], - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), + if not is_processed: + # Add EOS token to the end of the sequences if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if "rejected" in example and not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + return example + + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + **map_kwargs, ) - rejected_input_ids = processing_class.apply_chat_template( - example["rejected"], - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} - else: - output = { - "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], - "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], - } - return output - dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(example, processing_class): + if "prompt" in example: # explicit prompt case + example["chosen"] = example["prompt"] + example["chosen"] + example["rejected"] = example["prompt"] + example["rejected"] + + if is_conversational(example): + chosen_input_ids = processing_class.apply_chat_template( + example["chosen"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + rejected_input_ids = processing_class.apply_chat_template( + example["rejected"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} + else: + output = { + "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], + "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], + } + return output + + dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) # Truncate if args.max_length is not None: From 7f3f4fda773a5beaeb2ca9a33352afd157a443df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 17:27:17 +0000 Subject: [PATCH 23/62] support SequenceClassification models in clone_chat_template --- tests/test_dataset_formatting.py | 55 +++++++++++++++++++++++--------- trl/models/utils.py | 3 +- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index bfedc947b14..c85845e34c3 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -15,7 +15,7 @@ from typing import Callable from datasets import Dataset, load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from trl.extras.dataset_formatting import get_formatting_func_from_dataset from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format @@ -159,47 +159,59 @@ def test_example_with_setup_model(self): class CloneChatTemplateTestCase(TrlTestCase): - def setUp(self): - super().setUp() + def test_clone(self): # This tokenizer doesn't have a chat_template by default - self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") - self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") # This one has a chat_template by default - self.source = "trl-internal-testing/tiny-Qwen3ForCausalLM" - - def test_clone(self): - _, modified_tokenizer, _ = clone_chat_template(self.model, self.tokenizer, self.source) + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) # Check if special tokens are correctly set self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") def test_clone_with_resize(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" modified_model, modified_tokenizer, _ = clone_chat_template( - self.model, self.tokenizer, self.source, resize_to_multiple_of=123 + model, tokenizer, source, resize_to_multiple_of=123 ) # Check that the input embeddings have been resized to a multiple of 123 self.assertEqual((modified_model.vocab_size % 123), 0) # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(self.model.vocab_size, len(modified_tokenizer.vocab)) + self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" # This will add , , ... to the tokenizer modified_model, modified_tokenizer, _ = clone_chat_template( - self.model, self.tokenizer, self.source, resize_to_multiple_of=123 + model, tokenizer, source, resize_to_multiple_of=123 ) # Try if we can resize a tokenizer that already has extra these extra tokens modified_model, modified_tokenizer, _ = clone_chat_template( - modified_model, modified_tokenizer, self.source, resize_to_multiple_of=124 + modified_model, modified_tokenizer, source, resize_to_multiple_of=124 ) # Check that the input embeddings have been resized to a multiple of 123 self.assertEqual((modified_model.vocab_size % 124), 0) # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(self.model.vocab_size, len(modified_tokenizer.vocab)) + self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) def test_apply_new_chat_template(self): - _, modified_tokenizer, _ = clone_chat_template(self.model, self.tokenizer, self.source) + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) messages = [ {"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hello"}, @@ -211,3 +223,16 @@ def test_apply_new_chat_template(self): prompt, "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n", ) + + def test_clone_with_sequence_classification_model(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptNeoXForSequenceClassification") + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-GptNeoXForSequenceClassification" + ) + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) + + # Check if special tokens are correctly set + self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") diff --git a/trl/models/utils.py b/trl/models/utils.py index 345ad8065b0..baa23e8e140 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -223,7 +223,8 @@ def clone_chat_template( # Set the EOS token from the source tokenizer (important for generation) tokenizer.eos_token = tokenizer_source.eos_token model.config.eos_token_id = tokenizer.eos_token_id - model.generation_config.eos_token_id = tokenizer.eos_token_id + if model.generation_config is not None: # for SequenceClassification models, generation_config is None + model.generation_config.eos_token_id = tokenizer.eos_token_id # Resize model embeddings to include any new tokens, optionally rounding up to a multiple model.resize_token_embeddings( From 82f238a66fe47235787c21501cf0c679f7370557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 18:24:19 +0000 Subject: [PATCH 24/62] fix test --- tests/test_reward_trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index ae28993ff60..65498bebf8e 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -586,6 +586,11 @@ def test_train_with_set_chat_template_from_model(self): # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") def test_train_with_set_chat_template_from_path(self): @@ -617,6 +622,11 @@ def test_train_with_set_chat_template_from_path(self): # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") # Check that the template saved in the output directory is the same as the one used for training From 41cbaa0f254f07d8a5039cf39d17ddf6cad568fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 18:30:46 +0000 Subject: [PATCH 25/62] fix sft docstring --- trl/trainer/sft_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index a5b9f2c4267..1d11a57b2c4 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -58,7 +58,7 @@ class SFTConfig(TrainingArguments): eos_token (`str`, *optional*): Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`. - pad_token (`int`, *optional*): + pad_token (`str`, *optional*): Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, it falls back to `processing_class.eos_token`. max_length (`int` or `None`, *optional*, defaults to `1024`): From 0ee9fc0fec7dad43f15640f5aec6ca89ae8b1a9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 18:31:00 +0000 Subject: [PATCH 26/62] docstring --- trl/trainer/reward_config.py | 6 ++---- trl/trainer/reward_trainer.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 88050ba8c1a..39c7e8bb1bb 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -52,7 +52,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. eos_token (`str`, *optional*): Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`. - pad_token (`int`, *optional*): + pad_token (`str`, *optional*): Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, it falls back to `processing_class.eos_token`. max_length (`int` or `None`, *optional*, defaults to `1024`): @@ -104,9 +104,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=None, metadata={ "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " - "the `RewardTrainer` is provided as a string. If you're training a MoE architecture and want to include " - "the load balancing/auxilliary loss as a part of the final loss, remember to set " - "`output_router_logits=True` in this dictionary." + "the `RewardTrainer` is provided as a string." }, ) chat_template_path: Optional[str] = field( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index aa06a52baa9..4eed58aa5b6 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -159,12 +159,12 @@ class RewardTrainer(Trainer): Example: ```python - from datasets import load_dataset from trl import RewardTrainer + from datasets import load_dataset - dataset = load_dataset("trl-lib/tldr-preference", split="train[:1%]") + dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") - trainer = RewardTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset) trainer.train() ``` From d84b85031d9df884c47842ee1ac75e353772a8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 18:31:14 +0000 Subject: [PATCH 27/62] simplify example in readme --- README.md | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 2df75017d02..b7a0668adc9 100644 --- a/README.md +++ b/README.md @@ -136,23 +136,13 @@ trainer.train() Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): ```python -from trl import RewardConfig, RewardTrainer +from trl import RewardTrainer from datasets import load_dataset -from transformers import AutoModelForSequenceClassification, AutoTokenizer - -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -model = AutoModelForSequenceClassification.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 -) -model.config.pad_token_id = tokenizer.pad_token_id dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2) trainer = RewardTrainer( - args=training_args, - model=model, - processing_class=tokenizer, + model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset, ) trainer.train() From 59b4ff0a9ac5db7a54fd8f898fcd03cb2aed2108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 19:34:29 +0000 Subject: [PATCH 28/62] two papers --- docs/source/paper_index.md | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4e5d3bdba56..5a93bfba5e5 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -528,3 +528,53 @@ training_args = CPOConfig( ... ) ``` + +## Reward Modeling + +Papers relating to the [`RewardTrainer`] + +### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking + +**📜 Paper**: https://huggingface.co/papers/2312.09244 + +This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination. + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right]. +$$ + +To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows: + +```python +from trl import RewardConfig + +training_args = RewardConfig( + center_rewards_coefficient=0.01, # η in the paper + ... +) +``` + +### Llama 2: Open Foundation and Fine-Tuned Chat Models + +**📜 Paper**: https://huggingface.co/papers/2307.09288 + +In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings. + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right]. +$$ + +You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper. + +```python +def add_margin(example): + preference_to_margin = { + "significantly better": 1.0, + "better": 2.0/3.0, + "slightly better", 1.0/3.0, + "negligibly better / unsure": 0.0, + } + return {"margin": preference_to_margin[example["preference_label"]]} + +dataset = dataset.map(add_margin) +``` From aff5097bd1da23ad4b13c73f7923523e2764ec1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 20:03:12 +0000 Subject: [PATCH 29/62] margin and center_rewards_coefficient --- docs/source/paper_index.md | 2 +- tests/test_reward_trainer.py | 73 +++++++++++++++++++++++++++++++++++ trl/trainer/reward_trainer.py | 27 +++++++++++-- 3 files changed, 98 insertions(+), 4 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 5a93bfba5e5..58c62d17e0e 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -571,7 +571,7 @@ def add_margin(example): preference_to_margin = { "significantly better": 1.0, "better": 2.0/3.0, - "slightly better", 1.0/3.0, + "slightly better": 1.0/3.0, "negligibly better / unsure": 0.0, } return {"margin": preference_to_margin[example["preference_label"]]} diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 65498bebf8e..b4d53e16941 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -92,6 +92,21 @@ def test_different_pad_token_id(self): result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) ) + def test_collate_with_margin(self): + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2}, + ] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2])) + class RewardTrainerTester(TrlTestCase): @parameterized.expand( @@ -759,3 +774,61 @@ def test_tag_added_peft(self): for tag in ["reward-trainer", "trl"]: self.assertIn(tag, trainer.model.model_tags) + + def test_train_with_margin(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def add_margin(example): + # dummy margin based on the length of the chosen summary + return {"margin": len(example["chosen"])} + + dataset = dataset.map(add_margin) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_center_rewards_coefficient(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 4eed58aa5b6..29423e5be23 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -86,12 +86,14 @@ class DataCollatorForPreference(DataCollatorMixin): Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and - `"rejected_input_ids"` keys. The collator returns a dictionary containing the following - keys: + `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys: - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a + `"margin"` key with a tensor of margins. + Args: pad_token_id (`int`): Token ID to use for padding. @@ -118,6 +120,21 @@ class DataCollatorForPreference(DataCollatorMixin): [1, 1, 0], [1, 1, 0], [1, 0, 0]])} + + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]]), + 'margin': tensor([0.5, 0.0])} ``` """ @@ -129,6 +146,8 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d # Convert to tensor chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + if "margin" in examples[0]: + margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float) input_ids = chosen_input_ids + rejected_input_ids attention_mask = [torch.ones_like(ids) for ids in input_ids] @@ -147,6 +166,8 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ) + if "margin" in examples[0]: + output["margin"] = margins return output @@ -495,7 +516,7 @@ def _set_signature_columns_if_needed(self): # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" # and "attention_mask"). if self._signature_columns is None: - self._signature_columns = ["chosen_input_ids", "rejected_input_ids"] + self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ From ce17355d26be716b645bc3f94027527ab285866b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:14:55 +0000 Subject: [PATCH 30/62] cli + documentation --- docs/source/clis.md | 96 +++++++++++++ docs/source/quickstart.md | 21 ++- docs/source/reward_trainer.md | 248 +++++++++++++++++++++++++++------- docs/source/sft_trainer.md | 2 +- tests/test_cli.py | 7 + trl/cli.py | 11 ++ trl/scripts/reward.py | 102 ++++++++++++++ 7 files changed, 438 insertions(+), 49 deletions(-) create mode 100644 trl/scripts/reward.py diff --git a/docs/source/clis.md b/docs/source/clis.md index 54b7501c1aa..232ce9d86cf 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -9,6 +9,7 @@ Currently supported commands are: - `trl dpo`: fine-tune a LLM with DPO - `trl grpo`: fine-tune a LLM with GRPO - `trl kto`: fine-tune a LLM with KTO +- `trl reward`: train a Reward Model - `trl rloo`: fine-tune a LLM with RLOO - `trl sft`: fine-tune a LLM with SFT @@ -41,6 +42,15 @@ trl dpo \ --dataset_name anthropic/hh-rlhf ``` + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized +``` + @@ -78,6 +88,21 @@ Launch with: trl dpo --config dpo_config.yaml ``` + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -138,6 +163,33 @@ Launch with: ```bash trl dpo --config dpo_config.yaml ``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_processes 4 +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +num_processes: 4 +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -217,6 +269,33 @@ Launch with: ```bash trl dpo --config dpo_config.yaml ``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -258,6 +337,23 @@ Launch with: trl dpo --config dpo_config.yaml ``` + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: trl-lib/tldr-preference + - path: trl-lib/lm-human-preferences-sentiment +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index 908230131e0..4898019bd9c 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -1,6 +1,6 @@ # Quickstart -TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). +TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). ## Quick Examples @@ -51,6 +51,21 @@ trainer = DPOTrainer( trainer.train() ``` +### Reward Modeling + +```python +from trl import RewardTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, +) +trainer.train() +``` + ## Command Line Interface Skip the code entirely - train directly from your terminal: @@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ # DPO: Align with preferences trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized + +# Reward: Train a reward model +trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized ``` ## What's Next? diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 972c45b0c88..8970fc00181 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -2,84 +2,234 @@ [![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) -TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. +## Overview -Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). +TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models. -## Expected dataset type +This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada). -The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). -The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +## Quick start -You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. +This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset. -## Using the `RewardTrainer` +```python +from trl import RewardTrainer +from datasets import load_dataset + +trainer = RewardTrainer( + model="Qwen/Qwen3-0.6B", + train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"), +) +trainer.train() +``` -After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. -You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. + -### Leveraging 🤗 PEFT to train a reward model +## Expected dataset type and format -Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! +[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ```python -from peft import LoraConfig, TaskType -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from trl import RewardTrainer, RewardConfig - -model = AutoModelForSequenceClassification.from_pretrained("gpt2") -peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, -) +# Standard preference (implicit prompt) +{"chosen": "The sky is blue.", + "rejected": "The sky is green."} + +# Conversational preference (implicit prompt) +{"chosen": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}]} + +# Standard preference (explicit prompt) +{"prompt": "The sky is", + "chosen": " blue.", + "rejected": " green."} + +# Conversational preference (explicit prompt) +{"prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}]} +``` -... +If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset: -trainer = RewardTrainer( - model=model, - args=training_args, - processing_class=tokenizer, - train_dataset=dataset, - peft_config=peft_config, -) +```python +from datasets import load_dataset +import json -trainer.train() +dataset = load_dataset("lmarena-ai/arena-human-preference-55k") + +# Filter out ties +dataset = dataset.filter(lambda example: example["winner_tie"] == 0) + + +# Create 'chosen' and 'rejected' fields based on the winner column +def response_a_b_to_chosen_rejected(example): + if example["winner_model_a"] == 1: + example["chosen"] = example["response_a"] + example["rejected"] = example["response_b"] + else: + example["chosen"] = example["response_b"] + example["rejected"] = example["response_a"] + return example + + +dataset = dataset.map(response_a_b_to_chosen_rejected) + +# Convert to conversational format +def make_conversation(example): + prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?" + chosen = json.loads(example["chosen"])[0] + rejected = json.loads(example["rejected"])[0] + return { + "chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}], + "rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}], + } + + +dataset = dataset.map(make_conversation) + +# Keep only necessary columns +dataset = dataset.select_columns(["chosen", "rejected"]) + +print(next(iter(dataset["train"]))) ``` -### Adding a margin to the loss +```json +{ + "chosen": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."}, + ], + "rejected": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."}, + ], +} +``` + +## Looking deeper into the training method + +Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences. + +This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**. + +### Preprocessing and tokenization + +During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference). +The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization. + +### Computing the loss + +Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function. + +The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences: + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right]. +$$ + + + +The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recomended value is `1e-2`. + + + +## Logged metrics -As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. +While training and evaluating we record the following reward metrics: + +* `global_step`: The total number of optimizer steps taken so far. +* `epoch`: The current epoch number, based on dataset iteration. +* `num_tokens`: The total number of tokens processed so far. +* `loss`: The average loss over the last logging interval. +* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval. +* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval. +* `mean_reward`: The average reward score assigned by the model over the last logging interval. +* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval. +* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval. +* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used. +* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping. + +## Customization + +### Model initialization + +You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16) +``` + +you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`]. ```python -def add_margin(row): - # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin - return {'margin': row['score_chosen'] - row['score_rejected']} +from trl import RewardConfig -dataset = dataset.map(add_margin) +training_args = RewardConfig( + model_init_kwargs={"dtype": torch.bfloat16}, +) ``` -### Centering rewards +Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1. -In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it. +### Train adapters with PEFT -[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs: +We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model. -$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$ +```python +from datasets import load_dataset +from trl import RewardTrainer +from peft import LoraConfig + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + "Qwen/Qwen3-4B", + train_dataset=dataset, + peft_config=LoraConfig(modules_to_save=["score"]). # important to include the score head when base model is not a sequence classification model +) + +trainer.train() +``` -This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`). +You can also continue training your [`peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. ```python -training_args = RewardConfig( - center_rewards_coefficient=0.01, - ... +from datasets import load_dataset +from trl import RewardTrainer +from peft import AutoPeftModelForCausalLM + +model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True) +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = RewardTrainer( + model=model, + train_dataset=dataset, ) + +trainer.train() ``` -For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932). + + +When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned. + +```python +RewardConfig(learning_rate=1e-3, ...) +``` + + + +## Tool Calling with Reward Modeling + +The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: + +* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages) +* The list of available tools in the `tools` column, typically provided as JSON schemas + +For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section. ## RewardTrainer @@ -91,3 +241,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr ## RewardConfig [[autodoc]] RewardConfig + +## DataCollatorPreference + +[[autodoc]] trainer.sft_trainer.DataCollatorPreference diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index dae04e1483f..1ddec50da7b 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -23,7 +23,7 @@ trainer = SFTTrainer( trainer.train() ``` - + ## Expected dataset type and format diff --git a/tests/test_cli.py b/tests/test_cli.py index 2f8891c7333..23b5d6bcff7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -67,6 +67,13 @@ def test_kto(self): with patch("sys.argv", command.split(" ")): main() + def test_reward(self): + from trl.cli import main + + command = f"trl reward --output_dir {self.tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_implicit_prompt_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + def test_rloo(self): from trl.cli import main diff --git a/trl/cli.py b/trl/cli.py index eba3bb6fff9..b6b8e3f922a 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -23,6 +23,7 @@ from .scripts.env import print_env from .scripts.grpo import make_parser as make_grpo_parser from .scripts.kto import make_parser as make_kto_parser +from .scripts.reward import make_parser as make_reward_parser from .scripts.rloo import make_parser as make_rloo_parser from .scripts.sft import make_parser as make_sft_parser from .scripts.utils import TrlParser @@ -44,6 +45,7 @@ def main(): subparsers.add_parser("env", help="Print the environment information") make_grpo_parser(subparsers) make_kto_parser(subparsers) + make_reward_parser(subparsers) make_rloo_parser(subparsers) make_sft_parser(subparsers) make_vllm_serve_parser(subparsers) @@ -110,6 +112,15 @@ def main(): args.training_script_args = sys.argv[2:] # remove "trl" and "kto" launch_command(args) # launch training + elif args.command == "reward": + # Get the default args for the launch command + reward_training_script = resources.files("trl.scripts").joinpath("reward.py") + args = launch_command_parser().parse_args([str(reward_training_script)]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "reward" + launch_command(args) # launch training + elif args.command == "rloo": # Get the default args for the launch command rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py") diff --git a/trl/scripts/reward.py b/trl/scripts/reward.py new file mode 100644 index 00000000000..62567c22633 --- /dev/null +++ b/trl/scripts/reward.py @@ -0,0 +1,102 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import argparse +import os +from typing import Optional + +from accelerate import logging +from datasets import load_dataset + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + RewardConfig, + RewardTrainer, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the RewardTrainer + trainer = RewardTrainer( + model=model_args.model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: Optional[argparse._SubParsersAction] = None): + dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser( + "reward", help="Run the reward training script", dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) From 6f0fb0960109d955dcbd3091842ab9dcc8d49628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:26:41 +0000 Subject: [PATCH 31/62] fix iframe --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 8970fc00181..962d8ab54f8 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -23,7 +23,7 @@ trainer = RewardTrainer( trainer.train() ``` - + ## Expected dataset type and format From 6a576935787b5cd4b2dea8344d32db22cfcb8107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:27:21 +0000 Subject: [PATCH 32/62] fix doc --- docs/source/reward_trainer.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 962d8ab54f8..02a1ce0a999 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -242,6 +242,6 @@ For details on the expected dataset structure, see the [Dataset Format — Tool [[autodoc]] RewardConfig -## DataCollatorPreference +## DataCollatoForPreference -[[autodoc]] trainer.sft_trainer.DataCollatorPreference +[[autodoc]] trainer.sft_trainer.DataCollatorForPreference From f2414741491df64c9f5ee197892cf5746df5374e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:33:38 +0000 Subject: [PATCH 33/62] focus dude --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 02a1ce0a999..3dd8c6dda95 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -244,4 +244,4 @@ For details on the expected dataset structure, see the [Dataset Format — Tool ## DataCollatoForPreference -[[autodoc]] trainer.sft_trainer.DataCollatorForPreference +[[autodoc]] trainer.reward_trainer.DataCollatorForPreference From c76499dfc57c7d37d0d9d875eb77b05a22912224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:42:28 +0000 Subject: [PATCH 34/62] nits --- docs/source/reward_trainer.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 3dd8c6dda95..f5f59495192 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -62,7 +62,6 @@ dataset = load_dataset("lmarena-ai/arena-human-preference-55k") # Filter out ties dataset = dataset.filter(lambda example: example["winner_tie"] == 0) - # Create 'chosen' and 'rejected' fields based on the winner column def response_a_b_to_chosen_rejected(example): if example["winner_model_a"] == 1: @@ -73,10 +72,8 @@ def response_a_b_to_chosen_rejected(example): example["rejected"] = example["response_a"] return example - dataset = dataset.map(response_a_b_to_chosen_rejected) - # Convert to conversational format def make_conversation(example): prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?" From a3410777b25880ddfca7d61290ba907af9f344eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:43:02 +0000 Subject: [PATCH 35/62] fix space --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index f5f59495192..0b3de01cb58 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -119,7 +119,7 @@ The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If promp ### Computing the loss -Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function. +Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function. The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences: From 302e4c9f79236e8e6e2b4a2d9851ba13371e9336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:44:08 +0000 Subject: [PATCH 36/62] nit --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 0b3de01cb58..a8dee9e895a 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -185,7 +185,7 @@ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") trainer = RewardTrainer( "Qwen/Qwen3-4B", train_dataset=dataset, - peft_config=LoraConfig(modules_to_save=["score"]). # important to include the score head when base model is not a sequence classification model + peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model ) trainer.train() From 84318016e527de972df0bd852c89789f8266c5c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 22:53:43 +0000 Subject: [PATCH 37/62] fix layer_types --- scripts/generate_tiny_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b6d532488a0..e1247d29acf 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -249,6 +249,8 @@ def init_weights_tiny_model(model): config.num_key_value_heads = 2 config.num_hidden_layers = 2 config.num_labels = 1 + if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"): + config.layer_types = config.layer_types[:2] model = model_class(config) push_to_hub(model, tokenizer, "tiny", suffix) From b7f8776f11ff4de1e4d5af96117f46c1621e8f68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 23:12:06 +0000 Subject: [PATCH 38/62] update section header for dataset mixtures in CLI documentation --- docs/source/clis.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 232ce9d86cf..d6e433cada7 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -303,7 +303,7 @@ trl reward --config reward_config.yaml You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. - + ```yaml From f5f1b2dfb4e801920c64c5d3a187b1b959eca594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 16 Sep 2025 23:16:18 +0000 Subject: [PATCH 39/62] fix some iframes --- docs/source/sft_trainer.md | 2 +- docs/source/trackio_integration.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index 1ddec50da7b..841bee62b66 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -23,7 +23,7 @@ trainer = SFTTrainer( trainer.train() ``` - + ## Expected dataset type and format diff --git a/docs/source/trackio_integration.md b/docs/source/trackio_integration.md index 9f5fe693a9f..4e93120fe19 100644 --- a/docs/source/trackio_integration.md +++ b/docs/source/trackio_integration.md @@ -64,4 +64,4 @@ trainer.train() will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio. - + From 3bf81559bf98bdd3db5d6080ad46fc734bbb15b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 17 Sep 2025 15:17:13 +0000 Subject: [PATCH 40/62] add disable_dropout parameter to RewardConfig and implement in RewardTrainer --- trl/trainer/reward_config.py | 6 ++++++ trl/trainer/reward_trainer.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 39c7e8bb1bb..595f0870530 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -44,6 +44,8 @@ class may differ from those in [`~transformers.TrainingArguments`]. or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must ensure that any special tokens referenced in the template are added to the tokenizer and that the model's embedding layer is resized accordingly. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. > Parameters that control the data preprocessing @@ -116,6 +118,10 @@ class may differ from those in [`~transformers.TrainingArguments`]. "that the model's embedding layer is resized accordingly." }, ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) # Parameters that control the data preprocessing dataset_num_proc: Optional[int] = field( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 29423e5be23..93cb528a2bc 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -45,7 +45,7 @@ from ..data_utils import is_conversational, truncate_dataset from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig -from .utils import generate_model_card, get_comet_experiment_url, pad, remove_none_values +from .utils import disable_dropout_in_model, generate_model_card, get_comet_experiment_url, pad, remove_none_values if is_peft_available(): @@ -359,6 +359,10 @@ def __init__( if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + # Pad token (needed for SequenceClassification models) # If not provided, use the one from the processing class or the eos token if the processing class does not have # a pad token. From bf77b59cff5ffa6a494c162156c788749f34b7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 17 Sep 2025 15:17:37 +0000 Subject: [PATCH 41/62] deprecate RewardDataCollatorWithPadding and decode_and_strip_padding functions with warnings --- trl/trainer/utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3c45729f84c..1bea8d346b4 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -16,6 +16,7 @@ import importlib.resources as pkg_resources import json import random +import warnings from collections.abc import Mapping, Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version @@ -172,6 +173,13 @@ class RewardDataCollatorWithPadding: r""" Reward DataCollator class that pads the inputs to the maximum length of the batch. + + + This class is deprecated and will be removed in version 0.27.0. Please use + `trl.trainer.reward_trainer.DataCollatorForPreference` instead. + + + Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer used for encoding the data. @@ -188,6 +196,14 @@ class RewardDataCollatorWithPadding: pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use " + "`trl.trainer.reward_trainer.DataCollatorForPreference` instead.", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: features_chosen = [] features_rejected = [] @@ -1185,6 +1201,13 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize """ Decodes the input tensor and strips the padding tokens. + + + This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy the + code into your codebase and use it from there. + + + Args: inputs (`torch.Tensor`): The input tensor to be decoded. @@ -1195,6 +1218,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize `list[str]`: The list of decoded strings with padding tokens stripped. """ + warnings.warn( + "The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want " + "to keep using it, please copy the code into your codebase and use it from there.", + DeprecationWarning, + ) decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) return [d.replace(tokenizer.pad_token, "") for d in decoded] From 4eda563f2b311853b460446c652d9c8d9091ca13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 22 Sep 2025 09:34:00 -0600 Subject: [PATCH 42/62] Update trl/trainer/reward_trainer.py --- trl/trainer/reward_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 93cb528a2bc..3022a777855 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -198,7 +198,7 @@ class RewardTrainer(Trainer): [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. + - A sequence classification [`~transformers.PreTrainedModel`] object. args ([`RewardConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): From 291ba982d4abb9332e407ed8f99a35b08afb432f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 22 Sep 2025 16:16:15 +0000 Subject: [PATCH 43/62] filter --- trl/trainer/reward_config.py | 4 ++-- trl/trainer/reward_trainer.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 595f0870530..33d248e635e 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -58,8 +58,8 @@ class may differ from those in [`~transformers.TrainingArguments`]. Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, it falls back to `processing_class.eos_token`. max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. - If `None`, no truncation is applied. + Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence + exceeds this value. If `None`, no filtering is applied. pad_to_multiple_of (`int`, *optional*): If set, the sequences will be padded to a multiple of this value. diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 3022a777855..81c29f1d2f2 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -42,7 +42,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..data_utils import is_conversational, truncate_dataset +from ..data_utils import is_conversational from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .reward_config import RewardConfig from .utils import disable_dropout_in_model, generate_model_card, get_comet_experiment_url, pad, remove_none_values @@ -507,11 +507,15 @@ def tokenize(example, processing_class): dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) - # Truncate + # Filter samples that are longer than `max_length` if args.max_length is not None: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Truncating {dataset_name} dataset" - dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens" + dataset = dataset.filter( + lambda example: len(example["chosen_input_ids"]) <= args.max_length + and len(example["rejected_input_ids"]) <= args.max_length, + **map_kwargs, + ) return dataset From 1323901fb892fdb21465220ee7871dda9bdc6251 Mon Sep 17 00:00:00 2001 From: juejuezi Date: Tue, 23 Sep 2025 10:23:58 +0800 Subject: [PATCH 44/62] =?UTF-8?q?=F0=9F=90=AF=20fix:=20use=5Fliger=5Fkerne?= =?UTF-8?q?l=20with=20IterableDataset=20(#4087)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- trl/trainer/sft_trainer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 35e125ddca1..6c3b95a198b 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -72,6 +72,10 @@ logger = logging.get_logger(__name__) +def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]: + return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names + + @dataclass class DataCollatorForLanguageModeling(DataCollatorMixin): """ @@ -858,7 +862,7 @@ def _prepare_dataset( dataset = dataset.with_transform(remove_none_values) # If the dataset is already preprocessed (tokenized), skip the processing steps. - column_names = list(next(iter(dataset)).keys()) + column_names = get_dataset_column_names(dataset) is_processed = "input_ids" in column_names # Build the kwargs for the `map` function @@ -890,7 +894,7 @@ def _func(example): if is_conversational_from_value(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" - column_names = next(iter(dataset)).keys() + column_names = get_dataset_column_names(dataset) dataset = dataset.map( maybe_convert_to_chatml, remove_columns="conversations" if "conversations" in column_names else None, @@ -999,9 +1003,9 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss) map_kwargs["desc"] = f"Packing {dataset_name} dataset" columns = ["input_ids"] - if "completion_mask" in dataset.column_names: + if "completion_mask" in get_dataset_column_names(dataset): columns.append("completion_mask") - if "assistant_masks" in dataset.column_names: + if "assistant_masks" in get_dataset_column_names(dataset): columns.append("assistant_masks") dataset = dataset.select_columns(columns) @@ -1015,7 +1019,8 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss) # For Liger kernel, ensure only the essential columns if args.use_liger_kernel: collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} - dataset = dataset.select_columns(collator_expected_keys.intersection(dataset.column_names)) + column_names = get_dataset_column_names(dataset) + dataset = dataset.select_columns(collator_expected_keys.intersection(column_names)) return dataset From 7d10daa8995c9e19722a0403e58a7d01268893bf Mon Sep 17 00:00:00 2001 From: Yi Shi <96773624+singing-cat@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:58:40 +0800 Subject: [PATCH 45/62] =?UTF-8?q?=F0=9F=93=A4=20Fix=20a=20dataset=20loadin?= =?UTF-8?q?g=20bug=20in=20scripts=20(#4124)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec --- trl/scripts/dpo.py | 1 + trl/scripts/kto.py | 1 + trl/scripts/rloo.py | 1 + trl/scripts/sft.py | 1 + 4 files changed, 4 insertions(+) diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 9ad6b021681..88c93398ed4 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -135,6 +135,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index af3447603c8..4c71f0178e1 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -111,6 +111,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: diff --git a/trl/scripts/rloo.py b/trl/scripts/rloo.py index f2d17cb7b5d..0a1941c556d 100644 --- a/trl/scripts/rloo.py +++ b/trl/scripts/rloo.py @@ -118,6 +118,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 8a29b694cd8..a9fa28a9a94 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -129,6 +129,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: From 814f97f59394cd6c74b079cdb2ca4a23a66c3a93 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 23 Sep 2025 07:19:12 +0200 Subject: [PATCH 46/62] =?UTF-8?q?=E2=9A=93=20[vllm]=20ensure=20MASTER=5FAD?= =?UTF-8?q?DR/MASTER=5FPORT=20are=20set=20safely=20(#4057)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec --- trl/trainer/grpo_trainer.py | 5 ++-- trl/trainer/online_dpo_trainer.py | 5 ++-- trl/trainer/rloo_trainer.py | 5 ++-- trl/trainer/utils.py | 43 +++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 69825102b27..9088ea39a83 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -59,6 +59,7 @@ from .utils import ( RepeatSampler, disable_dropout_in_model, + ensure_master_addr_port, entropy_from_logits, generate_model_card, get_comet_experiment_url, @@ -514,8 +515,8 @@ def __init__( os.environ["RANK"] = str(self.accelerator.process_index) os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345") + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() if self.max_prompt_length is not None and self.max_completion_length is not None: max_model_len = self.max_prompt_length + self.max_completion_length diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7c9466a3dd9..aed16895e1d 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -68,6 +68,7 @@ DPODataCollatorWithPadding, disable_dropout_in_model, empty_cache, + ensure_master_addr_port, generate_model_card, get_comet_experiment_url, pad, @@ -521,8 +522,8 @@ def __init__( os.environ["RANK"] = str(self.accelerator.process_index) os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345") + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() self.llm = LLM(**vllm_kwargs) else: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 13ac5e34a3d..2f649a98764 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -59,6 +59,7 @@ from .utils import ( RepeatSampler, disable_dropout_in_model, + ensure_master_addr_port, entropy_from_logits, generate_model_card, get_comet_experiment_url, @@ -586,8 +587,8 @@ def decode(example, tokenizer): os.environ["RANK"] = str(self.accelerator.process_index) os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345") + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() if self.max_prompt_length is not None and self.max_completion_length is not None: max_model_len = self.max_prompt_length + self.max_completion_length diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 42d6a69e4ae..cf976c35c13 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -15,7 +15,9 @@ import dataclasses import importlib.resources as pkg_resources import json +import os import random +import socket import warnings from collections.abc import Mapping, Sequence, Sized from dataclasses import dataclass, field @@ -172,6 +174,47 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: } +def _is_port_free(port: int, host: str = "127.0.0.1") -> bool: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return True + except OSError: + return False + + +def _find_free_port() -> int: + candidates = (29500, 23456, 12355, 12345) + for p in candidates: + if _is_port_free(p): + return p + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def ensure_master_addr_port(addr: Optional[str] = None, port: Optional[int] = None) -> None: + """ + Ensure `MASTER_ADDR`/`MASTER_PORT` are set safely. + + - Respects existing environment variables. + - Defaults `MASTER_ADDR` to localhost if unset. + - Chooses a free TCP port if `MASTER_PORT` is unset to avoid collisions. + - If `MASTER_PORT` is set to `"0"` or `"auto"`, it is resolved to a free port. + """ + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR") or addr or "localhost" + + env_port = os.environ.get("MASTER_PORT", "").strip().lower() + if port is None and env_port not in {"", "0", "auto"}: + try: + port = int(env_port) + except ValueError: + pass + + os.environ["MASTER_PORT"] = str(_find_free_port() if port in (None, 0) else port) + + @dataclass class RewardDataCollatorWithPadding: r""" From 05fd402af5b05594a1413ac2cb5b87db9bcaf577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 23 Sep 2025 05:21:30 +0000 Subject: [PATCH 47/62] =?UTF-8?q?=F0=9F=93=A4=20Fix=20a=20dataset=20loadin?= =?UTF-8?q?g=20bug=20in=20scripts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/scripts/grpo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 44c564113b4..5373e482db0 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -118,6 +118,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: From 7c174e047be249855ed962574f20213e0c2a0b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 23 Sep 2025 08:02:30 -0600 Subject: [PATCH 48/62] =?UTF-8?q?=F0=9F=93=8C=20Pin=20vLLM=20version=20(#4?= =?UTF-8?q?122)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/speeding_up_training.md | 8 +------- docs/source/vllm_integration.md | 6 ++++++ examples/scripts/evals/judge_tldr.py | 3 +-- examples/scripts/rloo.py | 3 +-- setup.cfg | 2 +- trl/extras/vllm_client.py | 2 +- trl/import_utils.py | 12 +++++++++++- trl/trainer/grpo_trainer.py | 2 +- trl/trainer/online_dpo_config.py | 2 +- trl/trainer/online_dpo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- 11 files changed, 26 insertions(+), 18 deletions(-) diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index 6a3392aa6f9..57586295f8f 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -14,13 +14,7 @@ To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm) To use [vLLM](https://github.com/vllm-project/vllm), first install it using: ```bash -pip install vllm -``` - -or - -```bash -pip install "trl[vllm]" +pip install trl[vllm] ``` diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 0cf92df944e..9240aed62ce 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -2,6 +2,12 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 + + +TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. Please ensure you have one of these versions installed to avoid compatibility issues. + + + ## 🚀 How can I use vLLM with TRL to speed up training? 💡 **Note**: Resources required for this specific example: a single node with 8 GPUs. diff --git a/examples/scripts/evals/judge_tldr.py b/examples/scripts/evals/judge_tldr.py index e803f335be8..286dfa1576f 100644 --- a/examples/scripts/evals/judge_tldr.py +++ b/examples/scripts/evals/judge_tldr.py @@ -14,8 +14,7 @@ # /// script # dependencies = [ -# "trl", -# "vllm", +# "trl[vllm]", # ] # /// diff --git a/examples/scripts/rloo.py b/examples/scripts/rloo.py index e9fb222f63b..bc599f7b9c0 100644 --- a/examples/scripts/rloo.py +++ b/examples/scripts/rloo.py @@ -14,12 +14,11 @@ # /// script # dependencies = [ -# "trl", +# "trl[vllm]", # "peft", # "math-verify", # "latex2sympy2_extended", # "trackio", -# "vllm", # "kernels", # ] # /// diff --git a/setup.cfg b/setup.cfg index fa431a70501..67f304f4088 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,7 +64,7 @@ test = vllm = # vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added: # see https://github.com/vllm-project/vllm/pull/13164 - vllm>=0.10.0; python_version < "3.13" + vllm>=0.10.0,<=0.10.2; python_version < "3.13" fastapi; python_version < "3.13" pydantic; python_version < "3.13" requests; python_version < "3.13" diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 177c5815ba5..0932697d6ee 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -114,7 +114,7 @@ def __init__( if not is_requests_available(): raise ImportError("requests is not installed. Please install it with `pip install requests`.") if not is_vllm_available(): - raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") + raise ImportError("vLLM is not installed. Please install it with `pip install trl[vllm]`.") self.session = requests.Session() diff --git a/trl/import_utils.py b/trl/import_utils.py index e495a845dae..0f15a17222c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -14,6 +14,7 @@ import importlib import os +import warnings from itertools import chain from types import ModuleType from typing import Any @@ -35,7 +36,7 @@ _requests_available = _is_package_available("requests") _unsloth_available = _is_package_available("unsloth") _uvicorn_available = _is_package_available("uvicorn") -_vllm_available = _is_package_available("vllm") +_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) _vllm_ascend_available = _is_package_available("vllm_ascend") _weave_available = _is_package_available("weave") @@ -81,6 +82,15 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: + if _vllm_available and ( + version.parse(_vllm_version) < version.parse("0.10.0") + or version.parse(_vllm_version) > version.parse("0.10.2") + ): + warnings.warn( + "TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. You have version " + f"{_vllm_version} installed. We recommend to install one of these versions to avoid compatibility issues.", + UserWarning, + ) return _vllm_available diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9088ea39a83..2df4aa4f6ab 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -480,7 +480,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install [vllm]` to use it." ) if self.vllm_mode == "server": diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index ae0eaace284..67dfa3b25f8 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -293,7 +293,7 @@ class may differ from those in [`~transformers.TrainingArguments`]. default=False, metadata={ "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " - "(`pip install vllm`)." + "(`pip install trl[vllm]`)." }, ) vllm_model_impl: str = field( diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index aed16895e1d..395a5d285e0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -484,7 +484,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 2f649a98764..093e3a784f8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -552,7 +552,7 @@ def decode(example, tokenizer): if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": From 8843b7b92afcdbec6b91d7695053dd0870fa2373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 23 Sep 2025 08:12:13 -0600 Subject: [PATCH 49/62] =?UTF-8?q?=F0=9F=91=8B=20Remove=20`backend`=20param?= =?UTF-8?q?eter=20from=20`GuidedDecodingParams`=20(#4123)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/scripts/vllm_serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 0acfdaa8616..3e448aedf13 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -557,7 +557,7 @@ async def generate(request: GenerateRequest): # Guided decoding, if enabled if request.guided_decoding_regex is not None: - guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex) + guided_decoding = GuidedDecodingParams(regex=request.guided_decoding_regex) else: guided_decoding = None From 95fe6b8be837049325ebcfcd8709725f6c456a14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 23 Sep 2025 08:50:52 -0600 Subject: [PATCH 50/62] =?UTF-8?q?=F0=9F=A7=B9=20Remove=20`max=5Fbatch=5Fto?= =?UTF-8?q?kens`,=20`num=5Fblocks`=20and=20`block=5Fsize`=20from=20generat?= =?UTF-8?q?ion=20kwargs=20(#4065)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/grpo_trainer.py | 5 +---- trl/trainer/online_dpo_trainer.py | 5 +---- trl/trainer/rloo_trainer.py | 5 +---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 2df4aa4f6ab..51db9a1f505 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -566,10 +566,6 @@ def __init__( "repetition_penalty": self.repetition_penalty, "cache_implementation": args.cache_implementation, } - if args.use_transformers_paged: - generation_kwargs["max_batch_tokens"] = 512 - generation_kwargs["num_blocks"] = 1024 - generation_kwargs["block_size"] = 128 if args.generation_kwargs is not None: generation_kwargs.update(args.generation_kwargs) self.generation_config = GenerationConfig(**generation_kwargs) @@ -1306,6 +1302,7 @@ def _generate_and_score_completions( all_outputs = unwrapped_model.generate_batch( paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 395a5d285e0..65550bd41ca 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -572,10 +572,6 @@ def __init__( generation_kwargs["min_p"] = self.min_p if args.generation_kwargs is not None: generation_kwargs.update(args.generation_kwargs) - if self.use_transformers_paged: - generation_kwargs["max_batch_tokens"] = 512 - generation_kwargs["num_blocks"] = 1024 - generation_kwargs["block_size"] = 128 # Remove None values generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} self.generation_config = GenerationConfig(**generation_kwargs) @@ -1112,6 +1108,7 @@ def _generate(self, model, prompts, images=None): generation_config=self.generation_config, progress_bar=False, ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 093e3a784f8..e87ecf95b37 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -638,10 +638,6 @@ def decode(example, tokenizer): "repetition_penalty": self.repetition_penalty, "cache_implementation": args.cache_implementation, } - if args.use_transformers_paged: - generation_kwargs["max_batch_tokens"] = 512 - generation_kwargs["num_blocks"] = 1024 - generation_kwargs["block_size"] = 128 if args.generation_kwargs is not None: generation_kwargs.update(args.generation_kwargs) self.generation_config = GenerationConfig(**generation_kwargs) @@ -1284,6 +1280,7 @@ def _generate_and_score_completions( all_outputs = unwrapped_model.generate_batch( paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") From 0641915daf157a4243910e2b25e4617f2a5381b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:18:33 +0000 Subject: [PATCH 51/62] fix template file --- trl/trainer/base_trainer.py | 2 ++ trl/trainer/reward_trainer.py | 68 ++--------------------------------- 2 files changed, 4 insertions(+), 66 deletions(-) diff --git a/trl/trainer/base_trainer.py b/trl/trainer/base_trainer.py index e7cb05def71..bb88cbfc934 100644 --- a/trl/trainer/base_trainer.py +++ b/trl/trainer/base_trainer.py @@ -28,6 +28,7 @@ class BaseTrainer(Trainer): _tag_names = [] _name = "Base" _paper = {} + _template_file = None def create_model_card( self, @@ -78,6 +79,7 @@ def create_model_card( comet_url=get_comet_experiment_url(), trainer_name=self._name, trainer_citation=self._paper.get("citation"), + template_file=self._template_file, paper_title=self._paper.get("title"), paper_id=self._paper.get("id"), ) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 957604157fb..56ebc03ec3d 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -30,26 +30,16 @@ DataCollator, PreTrainedModel, PreTrainedTokenizerBase, - ProcessorMixin, ) from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..data_utils import maybe_apply_chat_template from ..models import prepare_peft_model from .base_trainer import BaseTrainer from .reward_config import RewardConfig -from .utils import ( - RewardDataCollatorWithPadding, - compute_accuracy, - decode_and_strip_padding, - disable_dropout_in_model, - log_table_to_comet_experiment, - print_rich_table, -) - +from .utils import disable_dropout_in_model if is_peft_available(): from peft import PeftConfig, PeftModel @@ -260,6 +250,7 @@ class RewardTrainer(BaseTrainer): _tag_names = ["trl", "reward-trainer"] _name = "Reward" + _template_file = "rm_model_card.md" def __init__( self, @@ -600,58 +591,3 @@ def _save_checkpoint(self, model, trial): model_name = self.args.hub_model_id.split("/")[-1] self.create_model_card(model_name=model_name) super()._save_checkpoint(model, trial) - - def create_model_card( - self, - model_name: Optional[str] = None, - dataset_name: Optional[str] = None, - tags: Union[str, list[str], None] = None, - ): - """ - Creates a draft of a model card using the information available to the `Trainer`. - - Args: - model_name (`str`, *optional*): - Name of the model. - dataset_name (`str`, *optional*): - Name of the dataset used for training. - tags (`str`, `list[str]`, *optional*): - Tags to be associated with the model card. - """ - if not self.is_world_process_zero(): - return - - if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): - base_model = self.model.config._name_or_path - else: - base_model = None - - # normalize `tags` to a mutable set - if tags is None: - tags = set() - elif isinstance(tags, str): - tags = {tags} - else: - tags = set(tags) - - if hasattr(self.model.config, "unsloth_version"): - tags.add("unsloth") - - if "JOB_ID" in os.environ: - tags.add("hf_jobs") - - tags.update(self._tag_names) - - model_card = generate_model_card( - base_model=base_model, - model_name=model_name, - hub_model_id=self.hub_model_id, - dataset_name=dataset_name, - tags=list(tags), - wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, - comet_url=get_comet_experiment_url(), - trainer_name="Reward", - template_file="rm_model_card.md", - ) - - model_card.save(os.path.join(self.args.output_dir, "README.md")) From adcb80bb13e3588906265a5faf8641a481fc6c62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:35:12 +0000 Subject: [PATCH 52/62] fix imports --- trl/trainer/reward_trainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 56ebc03ec3d..17afc7187cb 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import logging +import os +import re from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -36,16 +40,16 @@ from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available -from ..models import prepare_peft_model +from ..data_utils import is_conversational +from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .base_trainer import BaseTrainer from .reward_config import RewardConfig -from .utils import disable_dropout_in_model +from .utils import disable_dropout_in_model, pad, remove_none_values + if is_peft_available(): from peft import PeftConfig, PeftModel -if is_wandb_available(): - import wandb logger = get_logger(__name__) From 1c4b2950f30860d1490e2a19c400d7cde834053d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:37:21 +0000 Subject: [PATCH 53/62] revert modif online dpo --- trl/trainer/online_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 6edf3b0fb90..4ad26b72a77 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -68,6 +68,7 @@ disable_dropout_in_model, empty_cache, ensure_master_addr_port, + pad, prepare_deepspeed, truncate_right, ) From cb61502822a50ce09995389e4da170dce0883ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:45:25 +0000 Subject: [PATCH 54/62] #4048 and #4124 --- trl/scripts/reward.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/trl/scripts/reward.py b/trl/scripts/reward.py index 62567c22633..f34b04e80ee 100644 --- a/trl/scripts/reward.py +++ b/trl/scripts/reward.py @@ -53,6 +53,7 @@ def main(script_args, training_args, model_args, dataset_args): "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " "dataset and `dataset_name` will be ignored." ) + dataset = get_dataset(dataset_args) elif dataset_args.datasets and not script_args.dataset_name: dataset = get_dataset(dataset_args) elif not dataset_args.datasets and script_args.dataset_name: @@ -74,10 +75,16 @@ def main(script_args, training_args, model_args, dataset_args): # Train the model trainer.train() + # Log training complete + trainer.accelerator.print("✅ Training completed.") + # Save and push to Hub trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.") + if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") def make_parser(subparsers: Optional[argparse._SubParsersAction] = None): From 72b4ad3465f5a14fe5196a6b6e2831e5fb6571ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:46:22 +0000 Subject: [PATCH 55/62] #4178 --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index a8dee9e895a..0f536e6d267 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -191,7 +191,7 @@ trainer = RewardTrainer( trainer.train() ``` -You can also continue training your [`peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. +You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. ```python from datasets import load_dataset From 8e4f332bda73c6ddd5a2e3c0c480f37ecd4d9d9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:51:39 +0000 Subject: [PATCH 56/62] #4161 --- docs/source/reward_trainer.md | 22 ++++++++-------------- trl/models/utils.py | 7 ++----- trl/trainer/utils.py | 16 +++++----------- 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 0f536e6d267..68e127343df 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -127,11 +127,8 @@ $$ \mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right]. $$ - - -The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recomended value is `1e-2`. - - +> [!TIP] +> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recomended value is `1e-2`. ## Logged metrics @@ -209,15 +206,12 @@ trainer = RewardTrainer( trainer.train() ``` - - -When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned. - -```python -RewardConfig(learning_rate=1e-3, ...) -``` - - +> [!TIP] +> When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned. +> +> ```python +> RewardConfig(learning_rate=1e-3, ...) +> ``` ## Tool Calling with Reward Modeling diff --git a/trl/models/utils.py b/trl/models/utils.py index 22ea572ca06..efdba75fbda 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -94,11 +94,8 @@ def setup_chat_format( Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. - - - This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead. - - + > [!WARNING] + > This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead. If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index f89e97de1af..a12fdec7b44 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -220,13 +220,10 @@ class RewardDataCollatorWithPadding: r""" Reward DataCollator class that pads the inputs to the maximum length of the batch. - - - This class is deprecated and will be removed in version 0.27.0. Please use + > [!WARNING] + > This class is deprecated and will be removed in version 0.27.0. Please use `trl.trainer.reward_trainer.DataCollatorForPreference` instead. - - Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer used for encoding the data. @@ -1257,12 +1254,9 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize """ Decodes the input tensor and strips the padding tokens. - - - This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy the - code into your codebase and use it from there. - - + > [!WARNING] + > This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy + > the code into your codebase and use it from there. Args: inputs (`torch.Tensor`): From 0f3b4f8307891e18b8ce731ffc4f5deac16dd4a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:56:13 +0000 Subject: [PATCH 57/62] #4007 --- trl/trainer/reward_trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 17afc7187cb..b8452cb0adb 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -139,7 +139,7 @@ class DataCollatorForPreference(DataCollatorMixin): pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" - def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: # Convert to tensor chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] @@ -258,9 +258,9 @@ class RewardTrainer(BaseTrainer): def __init__( self, - model: Union[str, nn.Module, PreTrainedModel], + model: Union[str, PreTrainedModel], args: Optional[RewardConfig] = None, - data_collator: Optional[DataCollator] = None, # type: ignore + data_collator: Optional[DataCollator] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, @@ -525,7 +525,13 @@ def _set_signature_columns_if_needed(self): if self._signature_columns is None: self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"] - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ): """ Compute training loss and additionally compute token accuracies """ From ff1175b9bb55cf74c75d44b0074af7b5b05023d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 18:57:14 +0000 Subject: [PATCH 58/62] #4080 --- trl/trainer/reward_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index b8452cb0adb..930d76c1ae3 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -480,7 +480,7 @@ def add_eos(example, eos_token): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - def tokenize(example, processing_class): + def tokenize_fn(example, processing_class): if "prompt" in example: # explicit prompt case example["chosen"] = example["prompt"] + example["chosen"] example["rejected"] = example["prompt"] + example["rejected"] @@ -504,7 +504,7 @@ def tokenize(example, processing_class): } return output - dataset = dataset.map(tokenize, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) # Filter samples that are longer than `max_length` if args.max_length is not None: From 115123284d013951d4dac63115971c7eae996460 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 19:03:07 +0000 Subject: [PATCH 59/62] #4006 --- trl/trainer/reward_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 930d76c1ae3..5408db49967 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -285,8 +285,7 @@ def __init__( if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: pass # dtype is already a torch.dtype or "auto" or None elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype + model_init_kwargs["dtype"] = getattr(torch, dtype) else: raise ValueError( "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " From b7ee764d57ebdddfaa39dfee8b8a3a58a97a0a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 19:35:53 +0000 Subject: [PATCH 60/62] fix: correct spelling of 'recommended' in reward_trainer.md --- docs/source/reward_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 68e127343df..1179ab220f8 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -128,7 +128,7 @@ $$ $$ > [!TIP] -> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recomended value is `1e-2`. +> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`. ## Logged metrics From b82810045bee48f50177cdd182b5836289b94bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 20:01:06 +0000 Subject: [PATCH 61/62] update tiny generation script --- scripts/generate_tiny_models.py | 43 +++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 6053687769c..b6cba71e963 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -243,15 +243,20 @@ def init_weights_tiny_model(model): ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - config.hidden_size = 16 - config.num_attention_heads = 4 - config.num_key_value_heads = 2 - config.num_hidden_layers = 2 - config.num_labels = 1 + kwargs = { + "num_labels": 1, + "hidden_size": 16, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_hidden_layers": 2, + "intermediate_size": 32, + } + config = AutoConfig.from_pretrained(model_id, **kwargs) + # Bug in transformers: it ignores num_hidden_layers to build layer_types if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"): config.layer_types = config.layer_types[:2] - model = model_class(config) + model = model_class(config).to(dtype=torch.bfloat16) + init_weights_tiny_model(model) push_to_hub(model, tokenizer, "tiny", suffix) # MoE Reward models @@ -259,15 +264,19 @@ def init_weights_tiny_model(model): ("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) - config = AutoConfig.from_pretrained(model_id) - config.hidden_size = 16 - config.num_attention_heads = 4 - config.num_hidden_layers = 2 - config.num_labels = 1 - config.num_experts = 4 - config.num_experts_per_tok = 2 - model = model_class(config) - push_to_hub(model, tokenizer, "tiny", suffix) + kwargs = { + "num_labels": 1, + "hidden_size": 16, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_hidden_layers": 2, + "intermediate_size": 32, + "num_experts": 4, + "num_experts_per_tok": 2, + } + config = AutoConfig.from_pretrained(model_id, **kwargs) + model = model_class(config).to(dtype=torch.bfloat16) + push_to_hub(model, tokenizer, "tiny", suffix, force=True) # Encoder-decoder models @@ -332,7 +341,5 @@ def init_weights_tiny_model(model): kwargs["perceiver_config"] = {"hidden_size": 16} config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) - model = model_class(config).to(dtype=torch.bfloat16) - push_to_hub(model, processor, "tiny") From f16282b2203177e4ff420488e12c92176b7451bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 20:01:24 +0000 Subject: [PATCH 62/62] rm force --- scripts/generate_tiny_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b6cba71e963..6c9f09b2834 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -276,7 +276,7 @@ def init_weights_tiny_model(model): } config = AutoConfig.from_pretrained(model_id, **kwargs) model = model_class(config).to(dtype=torch.bfloat16) - push_to_hub(model, tokenizer, "tiny", suffix, force=True) + push_to_hub(model, tokenizer, "tiny", suffix) # Encoder-decoder models