Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/best_of_n.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=o
```

There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
This is done by passing a [`~transformers.GenerationConfig`] from the `transformers` library at the time of initialization

```python

Expand Down
2 changes: 1 addition & 1 deletion docs/source/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ trainer.train()

## Use the accelerator cache optimizer

When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`:
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:

```python
training_args = DPOConfig(..., optimize_device_cache=True)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/judges.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pip install trl[judges]

## Using the provided judges

TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
TRL provides several judges out of the box. For example, you can use the [`HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:

```python
from trl import HfPairwiseJudge
Expand Down
8 changes: 4 additions & 4 deletions docs/source/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Trackio, Weights & Biases (wandb) or TensorBoard.

Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for [`PPOTrainer`], or [`GRPOConfig`] for [`GRPOTrainer`]):

```python
# For PPOTrainer
Expand All @@ -19,7 +19,7 @@ grpo_config = GRPOConfig(
)
```

If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., [`PPOConfig`] or [`GRPOConfig`]).

## PPO Logging

Expand Down Expand Up @@ -83,9 +83,9 @@ Here's a brief explanation for the logged metrics provided in the data for the G

### Policy and Loss Metrics

* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in [`GRPOConfig`]) is non-zero.
* `entropy`: Average entropy of token predictions across generated completions.
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
* If Liger GRPOLoss is used (`use_liger_loss: True` in [`GRPOConfig`]):
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
* If standard GRPOLoss is used (`use_liger_loss: False`):
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
Expand Down
2 changes: 1 addition & 1 deletion docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ training_args = DPOConfig(
)
```

For the unpaired version, the user should utilize `BCOConfig` and `BCOTrainer`.
For the unpaired version, the user should utilize [`BCOConfig`] and [`BCOTrainer`].

### Self-Play Preference Optimization for Language Model Alignment

Expand Down
2 changes: 1 addition & 1 deletion docs/source/peft_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scr

## How to use it?

Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
Simply declare a [`~peft.PeftConfig`] object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.

```python
from peft import LoraConfig
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].

> [!TIP]
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`.
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`].

```python
from trl import SFTConfig
Expand Down
23 changes: 12 additions & 11 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def maybe_apply_chat_template(
messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
to the chat template renderer.
tokenizer (`PreTrainedTokenizerBase`):
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not support
Expand Down Expand Up @@ -328,7 +328,7 @@ def unpair_preference_dataset(
Unpair a preference dataset.

Args:
dataset (`Dataset` or `DatasetDict`):
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally
`"prompt"`.
num_proc (`int`, *optional*):
Expand All @@ -337,7 +337,7 @@ def unpair_preference_dataset(
Meaningful description to be displayed alongside with the progress bar while mapping examples.

Returns:
`Dataset`: The unpaired preference dataset.
[`~datasets.Dataset`]: The unpaired preference dataset.

Example:

Expand Down Expand Up @@ -371,7 +371,7 @@ def maybe_unpair_preference_dataset(
Unpair a preference dataset if it is paired.

Args:
dataset (`Dataset` or `DatasetDict`):
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally
`"prompt"`.
num_proc (`int`, *optional*):
Expand All @@ -380,7 +380,8 @@ def maybe_unpair_preference_dataset(
Meaningful description to be displayed alongside with the progress bar while mapping examples.

Returns:
`Dataset` or `DatasetDict`: The unpaired preference dataset if it was paired, otherwise the original dataset.
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The unpaired preference dataset if it was paired, otherwise
the original dataset.

Example:

Expand Down Expand Up @@ -473,7 +474,7 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
```

Or, with the `map` method of `datasets.Dataset`:
Or, with the `map` method of [`~datasets.Dataset`]:

