Skip to content

Commit 8472da8

Browse files
authored
[algo] support RLOO algorithm (#6325)
* rloo init * nits * fix script * fix script & batch norm * update script * typo
1 parent 414bb08 commit 8472da8

File tree

10 files changed

+410
-68
lines changed

10 files changed

+410
-68
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# REINFORCE Leave-One-Out (RLOO)
2+
3+
**版本依赖**:ms-swift>=3.10
4+
5+
[REINFORCE Leave-One-Out (RLOO)](https://arxiv.org/abs/2402.14740) 基于经典的 REINFORCE 策略梯度方法,通过留一法(Leave-One-Out)构造无偏的优势函数基线。
6+
7+
## 算法原理
8+
9+
为便于理解,我们基于 GRPO(Group Relative Policy Optimization)算法进行对比说明。
10+
11+
### GRPO vs RLOO 的主要区别
12+
13+
GRPO 和 RLOO 都采用组内对比的方式来估计优势函数,避免了全局基线估计带来的高方差问题。两者的核心区别主要体现在以下两个方面:
14+
15+
#### 区别1:优势函数基线的构造方法
16+
17+
**1. GRPO (Group Relative Policy Optimization)**
18+
19+
GRPO 对每个 prompt 生成 $G$ 个响应样本,使用**组内所有样本的均值和标准差**进行标准化:
20+
21+
$$
22+
\hat{A}_{i} = \frac{R_i - \text{mean}(\{R_j\}_{j=1}^G)}{\text{std}(\{R_j\}_{j=1}^G)}
23+
$$
24+
25+
其中:
26+
- $R_i$ 是第 $i$ 个样本的奖励值
27+
- $\text{mean}(\{R_j\}_{j=1}^G) = \frac{1}{G}\sum_{j=1}^G R_j$ 是组内均值
28+
- $\text{std}(\{R_j\}_{j=1}^G)$ 是组内标准差
29+
30+
**2. RLOO (REINFORCE Leave-One-Out)**
31+
32+
RLOO 对每个 prompt 生成 $K$ 个响应样本,使用 **留一法(Leave-One-Out)** 构造基线,即第 $i$ 个样本的基线为除自己外的其他 $K-1$ 个样本的均值:
33+
34+
$$
35+
\hat{A}_{i} = R_i - \frac{1}{K-1}\sum_{j \neq i} R_j
36+
$$
37+
38+
这个公式可以等价地改写为:
39+
40+
$$
41+
\hat{A}_{i} = \frac{K}{K-1} \left(R_i - \bar{R}\right)
42+
$$
43+
44+
其中 $\bar{R} = \frac{1}{K}\sum_{j=1}^K R_j$ 是组内所有样本的均值。
45+
46+
> **说明**:这里使用 $K$ 对齐论文符号,与 GRPO 中的 $G$ 含义一致,均对应配置参数 `num_generations`
47+
48+
**为什么使用留一法?**
49+
50+
留一法的关键优势在于**无偏性**。对于第 $i$ 个样本,其奖励 $R_i$ 和基线 $\frac{1}{K-1}\sum_{j \neq i} R_j$ 是独立的,因此优势估计是无偏的。相比之下,如果使用包含自身的均值作为基线,会引入偏差。
51+
52+
#### 区别2:KL 散度正则化项的处理方式
53+
54+
为防止策略偏离参考策略过远,两种算法都引入了 KL 散度正则化,但处理方式不同:
55+
56+
**GRPO**:将 KL 散度作为独立的正则化项添加到[损失函数](../GetStarted/GRPO.md#算法原理)中:
57+
58+
$$
59+
\mathcal{L}(\theta) = -\mathbb{E}\left[\hat{A}_i \log \pi_\theta(a_i|s_i)\right] + \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})
60+
$$
61+
62+
**RLOO**:将 KL 散度直接整合到奖励项中,构造修正后的奖励:
63+
64+
$$
65+
R'_i = R_i - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})
66+
$$
67+
68+
其中 $\beta$ 是 KL 散度的权重系数(对应参数 `beta`),$\pi_{\text{ref}}$ 是参考策略(通常是 SFT 模型或初始策略)。
69+
70+
## 参数设置
71+
72+
我们可以基于 `GRPOTrainer`,通过设置以下参数实现 RLOO 训练:
73+
```bash
74+
# 基本 RLOO 配置
75+
--advantage_estimator rloo # 使用 RLOO 的留一法优势函数计算
76+
--kl_in_reward true # 将 KL 散度项整合到奖励中(RLOO 默认方式)
77+
```
78+
79+
训练可以参考该[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/rloo.sh)
80+
81+
### 重要参数说明
82+
83+
- **`--advantage_estimator`**:选择优势函数估计方法
84+
- `grpo`(默认):使用组内均值和标准差进行标准化
85+
- `rloo`:使用留一法(Leave-One-Out)构造基线
86+
87+
- **`--kl_in_reward`**:控制 KL 散度正则化项的处理位置
88+
- `false`:KL 散度作为损失函数的独立正则化项(GRPO 方式)
89+
- `true`:KL 散度直接从奖励中扣除,构造修正后的奖励(RLOO 方式)
90+
91+
- **`--num_generations`**:每个 prompt 生成的样本数量 $K$
92+
93+
- **`--beta`**:KL 散度正则化系数 $\beta$
94+
- 控制策略更新的保守程度
95+
96+
其他参数与 [GRPO参数](../../命令行参数.md#grpo参数)一致

docs/source/Instruction/GRPO/GetStarted/GRPO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ GRPOTrainer在ms-swift3.5进行了代码重构,如果你使用的swift版本<3
44

55
[GRPO(Group Relative Policy Optimization)](https://arxiv.org/abs/2402.03300) 算法利用组内相对优势计算来替代 PPO 算法中独立的价值模型,并直接在损失函数中加入 KL 散度惩罚来提高训练稳定性。
66

7+
## 算法原理
78

89
GRPO 目标函数
910

docs/source/Instruction/命令行参数.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -551,22 +551,25 @@ reward模型参数将在PPO、GRPO中使用。
551551
- offload_model: 是否在vLLM推理时 offload 模型,默认为False。
552552
- completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。
553553
`total`限制所有对话轮次的总输出长度不超过`max_completion_length`, `per_round`限制每一轮的输出长度。
554-
- num_iterations: 每个批次代更新次数,默认为1。
554+
- num_iterations: 每条数据的更新次数,[GRPO论文](https://arxiv.org/abs/2402.03300)中的 $\mu$ 值,默认为1。
555555
- epsilon: clip 系数,默认为0.2。
556556
- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。
557+
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
558+
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
559+
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
557560
- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。
561+
- importance_sampling_level: 控制重要性采样比计算,可选项为 `token``sequence``token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`
562+
- advantage_estimator: 优势计算函数,默认为 `grpo`,即计算组内相对优势,可选项为 `grpo`[`rloo`](./GRPO/AdvancedResearch/RLOO.md)
563+
- kl_in_reward: 控制 KL 散度正则项的处理位置;`false`(默认)表示作为损失函数的独立正则项,`true`表示将 KL 直接并入奖励(从奖励中扣除)。
564+
- scale_rewards: 指定奖励的缩放策略,默认为`group`, 计算组内按标准差对奖励进行缩放。`batch`对应在整个批次范围内按标准差对奖励进行缩放,`none`对应不进行缩放。在ms-swift<3.10时,为bool变量,true对应group,false对应none。
558565
- sync_ref_model: 是否定期同步ref_model,默认为False。
559566
- ref_model_mixup_alpha: 控制在更新过程中model和先前ref_model之间的混合。更新公式为 $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$。默认为0.6。
560567
- ref_model_sync_steps:同步频率,默认为512。
561568
- move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。
562569
- multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。
563570
- max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。
564-
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
565-
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
566-
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
567571
- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md)
568572
- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics)
569-
- importance_sampling_level: 控制重要性采样比计算,可选项为 `token``sequence``token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`
570573

571574
cosine 奖励参数
572575
- cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -563,26 +563,27 @@ The meanings of the following parameters can be referenced [here](https://huggin
563563
When set to `total`, the total output length across all turns must not exceed `max_completion_length`.
564564
When set to `per_round`, each individual turn's output length is limited separately.
565565
Defaults to `per_round`. Currently only takes effect in colocate mode.
566-
- top_k: Default is 50.
567-
- top_p: Default is 0.9.
568-
- repetition_penalty: Repetition penalty term. Default is 1.
569-
- num_iterations: number of iterations per batch. Default is 1.
566+
- num_iterations: The number of updates per data sample, corresponding to the $\mu$ value in the GRPO paper. Default is 1.
570567
- epsilon: epsilon value for clipping. Default is 0.2.
571568
- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon.
569+
- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False.
570+
- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times.
571+
- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False.
572+
The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions).
572573
- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
573-
- sync_ref_model: Whether to synchronize the reference model. Default is False。
574+
- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`.
575+
- advantage_estimator: Advantage estimator. Default is `grpo` (group-relative advantage). Options: `grpo`, [`rloo`](./GRPO/AdvancedResearch/RLOO.md).
576+
- kl_in_reward: Controls where the KL regularization is applied. `false` (default): KL is added as a separate term in the loss. `true`: KL is subtracted directly from the reward (integrated into the reward).
577+
- scale_rewards: Reward scaling strategy. Default is `group` (scale by standard deviation within each group). `batch` scales across the entire batch; `none` disables scaling. In ms-swift<3.10, this was a boolean: `true` means `group`, `false` means `none`.
578+
- sync_ref_model: Whether to synchronize the reference model. Default is False.
574579
- ref_model_mixup_alpha: The Parameter controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$. Default is 0.6.
575580
- ref_model_sync_steps:The parameter determines how frequently the current policy is synchronized with the reference policy. Default is 512.
576581
- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches.
577582
- multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py.
578583
- max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit.
579-
- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False.
580-
- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times.
581-
- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False.
582-
The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions).
583584
- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md).
584585
- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics).
585-
- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`.
586+
586587

587588

588589
cosine reward function arguments
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# REINFORCE Leave-One-Out (RLOO)
2+
3+
**Version requirement**: ms-swift>=3.10
4+
5+
[REINFORCE Leave-One-Out (RLOO)](https://arxiv.org/abs/2402.14740) is a reinforcement learning algorithm based on the classic REINFORCE policy-gradient method. It constructs an unbiased advantage baseline via the Leave-One-Out (LOO) technique.
6+
7+
## Algorithm Overview
8+
9+
For clarity, we explain RLOO by contrasting it with GRPO (Group Relative Policy Optimization).
10+
11+
### Key Differences Between GRPO and RLOO
12+
13+
Both GRPO and RLOO estimate advantages via intra-group comparisons to avoid the high variance of a global baseline. Their core differences are mainly in the following aspects:
14+
15+
#### Difference 1: How the Advantage Baseline Is Constructed
16+
17+
**1. GRPO (Group Relative Policy Optimization)**
18+
19+
For each prompt, GRPO generates $G$ response samples and normalizes rewards using the group mean and standard deviation:
20+
21+
$$
22+
\hat{A}_{i} = \frac{R_i - \text{mean}(\{R_j\}_{j=1}^G)}{\text{std}(\{R_j\}_{j=1}^G)}
23+
$$
24+
25+
Where:
26+
- $R_i$ is the reward of the $i$-th sample
27+
- $\text{mean}(\{R_j\}_{j=1}^G) = \frac{1}{G}\sum_{j=1}^G R_j$ is the group mean
28+
- $\text{std}(\{R_j\}_{j=1}^G)$ is the group standard deviation
29+
30+
**2. RLOO (REINFORCE Leave-One-Out)**
31+
32+
For each prompt, RLOO generates $K$ response samples and constructs the baseline via Leave-One-Out, i.e., for the $i$-th sample, the baseline is the mean of the other $K-1$ samples:
33+
34+
$$
35+
\hat{A}_{i} = R_i - \frac{1}{K-1}\sum_{j \neq i} R_j
36+
$$
37+
38+
This can be equivalently rewritten as:
39+
40+
$$
41+
\hat{A}_{i} = \frac{K}{K-1} \left(R_i - \bar{R}\right)
42+
$$
43+
44+
where $\bar{R} = \frac{1}{K}\sum_{j=1}^K R_j$ is the group mean reward.
45+
46+
> Note: We use $K$ here to match the notation in the paper. It has the same meaning as $G$ in GRPO and corresponds to the configuration parameter `num_generations`.
47+
48+
**Why Leave-One-Out?**
49+
50+
The key advantage is unbiasedness. For the $i$-th sample, its reward $R_i$ is independent of the baseline $\frac{1}{K-1}\sum_{j \neq i} R_j$, hence the advantage estimate is unbiased. In contrast, using the mean including itself as the baseline introduces bias.
51+
52+
#### Difference 2: How KL Regularization Is Applied
53+
54+
To prevent the policy from drifting too far from the reference policy, both algorithms introduce KL divergence regularization, but in different ways:
55+
56+
**GRPO**: Adds KL divergence as an independent regularization term to the [loss](../GetStarted/GRPO.md#algorithm-overview):
57+
58+
$$
59+
\mathcal{L}(\theta) = -\mathbb{E}\left[\hat{A}_i \log \pi_\theta(a_i|s_i)\right] + \beta \cdot \text{KL}(\pi_\theta \Vert \pi_{\text{ref}})
60+
$$
61+
62+
**RLOO**: Integrates KL divergence directly into the reward, constructing a modified reward:
63+
64+
$$
65+
R'_i = R_i - \beta \cdot \text{KL}(\pi_\theta \Vert \pi_{\text{ref}})
66+
$$
67+
68+
where $\beta$ is the KL coefficient (parameter `beta`), and $\pi_{\text{ref}}$ is the reference policy (typically an SFT model or the initial policy).
69+
70+
## Parameter Configuration
71+
72+
RLOO training can be enabled based on `GRPOTrainer` by setting the following parameters:
73+
74+
```bash
75+
# Basic RLOO configuration
76+
--advantage_estimator rloo # Use RLOO's leave-one-out advantage estimator
77+
--kl_in_reward true # Integrate KL divergence into the reward (default for RLOO)
78+
```
79+
80+
You can refer to this [script](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/rloo.sh) for training.
81+
82+
### Important Parameters
83+
84+
- **`--advantage_estimator`**: Choose the advantage estimator
85+
- `grpo` (default): standardize using group mean and standard deviation
86+
- `rloo`: construct the baseline via Leave-One-Out
87+
88+
- **`--kl_in_reward`**: Controls where the KL term is applied
89+
- `false`: KL as a separate regularization term in the loss (GRPO style)
90+
- `true`: subtract KL directly from the reward to form a modified reward (RLOO style)
91+
92+
- **`--num_generations`**: Number of samples per prompt, i.e., $K$
93+
94+
- **`--beta`**: KL regularization coefficient $\beta$
95+
- Controls how conservatively the policy updates
96+
97+
Other parameters are consistent with the [GRPO arguments](../../Command-line-parameters.md#grpo-arguments).

docs/source_en/Instruction/GRPO/GetStarted/GRPO.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ GRPOTrainer underwent a code refactoring in ms-swift3.5. If you are using a swif
44

55
[GRPO (Group Relative Policy Optimization)](https://arxiv.org/abs/2402.03300) leverages intra-group relative advantage calculations to replace the independent value model in the PPO algorithm and directly incorporates KL divergence penalties into the loss function to improve training stability.
66

7-
### GRPO Objective Function
7+
## Algorithm Overview
8+
9+
GRPO Objective Function is defined as
810
$
911
{\scriptstyle
1012
\begin{aligned}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2+
NPROC_PER_NODE=8 \
3+
swift rlhf \
4+
--rlhf_type grpo \
5+
--advantage_estimator rloo \
6+
--kl_in_reward true \
7+
--model Qwen/Qwen2.5-VL-3B-Instruct \
8+
--external_plugins examples/train/grpo/plugin/plugin.py \
9+
--reward_funcs external_r1v_acc format \
10+
--use_vllm true \
11+
--vllm_mode colocate \
12+
--vllm_gpu_memory_utilization 0.4 \
13+
--vllm_tensor_parallel_size 1 \
14+
--vllm_max_model_len 16384 \
15+
--train_type lora \
16+
--torch_dtype bfloat16 \
17+
--dataset 'AI-ModelScope/clevr_cogen_a_train' \
18+
--overlong_filter false \
19+
--epsilon 3e-4 \
20+
--epsilon_high 4e-4 \
21+
--max_completion_length 1024 \
22+
--num_train_epochs 1 \
23+
--per_device_train_batch_size 2 \
24+
--learning_rate 1e-6 \
25+
--gradient_accumulation_steps 4 \
26+
--eval_steps 1000 \
27+
--save_steps 1000 \
28+
--save_total_limit 10 \
29+
--sleep_level 1 \
30+
--offload_model true \
31+
--offload_optimizer true \
32+
--logging_steps 1 \
33+
--dataloader_num_workers 4 \
34+
--num_generations 16 \
35+
--temperature 1.0 \
36+
--system 'examples/train/grpo/prompt.txt' \
37+
--deepspeed zero2 \
38+
--log_completions true \
39+
--report_to tensorboard swanlab \
40+
--num_iterations 1 \
41+
--async_generate false \
42+
--beta 0.001 \
43+
--attn_impl flash_attention_2 \
44+
--padding_free true \
45+
--loss_type grpo

0 commit comments

Comments
 (0)