Skip to content

Commit eacc0a1

Browse files
authored
[script] provide on-policy distillation script & update GKD doc (#6334)
1 parent 76f471e commit eacc0a1

File tree

5 files changed

+88
-24
lines changed

5 files changed

+88
-24
lines changed

docs/source/Instruction/GKD.md

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ $$
3333
#### Forward KL(前向 KL)
3434

3535
$$
36-
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
36+
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
3737
$$
3838

39-
**特性**:Mode-seeking(寻模)
40-
- 期望在学生分布下计算
41-
- 学生模型倾向于集中在教师模型的峰值区域(高概率区域
39+
**特性**:Mode-covering
40+
- 期望在教师分布下计算
41+
- 学生模型倾向于覆盖教师的整个分布(包括低概率区域
4242

43-
#### Reverse KL(反向 KL)
4443

44+
#### Reverse KL(反向 KL)
4545
$$
46-
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
46+
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
4747
$$
4848

49-
**特性**:Mode-covering(覆模)
50-
- 期望在教师分布下计算
51-
- 学生模型倾向于覆盖教师的整个分布(包括低概率区域
49+
**特性**:Mode-seeking
50+
- 期望在学生分布下计算
51+
- 学生模型倾向于集中在教师模型的峰值区域(高概率区域
5252

5353
### 广义 Jensen-Shannon 散度(Generalized JSD)
5454

@@ -78,8 +78,8 @@ $$
7878
其中 $M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}$
7979

8080
> 对极端情况($\beta = 0$ 或 $\beta = 1$),直接计算单个 KL 散度:
81-
> - 当 $\beta = 0$ 时:直接定义 $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$(Reverse KL,Mode-covering)
82-
> - 当 $\beta = 1$ 时:直接定义 $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$(Forward KL,Mode-seeking)
81+
> - 当 $\beta = 0$ 时:直接定义 $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$(Forward KL,Mode-covering)
82+
> - 当 $\beta = 1$ 时:直接定义 $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$(Reverse KL,Mode-seeking)
8383
> - 当 $0 < \beta < 1$ 时:使用上述混合分布公式进行插值
8484
8585
通过调节 $\beta$ 参数,可以在不同的散度度量之间进行插值,当 $\beta = 0.5$ 时,散度为标准的对称 JSD。
@@ -142,8 +142,8 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
142142
| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
143143
|------|------|--------|---------|------|
144144
| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID |
145-
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Reverse KL (覆模,更多样)<br>• 0.5: JSD (平衡**推荐**)<br>• 1.0: Forward KL (寻模,更专注) |
146-
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 纯 Off-Policy<br>• 0.5: 混合策略 (**推荐**)<br>• 1.0: 纯 On-Policy |
145+
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Forward KL <br>• 0.5: JSD (平衡)<br>• 1.0: Reverse KL |
146+
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 纯 Off-Policy<br>• 0.5: 混合策略<br>• 1.0: 纯 On-Policy |
147147
| `--seq_kd` | bool | False | True/False | 是否使用教师生成序列<br>• False: 非 on-policy 时使用数据集<br>• True: 非 on-policy 时使用教师生成 |
148148
| `--temperature` | float | 0.9 | > 0 | 生成采样温度,控制随机性 |
149149
| `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 |
@@ -200,3 +200,13 @@ swift rlhf \
200200
```
201201

202202
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh)
203+
204+
## On-Policy Distillation
205+
206+
我们可以通过设置以下参数实现 Thinking Machine Lab blog 中的[On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/)训练。
207+
```bash
208+
--lmbda 1 # on-policy
209+
--beta 1 # reverse
210+
```
211+
212+
相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh)

docs/source/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ Advanced Research
77
DAPO.md
88
deepeyes.md
99
GSPO.md
10+
RLOO.md
1011
CHORD.md

docs/source_en/Instruction/GKD.md

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ In knowledge distillation, there are two choices depending on the order of the t
3333
#### Forward KL
3434

3535
$$
36-
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
36+
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
3737
$$
3838

39-
**Characteristics**: Mode-seeking
40-
- Expectation is computed under the student distribution
41-
- The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model
39+
**Characteristics**: Mode-covering
40+
- Expectation is computed under the teacher distribution
41+
- The student model tends to cover the entire teacher distribution (including low-probability regions)
4242

4343
#### Reverse KL
4444

4545
$$
46-
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
46+
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
4747
$$
4848

49-
**Characteristics**: Mode-covering
50-
- Expectation is computed under the teacher distribution
51-
- The student model tends to cover the entire teacher distribution (including low-probability regions)
49+
**Characteristics**: Mode-seeking
50+
- Expectation is computed under the student distribution
51+
- The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model
5252

5353
### Generalized Jensen-Shannon Divergence (Generalized JSD)
5454

@@ -78,8 +78,8 @@ $$
7878
Where $M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}$
7979

8080
> For extreme cases ($\beta = 0$ or $\beta = 1$), directly compute a single KL divergence:
81-
> - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Reverse KL, Mode-covering)
82-
> - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Forward KL, Mode-seeking)
81+
> - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Forward KL, Mode-covering)
82+
> - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Reverse KL, Mode-seeking)
8383
> - When $0 < \beta < 1$: use the above mixture distribution formula for interpolation
8484
8585
By adjusting the $\beta$ parameter, interpolation can be performed between different divergence metrics. When $\beta = 0.5$, the divergence is the standard symmetric JSD.
@@ -142,7 +142,7 @@ We can perform GKD training by setting the following parameters:
142142
| Parameter | Type | Default | Range | Description |
143143
|------|------|--------|---------|------|
144144
| `--teacher_model` | str | Required | - | Teacher model path or model ID |
145-
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Reverse KL (mode-covering, more diverse)<br>• 0.5: JSD (balanced, **recommended**)<br>• 1.0: Forward KL (mode-seeking, more focused) |
145+
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Forward KL <br>• 0.5: JSD (balanced)<br>• 1.0: Reverse KL |
146146
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Off-Policy<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
147147
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences<br>• False: Use dataset when not on-policy<br>• True: Use teacher generation when not on-policy |
148148
| `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
@@ -201,3 +201,14 @@ swift rlhf \
201201
```
202202

