diff --git a/README.md b/README.md index 2df75017d02..b7a0668adc9 100644 --- a/README.md +++ b/README.md @@ -136,23 +136,13 @@ trainer.train() Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): ```python -from trl import RewardConfig, RewardTrainer +from trl import RewardTrainer from datasets import load_dataset -from transformers import AutoModelForSequenceClassification, AutoTokenizer - -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -model = AutoModelForSequenceClassification.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 -) -model.config.pad_token_id = tokenizer.pad_token_id dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2) trainer = RewardTrainer( - args=training_args, - model=model, - processing_class=tokenizer, + model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset, ) trainer.train() diff --git a/docs/source/clis.md b/docs/source/clis.md index 54b7501c1aa..d6e433cada7 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -9,6 +9,7 @@ Currently supported commands are: - `trl dpo`: fine-tune a LLM with DPO - `trl grpo`: fine-tune a LLM with GRPO - `trl kto`: fine-tune a LLM with KTO +- `trl reward`: train a Reward Model - `trl rloo`: fine-tune a LLM with RLOO - `trl sft`: fine-tune a LLM with SFT @@ -41,6 +42,15 @@ trl dpo \ --dataset_name anthropic/hh-rlhf ``` + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized +``` + @@ -78,6 +88,21 @@ Launch with: trl dpo --config dpo_config.yaml ``` + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -138,6 +163,33 @@ Launch with: ```bash trl dpo --config dpo_config.yaml ``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_processes 4 +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +num_processes: 4 +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -217,6 +269,33 @@ Launch with: ```bash trl dpo --config dpo_config.yaml ``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + @@ -224,7 +303,7 @@ trl dpo --config dpo_config.yaml You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. - + ```yaml @@ -258,6 +337,23 @@ Launch with: trl dpo --config dpo_config.yaml ``` + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: trl-lib/tldr-preference + - path: trl-lib/lm-human-preferences-sentiment +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 3ddfe06f72f..d4b69cb44a3 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -533,3 +533,53 @@ training_args = CPOConfig( ... ) ``` + +## Reward Modeling + +Papers relating to the [`RewardTrainer`] + +### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking + +**šŸ“œ Paper**: https://huggingface.co/papers/2312.09244 + +This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination. + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right]. +$$ + +To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows: + +```python +from trl import RewardConfig + +training_args = RewardConfig( + center_rewards_coefficient=0.01, # Ī· in the paper + ... +) +``` + +### Llama 2: Open Foundation and Fine-Tuned Chat Models + +**šŸ“œ Paper**: https://huggingface.co/papers/2307.09288 + +In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings. + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right]. +$$ + +You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper. + +```python +def add_margin(example): + preference_to_margin = { + "significantly better": 1.0, + "better": 2.0/3.0, + "slightly better": 1.0/3.0, + "negligibly better / unsure": 0.0, + } + return {"margin": preference_to_margin[example["preference_label"]]} + +dataset = dataset.map(add_margin) +``` diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index 908230131e0..4898019bd9c 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -1,6 +1,6 @@ # Quickstart -TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). +TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). ## Quick Examples @@ -51,6 +51,21 @@ trainer = DPOTrainer( trainer.train() ``` +### Reward Modeling + +```python +from trl import RewardTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, +) +trainer.train() +``` + ## Command Line Interface Skip the code entirely - train directly from your terminal: @@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ # DPO: Align with preferences trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized + +# Reward: Train a reward model +trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized ``` ## What's Next? diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md index 972c45b0c88..1179ab220f8 100644 --- a/docs/source/reward_trainer.md +++ b/docs/source/reward_trainer.md @@ -2,84 +2,225 @@ [![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) -TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. +## Overview -Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). +TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models. -## Expected dataset type +This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada). -The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). -The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +## Quick start -You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. +This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset. -## Using the `RewardTrainer` +```python +from trl import RewardTrainer +from datasets import load_dataset -After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from šŸ¤— Transformers. -You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. +trainer = RewardTrainer( + model="Qwen/Qwen3-0.6B", + train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"), +) +trainer.train() +``` -### Leveraging šŸ¤— PEFT to train a reward model + -Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! +## Expected dataset type and format + +[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. ```python -from peft import LoraConfig, TaskType -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from trl import RewardTrainer, RewardConfig - -model = AutoModelForSequenceClassification.from_pretrained("gpt2") -peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, -) +# Standard preference (implicit prompt) +{"chosen": "The sky is blue.", + "rejected": "The sky is green."} + +# Conversational preference (implicit prompt) +{"chosen": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}]} + +# Standard preference (explicit prompt) +{"prompt": "The sky is", + "chosen": " blue.", + "rejected": " green."} + +# Conversational preference (explicit prompt) +{"prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}]} +``` -... +If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset: -trainer = RewardTrainer( - model=model, - args=training_args, - processing_class=tokenizer, - train_dataset=dataset, - peft_config=peft_config, -) +```python +from datasets import load_dataset +import json -trainer.train() +dataset = load_dataset("lmarena-ai/arena-human-preference-55k") + +# Filter out ties +dataset = dataset.filter(lambda example: example["winner_tie"] == 0) + +# Create 'chosen' and 'rejected' fields based on the winner column +def response_a_b_to_chosen_rejected(example): + if example["winner_model_a"] == 1: + example["chosen"] = example["response_a"] + example["rejected"] = example["response_b"] + else: + example["chosen"] = example["response_b"] + example["rejected"] = example["response_a"] + return example + +dataset = dataset.map(response_a_b_to_chosen_rejected) +# Convert to conversational format +def make_conversation(example): + prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?" + chosen = json.loads(example["chosen"])[0] + rejected = json.loads(example["rejected"])[0] + return { + "chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}], + "rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}], + } + + +dataset = dataset.map(make_conversation) + +# Keep only necessary columns +dataset = dataset.select_columns(["chosen", "rejected"]) + +print(next(iter(dataset["train"]))) ``` -### Adding a margin to the loss +```json +{ + "chosen": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."}, + ], + "rejected": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."}, + ], +} +``` + +## Looking deeper into the training method + +Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences. + +This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**. + +### Preprocessing and tokenization + +During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference). +The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization. + +### Computing the loss + +Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)āˆ’r(x, y^-)) \\), where \\( σ \\) is the sigmoid function. + +The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences: + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right]. +$$ + +> [!TIP] +> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`. + +## Logged metrics -As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. +While training and evaluating we record the following reward metrics: + +* `global_step`: The total number of optimizer steps taken so far. +* `epoch`: The current epoch number, based on dataset iteration. +* `num_tokens`: The total number of tokens processed so far. +* `loss`: The average loss over the last logging interval. +* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval. +* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval. +* `mean_reward`: The average reward score assigned by the model over the last logging interval. +* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval. +* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval. +* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used. +* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping. + +## Customization + +### Model initialization + +You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to ```python -def add_margin(row): - # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin - return {'margin': row['score_chosen'] - row['score_rejected']} +model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16) +``` -dataset = dataset.map(add_margin) +you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`]. + +```python +from trl import RewardConfig + +training_args = RewardConfig( + model_init_kwargs={"dtype": torch.bfloat16}, +) ``` -### Centering rewards +Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1. + +### Train adapters with PEFT -In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it. +We support tight integration with šŸ¤— PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model. -[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs: +```python +from datasets import load_dataset +from trl import RewardTrainer +from peft import LoraConfig -$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$ +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") -This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`). +trainer = RewardTrainer( + "Qwen/Qwen3-4B", + train_dataset=dataset, + peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model +) + +trainer.train() +``` + +You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. ```python -training_args = RewardConfig( - center_rewards_coefficient=0.01, - ... +from datasets import load_dataset +from trl import RewardTrainer +from peft import AutoPeftModelForCausalLM + +model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True) +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = RewardTrainer( + model=model, + train_dataset=dataset, ) + +trainer.train() ``` -For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932). +> [!TIP] +> When training adapters, you typically use a higher learning rate (ā‰ˆ1e‑3) since only new parameters are being learned. +> +> ```python +> RewardConfig(learning_rate=1e-3, ...) +> ``` + +## Tool Calling with Reward Modeling + +The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: + +* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages) +* The list of available tools in the `tools` column, typically provided as JSON schemas + +For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section. ## RewardTrainer @@ -91,3 +232,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr ## RewardConfig [[autodoc]] RewardConfig + +## DataCollatoForPreference + +[[autodoc]] trainer.reward_trainer.DataCollatorForPreference diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index cc292d65420..e855bd02b02 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -23,7 +23,7 @@ trainer = SFTTrainer( trainer.train() ``` - + ## Expected dataset type and format diff --git a/docs/source/trackio_integration.md b/docs/source/trackio_integration.md index 9f5fe693a9f..4e93120fe19 100644 --- a/docs/source/trackio_integration.md +++ b/docs/source/trackio_integration.md @@ -64,4 +64,4 @@ trainer.train() will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio. - + diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 7a3295d2d20..6c9f09b2834 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -43,6 +43,7 @@ GPT2LMHeadModel, GPTNeoXConfig, GPTNeoXForCausalLM, + GPTNeoXForSequenceClassification, GptOssConfig, GptOssForCausalLM, Idefics2Config, @@ -73,6 +74,7 @@ Qwen3ForSequenceClassification, Qwen3MoeConfig, Qwen3MoeForCausalLM, + Qwen3MoeForSequenceClassification, SmolVLMForConditionalGeneration, T5ForConditionalGeneration, ) @@ -234,22 +236,46 @@ def init_weights_tiny_model(model): push_to_hub(model, tokenizer, "small") # Reward models -for model_id, config_class, model_class, suffix in [ - ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"), - ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"), - ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None), +for model_id, model_class, suffix in [ + ("EleutherAI/pythia-14m", GPTNeoXForSequenceClassification, None), + ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"), + ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"), + ("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None), ]: tokenizer = AutoTokenizer.from_pretrained(model_id) - config = config_class( - vocab_size=len(tokenizer.vocab), - hidden_size=8, - num_attention_heads=4, - num_key_value_heads=2, - num_hidden_layers=2, - intermediate_size=32, - num_labels=1, - ) - model = model_class(config) + kwargs = { + "num_labels": 1, + "hidden_size": 16, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_hidden_layers": 2, + "intermediate_size": 32, + } + config = AutoConfig.from_pretrained(model_id, **kwargs) + # Bug in transformers: it ignores num_hidden_layers to build layer_types + if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"): + config.layer_types = config.layer_types[:2] + model = model_class(config).to(dtype=torch.bfloat16) + init_weights_tiny_model(model) + push_to_hub(model, tokenizer, "tiny", suffix) + +# MoE Reward models +for model_id, model_class, suffix in [ + ("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + kwargs = { + "num_labels": 1, + "hidden_size": 16, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_hidden_layers": 2, + "intermediate_size": 32, + "num_experts": 4, + "num_experts_per_tok": 2, + } + config = AutoConfig.from_pretrained(model_id, **kwargs) + model = model_class(config).to(dtype=torch.bfloat16) push_to_hub(model, tokenizer, "tiny", suffix) @@ -315,7 +341,5 @@ def init_weights_tiny_model(model): kwargs["perceiver_config"] = {"hidden_size": 16} config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) - model = model_class(config).to(dtype=torch.bfloat16) - push_to_hub(model, processor, "tiny") diff --git a/tests/test_cli.py b/tests/test_cli.py index 2f8891c7333..23b5d6bcff7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -67,6 +67,13 @@ def test_kto(self): with patch("sys.argv", command.split(" ")): main() + def test_reward(self): + from trl.cli import main + + command = f"trl reward --output_dir {self.tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_implicit_prompt_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + def test_rloo(self): from trl.cli import main diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index a92c2c8cb0f..b4d53e16941 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -12,217 +12,823 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pathlib +import unittest import torch -from datasets import Dataset, load_dataset +from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template -from trl.trainer.reward_trainer import _tokenize +from trl import RewardConfig, RewardTrainer +from trl.trainer.reward_trainer import DataCollatorForPreference from .testing_utils import TrlTestCase if is_peft_available(): - from peft import LoraConfig, TaskType + from peft import LoraConfig, PeftModel, get_peft_model + + +class TestDataCollatorForPreference(TrlTestCase): + def test_basic_padding(self): + """Test basic padding functionality without completion masks.""" + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + + def test_pad_to_multiple_of(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]]) + ) + + def test_single_example(self): + """Test collator with a single example.""" + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + + def test_different_pad_token_id(self): + """Test with different pad token ID.""" + collator = DataCollatorForPreference(pad_token_id=999) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + + def test_collate_with_margin(self): + self.collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2}, + ] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2])) class RewardTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() - self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) - self.model.config.pad_token_id = self.tokenizer.pad_token_id - - def test_preprocessing_conversational(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + @parameterized.expand( + [ + ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",), + ("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification",), + ("trl-internal-testing/tiny-LlamaForSequenceClassification-3.2",), + ] + ) + def test_train(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @parameterized.expand( + [ + ("standard_preference",), + ("conversational_preference",), + ("standard_implicit_prompt_preference",), + ("conversational_implicit_prompt_preference",), + ] + ) + def test_train_dataset_types(self, config_name): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, ) - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) - self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) - def test_preprocessing_standard(self): - # No chat template, so we load a fresh tokenizer - tokenizer = AutoTokenizer.from_pretrained(self.model_id) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model(self): + # Instantiate the model + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + ) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_from_causal_lm(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer}) - self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) - def test_train_full(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model_dtype(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + model_init_kwargs={"dtype": torch.float16}, + learning_rate=0.1, + report_to="none", + ) trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, ) + + # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() + # Check that the training loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + + # Check the params have changed for n, param in previous_trainable_params.items(): + # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does + # locally. We ignore this parameter for now + if "layernorm" in n: + continue new_param = trainer.model.get_parameter(n) - if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + # Check the torch dtype + self.assertEqual(new_param.dtype, torch.float16) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_dense_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") - def test_train_full_pretokenized(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), ) + + # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() + # Check that the training loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + + # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") @require_peft - def test_train_lora(self): - peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + def test_train_moe_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( - model=self.model, + model=model_id, args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset, - peft_config=peft_config, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]), ) - previous_trainable_params = {} - previous_non_trainable_params = {} - # due to a change in the way the modules to save are dealt in PEFT. - trainable_params_name = ["lora", "modules_to_save"] + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_dense_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) - # check gradients are not None - for n, param in trainer.model.named_parameters(): - if any(t in n for t in trainable_params_name): - previous_trainable_params[n] = param.clone() - else: - previous_non_trainable_params[n] = param.clone() + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + # Train the model trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_moe_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") - # Check that the non trainable parameters have not changed - for n, param in previous_non_trainable_params.items(): + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") @require_peft - def test_train_lora_pretokenized(self): - peft_config = LoraConfig( - task_type=TaskType.SEQ_CLS, - inference_mode=False, - r=8, - lora_alpha=32, - lora_dropout=0.1, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") - dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) - dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + def test_train_with_peft_model_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + model = get_peft_model(model, LoraConfig()) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Verify model is a PeftModel + self.assertIsInstance(trainer.model, PeftModel) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_pretokenized_data(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def tokenize_example(example): + return { + "chosen_input_ids": tokenizer(example["chosen"]).input_ids, + "rejected_input_ids": tokenizer(example["rejected"]).input_ids, + } + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"]) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_iterable_dataset(self): + # Get the dataset + dataset = load_dataset( + "trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True + ) + + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") trainer = RewardTrainer( - model=self.model, + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_chat_template_kwargs(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5") + # The following template is a simplified version of the Qwen chat template, where an additional argument + # `role_capital` is used to control the capitalization of roles. + tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n\\n" + message.content + "\\n" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' + + dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))]) + + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset, - peft_config=peft_config, + train_dataset=dataset, ) - previous_trainable_params = {} - previous_non_trainable_params = {} - # due to a change in the way the modules to save are dealt in PEFT. - trainable_params_name = ["lora", "modules_to_save"] + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") - # check gradients are not None - for n, param in trainer.model.named_parameters(): - if any(t in n for t in trainable_params_name): - previous_trainable_params[n] = param.clone() - else: - previous_non_trainable_params[n] = param.clone() + def test_train_with_set_chat_template_from_model(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") + # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default + trainer = RewardTrainer( + model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # Check that the parameters have changed + # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_set_chat_template_from_path(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"), + report_to="none", + ) + # trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default + trainer = RewardTrainer( + model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification", + args=training_args, + train_dataset=dataset, + ) - # Check that the non trainable parameters have not changed - for n, param in previous_non_trainable_params.items(): + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + # Check that the template saved in the output directory is the same as the one used for training + template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" + self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + + with open(template_path) as f: + template_content = f.read() + with open(training_args.chat_template_path) as f: + original_template_content = f.read() + self.assertEqual( + template_content, original_template_content, "Chat template content does not match the original" + ) + + @unittest.skip("Skipping until we have a dataset with tool calls") + def test_train_toolcall_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/toolcall", split="train") - def test_margin(self): - dummy_dataset_dict = { - "input_ids_chosen": [ - torch.LongTensor([0, 1, 2]), - ], - "attention_mask_chosen": [ - torch.LongTensor([1, 1, 1]), - ], - "input_ids_rejected": [ - torch.LongTensor([0, 2]), - ], - "attention_mask_rejected": [ - torch.LongTensor([1, 1]), - ], - "margin": [ - torch.FloatTensor([1.0]), - ], - } - dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_eval(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Train the model + trainer.train() + + # Check that the eval loss is not None + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + + def test_train_with_multiple_eval_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, + ) + # Train the model + trainer.train() + + # Check that the eval losses are not None + self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) + self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + + def test_train_with_gradient_checkpointing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_tag_added(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + ) + + for tag in ["reward-trainer", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + @require_peft + def test_tag_added_peft(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + peft_config=LoraConfig(), ) - batch = [dummy_dataset[0]] - batch = trainer.data_collator(batch) - batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True) + for tag in ["reward-trainer", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + def test_train_with_margin(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") - l_val = -torch.nn.functional.logsigmoid( - outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] - ).mean() + def add_margin(example): + # dummy margin based on the length of the chosen summary + return {"margin": len(example["chosen"])} - self.assertLess(abs(loss - l_val), 1e-6) + dataset = dataset.map(add_margin) - def test_tags(self): - dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + # Initialize the trainer training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") trainer = RewardTrainer( - model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, ) - self.assertEqual(trainer.model.model_tags, trainer._tag_names) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_center_rewards_coefficient(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/trl/cli.py b/trl/cli.py index 9245f31dabd..199b1c26703 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -24,6 +24,7 @@ from .scripts.env import print_env from .scripts.grpo import make_parser as make_grpo_parser from .scripts.kto import make_parser as make_kto_parser +from .scripts.reward import make_parser as make_reward_parser from .scripts.rloo import make_parser as make_rloo_parser from .scripts.sft import make_parser as make_sft_parser from .scripts.utils import TrlParser @@ -45,6 +46,7 @@ def main(): subparsers.add_parser("env", help="Print the environment information") make_grpo_parser(subparsers) make_kto_parser(subparsers) + make_reward_parser(subparsers) make_rloo_parser(subparsers) make_sft_parser(subparsers) make_vllm_serve_parser(subparsers) @@ -111,6 +113,15 @@ def main(): args.training_script_args = sys.argv[2:] # remove "trl" and "kto" launch_command(args) # launch training + elif args.command == "reward": + # Get the default args for the launch command + reward_training_script = resources.files("trl.scripts").joinpath("reward.py") + args = launch_command_parser().parse_args([str(reward_training_script)]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "reward" + launch_command(args) # launch training + elif args.command == "rloo": # Get the default args for the launch command rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py") diff --git a/trl/models/utils.py b/trl/models/utils.py index 22ea572ca06..efdba75fbda 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -94,11 +94,8 @@ def setup_chat_format( Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. - - - This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead. - - + > [!WARNING] + > This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead. If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. diff --git a/trl/scripts/reward.py b/trl/scripts/reward.py new file mode 100644 index 00000000000..f34b04e80ee --- /dev/null +++ b/trl/scripts/reward.py @@ -0,0 +1,109 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import argparse +import os +from typing import Optional + +from accelerate import logging +from datasets import load_dataset + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + RewardConfig, + RewardTrainer, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the RewardTrainer + trainer = RewardTrainer( + model=model_args.model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("āœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"šŸ’¾ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"šŸ¤— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: Optional[argparse._SubParsersAction] = None): + dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser( + "reward", help="Run the reward training script", dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/trl/templates/rm_model_card.md b/trl/templates/rm_model_card.md new file mode 100644 index 00000000000..685ed007bd5 --- /dev/null +++ b/trl/templates/rm_model_card.md @@ -0,0 +1,55 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_name }} + +This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +text = "The capital of France is Paris." +rewarder = pipeline(model="{{ hub_model_id }}", device="cuda") +output = rewarder(text)[0] +print(output["score"]) +``` + +## Training procedure + +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} + +This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. + +### Framework versions + +- TRL: {{ trl_version }} +- Transformers: {{ transformers_version }} +- Pytorch: {{ pytorch_version }} +- Datasets: {{ datasets_version }} +- Tokenizers: {{ tokenizers_version }} + +## Citations + +{% if trainer_citation %}Cite {{ trainer_name }} as: + +```bibtex +{{ trainer_citation }} +```{% endif %} + +Cite TRL as: + +```bibtex +{% raw %}@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +}{% endraw %} +``` diff --git a/trl/trainer/base_trainer.py b/trl/trainer/base_trainer.py index e7cb05def71..bb88cbfc934 100644 --- a/trl/trainer/base_trainer.py +++ b/trl/trainer/base_trainer.py @@ -28,6 +28,7 @@ class BaseTrainer(Trainer): _tag_names = [] _name = "Base" _paper = {} + _template_file = None def create_model_card( self, @@ -78,6 +79,7 @@ def create_model_card( comet_url=get_comet_experiment_url(), trainer_name=self._name, trainer_citation=self._paper.get("citation"), + template_file=self._template_file, paper_title=self._paper.get("title"), paper_id=self._paper.get("id"), ) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 9a3aabc39ee..33d248e635e 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from transformers import TrainingArguments @@ -32,22 +32,53 @@ class may differ from those in [`~transformers.TrainingArguments`]. command line. Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the - limit. This argument is required if you want to use the default data collator. + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want + to include the load balancing/auxilliary loss as a part of the final loss, remember to set + `output_router_logits=True` in this dictionary. + chat_template_path (`str`, *optional*): + If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory + or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must + ensure that any special tokens referenced in the template are added to the tokenizer and that the model's + embedding layer is resized accordingly. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. + + > Parameters that control the data preprocessing + dataset_num_proc (`int`, *optional*): Number of processes to use for processing the dataset. + eos_token (`str`, *optional*): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to + `processing_class.eos_token`. + pad_token (`str`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence + exceeds this value. If `None`, no filtering is applied. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + + > Parameters that control the training + center_rewards_coefficient (`float`, *optional*): Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the - dataset is pretokenized. + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. """ + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-4, + metadata={"help": "The initial learning rate for AdamW."}, + ) logging_steps: float = field( default=10, metadata={ @@ -70,21 +101,59 @@ class may differ from those in [`~transformers.TrainingArguments`]. }, ) - max_length: Optional[int] = field( - default=1024, + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " - "exceed the limit. This argument is required if you want to use the default data collator." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `RewardTrainer` is provided as a string." + }, + ) + chat_template_path: Optional[str] = field( + default=None, + metadata={ + "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local " + "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, " + "you must ensure that any special tokens referenced in the template are added to the tokenizer and " + "that the model's embedding layer is resized accordingly." }, ) disable_dropout: bool = field( default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, + metadata={"help": "Whether to disable dropout in the model."}, ) + + # Parameters that control the data preprocessing dataset_num_proc: Optional[int] = field( default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) + eos_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." + }, + ) + pad_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied." + }, + ) + pad_to_multiple_of: Optional[int] = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + + # Parameters that control the training center_rewards_coefficient: Optional[float] = field( default=None, metadata={ @@ -92,15 +161,11 @@ class may differ from those in [`~transformers.TrainingArguments`]. "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." }, ) - remove_unused_columns: bool = field( + activation_offloading: bool = field( default=False, - metadata={ - "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " - "if the dataset is pretokenized." - }, + metadata={"help": "Whether to offload the activations to the CPU."}, ) def __post_init__(self): self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 - super().__post_init__() diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 79570eff582..5408db49967 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -12,132 +12,348 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import logging +import os +import re from collections import defaultdict -from dataclasses import FrozenInstanceError, replace +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Optional, Union -import pandas as pd import torch import torch.nn as nn -from accelerate import PartialState, logging -from accelerate.utils import gather_object -from datasets import Dataset +import transformers +from accelerate import PartialState +from accelerate.logging import get_logger +from datasets import Dataset, IterableDataset from transformers import ( - BaseImageProcessor, + AutoModelForSequenceClassification, + AutoTokenizer, DataCollator, - FeatureExtractionMixin, PreTrainedModel, PreTrainedTokenizerBase, - ProcessorMixin, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_callback import TrainerCallback -from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction -from transformers.utils import is_peft_available, is_rich_available +from transformers.utils import is_peft_available -from ..data_utils import maybe_apply_chat_template -from ..models import prepare_peft_model +from ..data_utils import is_conversational +from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model from .base_trainer import BaseTrainer from .reward_config import RewardConfig -from .utils import ( - RewardDataCollatorWithPadding, - compute_accuracy, - decode_and_strip_padding, - disable_dropout_in_model, - log_table_to_comet_experiment, - print_rich_table, -) +from .utils import disable_dropout_in_model, pad, remove_none_values if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel + + +logger = get_logger(__name__) -logger = logging.get_logger(__name__) +# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly +# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to +# avoid confusing users. +@contextmanager +def suppress_from_pretrained_warning(logger: logging.Logger): + pattern = re.compile( + r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: " + r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and " + r"inference\.$" + ) + class _Filter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return not pattern.search(record.getMessage()) -def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]: - """Tokenize a batch from a reward modelling dataset.""" - new_examples = { - "input_ids_chosen": [], - "attention_mask_chosen": [], - "input_ids_rejected": [], - "attention_mask_rejected": [], - } - for chosen, rejected in zip(batch["chosen"], batch["rejected"]): - tokenized_chosen = tokenizer(chosen) - tokenized_rejected = tokenizer(rejected) - new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) - new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + f = _Filter() + logger.addFilter(f) + try: + yield + finally: + logger.removeFilter(f) + + +@dataclass +class DataCollatorForPreference(DataCollatorMixin): + """ + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. - return new_examples + This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and + `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch + corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + + Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a + `"margin"` key with a tensor of margins. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.reward_trainer import DataCollatorForPreference + + >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]])} + + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]]), + 'margin': tensor([0.5, 0.0])} + ``` + """ + + pad_token_id: int + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + # Convert to tensor + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + if "margin" in examples[0]: + margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float) + input_ids = chosen_input_ids + rejected_input_ids + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + output = {} + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if "margin" in examples[0]: + output["margin"] = margins + return output class RewardTrainer(BaseTrainer): """ - Trainer for custom reward. + Trainer for Outcome-supervised Reward Models (ORM). + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from trl import RewardTrainer + from datasets import load_dataset + + dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + + trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` Args: - model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*): - Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`]. + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in + `args.model_init_kwargs`. + - A sequence classification [`~transformers.PreTrainedModel`] object. args ([`RewardConfig`], *optional*): - Training arguments. + Configuration for this trainer. If `None`, a default configuration is used. data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - [`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum - length of the sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`], *optional*): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]): - Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a - dictionary string to float. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.reward_trainer.DataCollatorForPreference`]. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and + explicit prompt). The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and + `rejected_input_ids` fields. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): + Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with + [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be + set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the + default. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a + boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the + function needs to calculate and return the global summary statistics rather than accumulating the + batch-level statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and - return the logits to be used for metrics computation. - peft_config (`dict`, *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be - wrapped with the specified PEFT adapter. + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded + model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration + to ensure that the reward head is properly trained. """ _tag_names = ["trl", "reward-trainer"] _name = "Reward" + _template_file = "rm_model_card.md" def __init__( self, - model: Optional[Union[PreTrainedModel, nn.Module]] = None, + model: Union[str, PreTrainedModel], args: Optional[RewardConfig] = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, + peft_config: Optional["PeftConfig"] = None, ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = RewardConfig(f"{model_name}-Reward") + + # Model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + model_init_kwargs["dtype"] = getattr(torch, dtype) + else: + raise ValueError( + "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + with suppress_from_pretrained_warning(transformers.modeling_utils.logger): + model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model_id) + + # Handle pad token for processors or tokenizers + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = processing_class.convert_tokens_to_ids(eos_token) + if eos_token_id is None: + raise ValueError( + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." + ) + processing_class.eos_token_id = eos_token_id + + if args.chat_template_path is not None: + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + else: + added_tokens = [] + + # PEFT configuration and model wrapping + if peft_config is not None: + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) @@ -145,78 +361,47 @@ def __init__( if args.disable_dropout: disable_dropout_in_model(model) - if compute_metrics is None: - compute_metrics = compute_accuracy + # Pad token (needed for SequenceClassification models) + # If not provided, use the one from the processing class or the eos token if the processing class does not have + # a pad token. + pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token + pad_token_id = processing_class.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + model.config.pad_token_id = pad_token_id + processing_class.pad_token_id = pad_token_id + # Data collator if data_collator is None: - if processing_class is None: - raise ValueError( - "A processing_class must be specified when using the default RewardDataCollatorWithPadding" - ) + data_collator = DataCollatorForPreference( + pad_token_id=pad_token_id, + pad_to_multiple_of=args.pad_to_multiple_of, + ) - max_length = args.max_length + # Dataset + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - data_collator = RewardDataCollatorWithPadding(processing_class) + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 - if args.remove_unused_columns: - try: # for bc before https://github.com/huggingface/transformers/pull/25435 - args.remove_unused_columns = False - except FrozenInstanceError: - args = replace(args, remove_unused_columns=False) - # warn users - logger.warning( - "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_reward_data_collator = True - else: - self.use_reward_data_collator = False - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result, - # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point - # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's - # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been - # issued. - model.warnings_issued["estimate_tokens"] = True - - if "input_ids_chosen" not in train_dataset.column_names: - with PartialState().main_process_first(): - fn_kwargs = {"tokenizer": processing_class} - train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}) - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - train_dataset = train_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class} - ) - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs=fn_kwargs, - batched=True, - num_proc=args.dataset_num_proc, - ) - # This filter is important because otherwise you get samples that exceed the model's context length and - # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the - # user might get surprised if N samples are missing from training. - eval_dataset = eval_dataset.filter( - lambda x: len(x["input_ids_chosen"]) <= max_length - and len(x["input_ids_rejected"]) <= max_length, - num_proc=args.dataset_num_proc, - ) + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation super().__init__( model=model, @@ -225,35 +410,140 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty. + self.can_return_loss = True + self.label_names = [] + + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: PreTrainedTokenizerBase, + args: RewardConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) + + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + if not is_processed: + # Add EOS token to the end of the sequences if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if "rejected" in example and not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + return example + + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + **map_kwargs, + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class): + if "prompt" in example: # explicit prompt case + example["chosen"] = example["prompt"] + example["chosen"] + example["rejected"] = example["prompt"] + example["rejected"] + + if is_conversational(example): + chosen_input_ids = processing_class.apply_chat_template( + example["chosen"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + rejected_input_ids = processing_class.apply_chat_template( + example["rejected"], + tools=example.get("tools"), + **example.get("chat_template_kwargs", {}), + ) + output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} + else: + output = { + "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], + "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], + } + return output + + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + + # Filter samples that are longer than `max_length` + if args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens" + dataset = dataset.filter( + lambda example: len(example["chosen_input_ids"]) <= args.max_length + and len(example["rejected_input_ids"]) <= args.max_length, + **map_kwargs, + ) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). + if self._signature_columns is None: + self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"] + def compute_loss( self, - model: Union[PreTrainedModel, nn.Module], + model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - rewards_chosen = model( - input_ids=inputs["input_ids_chosen"], - attention_mask=inputs["attention_mask_chosen"], - return_dict=True, - )["logits"] - rewards_rejected = model( - input_ids=inputs["input_ids_rejected"], - attention_mask=inputs["attention_mask_rejected"], - return_dict=True, - )["logits"] - # calculate loss, optionally modulate with margin + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ): + """ + Compute training loss and additionally compute token accuracies + """ + mode = "train" if self.model.training else "eval" + + # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing + inputs["use_cache"] = False + outputs = model(**inputs) + + # Split the rewards into chosen and rejected + rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2) + + # Calculate loss, optionally modulate with margin if "margin" in inputs: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() else: @@ -262,86 +552,45 @@ def compute_loss( if self.args.center_rewards_coefficient is not None: loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) - if return_outputs: - return loss, { - "rewards_chosen": rewards_chosen, - "rewards_rejected": rewards_rejected, - } - return loss - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] + if mode == "train": + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + # Compute min, mean, max, accuracy and margin with torch.no_grad(): - loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) - - if prediction_loss_only: - return (loss, None, None) - - loss = loss.detach() - logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) - logits = nested_detach(logits) - # Stack accepted against rejected, mean over logits - # and softmax to get preferences between accepted and rejected to sum to 1 - logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T - - labels = torch.zeros(logits.shape[0]) - labels = self._prepare_inputs(labels) - - return loss, logits, labels - - def evaluate(self, *args, **kwargs): - num_print_samples = kwargs.pop("num_print_samples", 4) - self.visualize_samples(num_print_samples) - return super().evaluate(*args, **kwargs) - - def visualize_samples(self, num_print_samples: int): - """ - Visualize the reward model logits prediction - - Args: - num_print_samples (`int`, defaults to `4`): - The number of samples to print. Set to `-1` to print all samples. - """ - eval_dataloader = self.get_eval_dataloader() - table = defaultdict(list) - for _, inputs in enumerate(eval_dataloader): - _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) - chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class) - rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class) - table["chosen_text"].extend(gather_object(chosen_text)) - table["rejected_text"].extend(gather_object(rejected_text)) - table["logits"].extend( - gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) - ) - if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: - break - df = pd.DataFrame(table) - if self.accelerator.process_index == 0: - if is_rich_available(): - print_rich_table(df[:num_print_samples]) - if "wandb" in self.args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="completions.csv", - table=df, - ) + all_rewards = self.accelerator.gather(outputs.logits) + self._metrics[mode]["min_reward"].append(all_rewards.min().item()) + self._metrics[mode]["mean_reward"].append(all_rewards.mean().item()) + self._metrics[mode]["max_reward"].append(all_rewards.max().item()) + + mean_accuracy = (rewards_chosen > rewards_rejected).float().mean() + mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item() + self._metrics[mode]["accuracy"].append(mean_accuracy) + + mean_margin = (rewards_chosen - rewards_rejected).mean() + mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean() + self._metrics[mode]["margin"].append(mean_margin.item()) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs.update(metrics) + super().log(logs, start_time) + self._metrics[mode].clear() # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 6f5db07ba50..c0f26bc9373 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -15,10 +15,9 @@ import contextlib import os from collections import defaultdict -from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -56,6 +55,7 @@ entropy_from_logits, flush_left, pad, + remove_none_values, selective_log_softmax, ) @@ -66,7 +66,6 @@ logger = logging.get_logger(__name__) -TListOrMapping = TypeVar("TListOrMapping", list, Mapping) FLASH_ATTENTION_VARIANTS = { "flash_attention_2", @@ -77,38 +76,6 @@ } -def remove_none_values(example: TListOrMapping) -> TListOrMapping: - """ - Recursively removes entries with `None` values from a nested structure (list or dictionary). - - Args: - example (`list` or `Mapping`): - Input nested structure (list or dictionary) from which to remove `None`. - - Example: - ```python - >>> [ - ... { - ... "a": {"aa": None, "ab": 1}, - ... "b": "my_string", - ... } - ... ] - >>> remove_none_values(example) - [{'a': {'ab': 1}, 'b': 'my_string'}] - ``` - """ - if isinstance(example, list): - return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] - elif isinstance(example, Mapping): - return { - key: remove_none_values(value) if isinstance(value, (dict, list)) else value - for key, value in example.items() - if value is not None - } - else: - raise TypeError("Input must be a list or a dictionary.") - - def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]: return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1ef75bd0da2..a12fdec7b44 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -18,11 +18,12 @@ import os import random import socket -from collections.abc import Sequence, Sized +import warnings +from collections.abc import Mapping, Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version from itertools import accumulate -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, TypeVar, Union import numpy as np import pandas as pd @@ -219,6 +220,10 @@ class RewardDataCollatorWithPadding: r""" Reward DataCollator class that pads the inputs to the maximum length of the batch. + > [!WARNING] + > This class is deprecated and will be removed in version 0.27.0. Please use + `trl.trainer.reward_trainer.DataCollatorForPreference` instead. + Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer used for encoding the data. @@ -235,6 +240,14 @@ class RewardDataCollatorWithPadding: pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use " + "`trl.trainer.reward_trainer.DataCollatorForPreference` instead.", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: features_chosen = [] features_rejected = [] @@ -1241,6 +1254,10 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize """ Decodes the input tensor and strips the padding tokens. + > [!WARNING] + > This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy + > the code into your codebase and use it from there. + Args: inputs (`torch.Tensor`): The input tensor to be decoded. @@ -1251,6 +1268,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize `list[str]`: The list of decoded strings with padding tokens stripped. """ + warnings.warn( + "The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want " + "to keep using it, please copy the code into your codebase and use it from there.", + DeprecationWarning, + ) decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) return [d.replace(tokenizer.pad_token, "") for d in decoded] @@ -1264,6 +1286,7 @@ def generate_model_card( wandb_url: Optional[str], trainer_name: str, trainer_citation: Optional[str] = None, + template_file: Optional[str] = None, paper_title: Optional[str] = None, paper_id: Optional[str] = None, comet_url: Optional[str] = None, @@ -1290,6 +1313,8 @@ def generate_model_card( Trainer name. trainer_citation (`str` or `None`, defaults to `None`): Trainer citation as a BibTeX entry. + template_file (`str` *optional*): + Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`. paper_title (`str` or `None`, defaults to `None`): Paper title. paper_id (`str` or `None`, defaults to `None`): @@ -1307,9 +1332,10 @@ def generate_model_card( model_name=model_name, tags=["generated_from_trainer", *tags], ) + template_file = template_file or "lm_model_card.md" card = ModelCard.from_template( card_data, - template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), + template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")), base_model=base_model, model_name=model_name, hub_model_id=hub_model_id, @@ -1956,6 +1982,41 @@ def process_sequence(ids, mask): return torch.stack(truncated_seq), torch.stack(truncated_mask) +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) + + +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. + + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") + + def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: """ Create a model from a given path using the specified initialization arguments.