```python
>>> from trl import extract_prompt
Expand Down Expand Up @@ -664,7 +665,7 @@ def pack_dataset(
Pack sequences in a dataset into chunks of size `seq_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
Dataset to pack
seq_length (`int`):
Target sequence length to pack to.
Expand All @@ -679,8 +680,8 @@ def pack_dataset(
Additional keyword arguments to pass to the dataset's map method when packing examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may decrease as sequences
are combined.
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The dataset with packed sequences. The number of examples
may decrease as sequences are combined.

Example:
```python
Expand Down Expand Up @@ -720,15 +721,15 @@ def truncate_dataset(
Truncate sequences in a dataset to a specified `max_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
Dataset to truncate.
max_length (`int`):
Maximum sequence length to truncate to.
map_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the dataset's map method when truncating examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with truncated sequences.
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The dataset with truncated sequences.

Example:
```python
Expand Down
2 changes: 1 addition & 1 deletion trl/mergekit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def merge_models(config: MergeConfig, out_path: str):
Merge two models using mergekit

Args:
config (`MergeConfig`): The merge configuration.
config ([`MergeConfig`]): The merge configuration.
out_path (`str`): The output path for the merged model.
"""
if not is_mergekit_available():
Expand Down
48 changes: 26 additions & 22 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@


class PreTrainedModelWrapper(nn.Module):
r"""
A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the (`~transformers.PreTrained`)
class in order to keep some attributes and methods of the (`~transformers.PreTrainedModel`) class.
"""
Wrapper for a [`~transformers.PreTrainedModel`] implemented as a standard PyTorch [`torch.nn.Module`].

This class provides a compatibility layer that preserves the key attributes and methods of the original
[`~transformers.PreTrainedModel`], while exposing a uniform interface consistent with PyTorch modules. It enables
seamless integration of pretrained Transformer models into custom training, evaluation, or inference workflows.

Attributes:
pretrained_model (`transformers.PreTrainedModel`):
pretrained_model ([`~transformers.PreTrainedModel`]):
The model to be wrapped.
parent_class (`transformers.PreTrainedModel`):
parent_class ([`~transformers.PreTrainedModel`]):
The parent class of the model to be wrapped.
supported_args (`list`):
The list of arguments that are supported by the wrapper class.
Expand Down Expand Up @@ -111,19 +114,20 @@ def __init__(
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiates a new model from a pretrained model from `transformers`. The pretrained model is loaded using the
`from_pretrained` method of the `transformers.PreTrainedModel` class. The arguments that are specific to the
`transformers.PreTrainedModel` class are passed along this method and filtered out from the `kwargs` argument.
`from_pretrained` method of the [`~transformers.PreTrainedModel`] class. The arguments that are specific to the
[`~transformers.PreTrainedModel`] class are passed along this method and filtered out from the `kwargs`
argument.

Args:
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
pretrained_model_name_or_path (`str` or [`~transformers.PreTrainedModel`]):
The path to the pretrained model or its name.
*model_args (`list`, *optional*)):
*model_args (`list`, *optional*):
Additional positional arguments passed along to the underlying model's `from_pretrained` method.
**kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's `from_pretrained` method. We also
pre-process the kwargs to extract the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs also support
`prepare_model_for_kbit_training` arguments from `peft` library.
pre-process the kwargs to extract the arguments that are specific to the
[`~transformers.PreTrainedModel`] class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_kbit_training` arguments from `peft` library.
"""
if kwargs is not None:
peft_config = kwargs.pop("peft_config", None)
Expand Down Expand Up @@ -507,8 +511,8 @@ def add_and_load_reward_modeling_adapter(
def push_to_hub(self, *args, **kwargs):
r"""
Push the pretrained model to the hub. This method is a wrapper around
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation of
`transformers.PreTrainedModel.push_to_hub` for more information.
[`~transformers.PreTrainedModel.push_to_hub`]. Please refer to the documentation of
[`~transformers.PreTrainedModel.push_to_hub`] for more information.

Args:
*args (`list`, *optional*):
Expand All @@ -521,8 +525,8 @@ def push_to_hub(self, *args, **kwargs):
def save_pretrained(self, *args, **kwargs):
r"""
Save the pretrained model to a directory. This method is a wrapper around
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation of
`transformers.PreTrainedModel.save_pretrained` for more information.
[`~transformers.PreTrainedModel.save_pretrained`]. Please refer to the documentation of
[`~transformers.PreTrainedModel.save_pretrained`] for more information.

Args:
*args (`list`, *optional*):
Expand Down Expand Up @@ -596,14 +600,14 @@ def create_reference_model(
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.

Args:
model (`PreTrainedModelWrapper`): The model to be copied.
model ([`PreTrainedModelWrapper`]): The model to be copied.
num_shared_layers (`int`, *optional*):
The number of initial layers that are shared between both models and kept frozen.
pattern (`str`, *optional*): The shared layers are selected with a string pattern
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.

Returns:
`PreTrainedModelWrapper`
[`PreTrainedModelWrapper`]
"""
if is_deepspeed_zero3_enabled():
raise ValueError(
Expand Down Expand Up @@ -665,13 +669,13 @@ def create_reference_model(


class GeometricMixtureWrapper(GenerationMixin):
r"""
"""
Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture.

Args:
model (`PreTrainedModel`): The model to be wrapped.
ref_model (`PreTrainedModel`): The reference model.
generation_config (`GenerationConfig`): The generation config.
model ([`~transformers.PreTrainedModel`]): The model to be wrapped.
ref_model ([`~transformers.PreTrainedModel`]): The reference model.
generation_config ([`~transformers.GenerationConfig`]): The generation config.
mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient.
"""

Expand Down
Loading
Loading