203203
Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh)
204+
205+
206+
## On-Policy Distillation
207+
We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/) training described in the Thinking Machines Lab blog by setting the following parameters:
208+
209+
```bash
210+
--lmbda 1 # on-policy
211+
--beta 1 # reverse
212+
```
213+
214+
For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh).

docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ Advanced Research
77
DAPO.md
88
deepeyes.md
99
GSPO.md
10+
RLOO.md
1011
CHORD.md
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/
2+
3+
# CUDA_VISIBLE_DEVICES=7 \
4+
# swift rollout \
5+
# --model Qwen/Qwen3-8B-Base \
6+
# --vllm_max_model_len 24192
7+
8+
NPROC_PER_NODE=7 \
9+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
10+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \
11+
swift rlhf \
12+
--rlhf_type gkd \
13+
--model Qwen/Qwen3-8B-Base \
14+
--teacher_model Qwen/Qwen3-32B \
15+
--train_type full \
16+
--dataset open-thoughts/OpenThoughts3-1.2M#10000 \
17+
--seq_kd false \
18+
--lmbda 1 \
19+
--beta 1 \
20+
--torch_dtype bfloat16 \
21+
--num_train_epochs 1 \
22+
--per_device_train_batch_size 1 \
23+
--learning_rate 1e-5 \
24+
--gradient_accumulation_steps 1 \
25+
--save_steps 1000 \
26+
--save_total_limit 2 \
27+
--logging_steps 1 \
28+
--max_length 16000 \
29+
--max_completion_length 8192 \
30+
--output_dir output \
31+
--warmup_ratio 0.05 \
32+
--save_only_model true \
33+
--dataloader_num_workers 64 \
34+
--dataset_num_proc 4 \
35+
--deepspeed zero2 \
36+
--teacher_deepspeed zero3 \
37+
--attn_impl flash_attn \
38+
--use_vllm true \
39+
--vllm_mode server \
40+
--vllm_server_host 127.0.0.1 \
41+
--vllm_server_port 8000

0 commit comments

Comments
 (0)