大语言模型(LLM)的训练通常分为三个主要阶段:预训练(Pre-training)、中训练(Mid-training)和后训练(Post-training)。 预训练阶段赋予模型通用的语言能力、世界知识和基础的推理能力。中训练阶段则向模型注入特定领域的知识,例如代码、医学文献或公司内部文档。最后的后训练阶段,旨在引导模型产生符合期望的特定行为,如遵循指令、解决数学问题或进行对话。

随着模型能力的提升,一个趋势是使用经过更强训练的更小模型,在特定领域其表现能够超越更大规模的通用模型。 使用小模型有诸多好处,比如可以本地化部署以保障隐私和安全,可以更便捷地进行持续训练和更新,并且能够节省推理成本。 为了充分利用这些优势,在后训练阶段选择合适的训练方法至关重要。

本文旨在深入解读 Kevin Lu 等人在 Thinking Machines 发表的《On-Policy Distillation》一文。文章探讨了一种名为“On-Policy Distillation”的后训练方法,该方法结合了强化学习(Reinforcement Learning, RL)的在线策略(On-policy)相关性和知识蒸馏(Knowledge Distillation)的密集奖励信号,旨在以更低的计算成本实现与前沿大模型相当的性能。我们将详细剖析该方法的技术原理、实现细节,并通过其在数学推理和模型个性化两个任务上的实验结果,展示其有效性和效率优势。

代码链接:https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/distillation

1. 后训练方法的分类与权衡

大模型的后训练方法大致可以分为两类:在线策略(On-policy)训练和离线策略(Off-policy)训练。

  • On-policy 训练:该方法从学生模型自身的生成中采样轨迹(rollouts),并给予一定的奖励。 强化学习(RL)是典型的 On-policy 训练方法。例如,在训练一个模型解决数学问题时,可以通过评估模型生成的每个解题步骤是否正确来给予奖励。 这种方式的优点在于,模型直接从自身的错误中学习,能够更直接地避免犯错。 但其主要缺点是反馈信号极其稀疏。无论模型生成的序列有多长,每个训练片段(episode)只能提供少量信息。

  • Off-policy 训练:该方法依赖于某个外部来源(例如,一个更强大的“教师”模型)提供的目标输出,学生模型通过模仿这些输出来学习。 监督微调(Supervised Fine-Tuning, SFT)是常见的 Off-policy 训练方式,即在一个包含特定任务标注样本的策划数据集上进行训练。 知识蒸馏是其中一种具体机制,通过训练学生模型来匹配教师模型的输出分布。 这种方法的缺点在于,学生模型学习的是教师模型频繁出现的上下文,而不是它自身在实际推理中会遇到的上下文。 这可能导致复合误差(compounding error):一旦学生模型在推理早期犯了一个教师模型从未犯过的错误,它就会进入一个与训练数据分布差异越来越大的状态空间,从而导致后续表现持续下降。 这个问题在处理长序列时尤为突出。此外,有研究指出,学生模型可能会学会模仿教师模型的风格和自信度,而非其事实准确性。

理想的后训练方法应该兼具二者之长,既能获得 On-policy 训练的相关性,又能利用 Off-policy 蒸馏的密集奖励信号。这就引出了本文的核心——On-Policy Distillation

2. On-Policy Distillation 的核心思想与实现

On-Policy Distillation 的核心思想是:从学生模型中采样轨迹,并使用一个高性能的教师模型来为该轨迹中的每一个 token 进行打分

回到数学问题的例子 5 + (2 * 3),On-policy Distillation 会对学生模型生成的解题步骤中的每一步进行评分,惩罚导致最终答案错误的步骤,同时强化那些执行正确的步骤。

2.1 损失函数:Reverse KL Divergence

On-policy Distillation 可以使用多种损失函数来评估学生模型的轨迹。为简单起见,作者选择了逐 token 的反向 KL 散度(Reverse KL Divergence)。该损失函数衡量在给定相同上文的情况下,学生模型()和教师模型()在下一个 token 上的分布差异。

其数学表达式为:

这个奖励函数的目标是最小化反向 KL 散度,从而促使学生模型在它所遇到的每一个状态下都去近似教师模型的行为。当学生模型的行为与教师模型完全一致时,反向 KL 散度为零。文章中将折扣因子(discount factor)设为零,意味着在任何给定的时间步,学生模型只优化紧接着的下一个 token,而不考虑未来的 tokens。

反向 KL 散度与强化学习有天然的协同作用,后者通常优化由奖励模型引导的序列级反向 KL 散度。 与大多数奖励模型不同,反向 KL 散度是“不可攻击的”(unhackable),因为低的 KL 值总是对应于从教师模型角度看期望行为的高概率。

反向 KL 散度的另一个有用特性是它是“寻找模式的”(mode-seeking)。 这意味着它会学习教师模型的某一种特定行为,而不是将概率分布分散在多个次优选项上。 这与“覆盖模式的”(mode-covering)的正向 KL 散度(Forward KL Divergence)形成对比,后者会鼓励策略覆盖目标分布的所有区域。

细节详见:https://www.mlpod.com/732.html

这种方法还能带来显著的计算节省。因为它不需要等待一个完整的序列生成结束来计算奖励,所以可以使用更短或部分的序列进行训练。同时,查询教师模型的对数概率(log probabilities)只需要一次前向传播,而轨迹则是由更小、更廉价的学生模型生成的。此外,这种方法也不需要一个独立的奖励模型或标注模型。

2.2 伪代码与实现流程

作者在 Tinker 框架的强化学习脚本基础上实现了 On-policy Distillation,该脚本已经包含了采样、奖励计算和策略梯度式训练的功能。其核心流程如下:

  1. 初始化教师客户端:使用 Tinker API 为教师模型创建一个采样客户端。这一步不需要通过教师模型传播对数概率。
  2. 采样轨迹:与标准强化学习一样,从学生模型中采样序列(rollouts)。在采样过程中,已经计算好了学生模型的对数概率 ,用于后续的重要性采样损失计算。
  3. 计算奖励:使用 compute_logprobs 函数查询教师客户端,获取在学生模型采样的轨迹上,教师模型的对数概率 。然后,利用这两个对数概率计算反向 KL 散度。
  4. 使用 RL 进行训练:将每个 token 的优势(advantage)设置为负的反向 KL 散度,然后调用强化学习的重要性采样损失函数来更新学生模型的参数。
# 初始化教师客户端
teacher_client = service_client.create_sampling_client(
    base_model=teacher_config.base_model,
    model_path=teacher_config.load_checkpoint_path,
)

# 采样轨迹
trajectories = do_group_rollout(student_client, env_group_builder)
sampled_logprobs = trajectories.loss_fn_inputs["logprobs"]

# 计算奖励
teacher_logprobs = teacher_client.compute_logprobs(trajectories)
reverse_kl = sampled_logprobs - teacher_logprobs
trajectories["advantages"] = -reverse_kl

# 使用 RL 进行训练
training_client.forward_backward(trajectories, loss_fn="importance_sampling")

上图展示了一个例子,一个学生模型(Qwen1.5-4B-Instruct-2507)在处理 SimpleBench 中的一个问题时出现了错误,而教师模型(Qwen1.5-235B-A22B-Instruct-2507)对其进行了评分。这个问题需要模型注意到一个关键前提:煎锅里的冰块会融化,因此最终答案是“B. 0”。学生模型错误地将其当作一个纯数学问题处理。图中颜色越深的 token 代表教师模型给予的惩罚(即更高的反向 KL 散度)越大。我们可以看到,教师模型主要惩罚了那些引导学生模型走向错误方向的短语开头的 token,这些 token 直观上对应于引导推理的关键“分叉点”。而最终的错误答案本身并没有受到惩罚,因为它在给定前面错误推理序列的条件下是完全可预测的。

3. 实验验证:数学推理能力蒸馏

文章首先在数学推理任务上验证了 On-policy Distillation 的效果。实验使用 Qwen1.5-32B 作为教师模型,将数学推理能力蒸馏到 Qwen1.5-8B-Base 学生模型中。

3.1 对比方法:Off-policy Distillation 与 RL
  • Off-policy Distillation (SFT) :实验首先采用监督微调(SFT)的方式进行中训练。使用的数据集是 OpenThoughts-3,这是一个由类似 Qwen1.5-32B 的模型生成的推理问答对集合。 将学生模型在 40 万个样本上进行全参数微调,在数学问题基准 AIME'24 上达到了 60% 的分数。 实验还对比了 LoRA 微调,发现在大规模数据集上,LoRA 的表现落后于全参数微调。从性能提升曲线可以看出,性能随训练数据量对数线性增长,即初期的性能提升成本较低,而后期则越来越昂贵。

作者将这个在 40 万样本上微调过的模型作为一个检查点,比较不同后训练方法将其在 AIME'24 上的分数从 60% 提升到 70% 所需的成本。通过对数线性外推,估计继续使用 Off-policy Distillation 大约需要 200 万个样本才能达到 70% 的分数。

  • Reinforcement Learning:根据 Qwen 团队的技术报告,他们在类似的 SFT 初始化基础上,通过强化学习(RL)在 AIME'24 上达到了 67.6% 的性能,花费了 17,920 GPU 小时。 作者估计,这与训练 200 万个 Off-policy 蒸馏样本的成本相当。
3.2 On-policy Distillation 的结果

作为 Off-policy Distillation 或 RL 的替代方案,文章从 40 万样本的 SFT 检查点开始运行 On-policy Distillation。结果显示,该方法在大约 150 个步骤后就达到了 70% 的 AIME'24 分数。

3.3 计算成本对比

比较不同方法的计算成本并非易事,因为训练、采样和对数概率计算的成本比例因实现而异。文章使用 FLOPs(浮点运算次数)作为衡量标准,这种方式虽然会惩罚那些可以在 GPU 上高效并行化的计算,但能更公平地反映总体计算量。

不同方法的计算成本(FLOPs)对比
不同方法的计算成本(FLOPs)对比

结果显示,在给定 SFT 数据集的情况下,On-policy Distillation 实现了 9 倍的基线成本降低。在这种情况下,不计算 Off-policy 训练中教师模型的 FLOPs 成本(因为数据集是现成的),但计算 On-policy 训练中的教师模型成本(因为需要为学生轨迹计算对数概率)。由于教师模型的计算可以高效地在多 GPU 上并行,换算成 GPU 小时,成本降低接近 18 倍

更常见的情况是,我们需要为一个没有现成 Off-policy 蒸馏数据集的新任务训练一个小模型。如果将教师模型采样生成 Off-policy 数据的成本也计算在内,那么总的成本降低大约是 30 倍

这一结果清晰地表明,On-policy Distillation 在计算效率上具有巨大优势。

4. 实验验证:模型个性化与持续学习

除了在通用任务上训练小模型达到高性能,蒸馏的另一个重要用例是个性化,例如让模型遵循特定的对话语气和输出格式,或者学习工具使用和成本预算等能力。通常需要将这种行为训练与新领域知识的学习结合起来。

同时训练两者通常很困难,轻量级的微调往往不足以实现目标,需要更大规模的中训练。在一个已经具备新知识的模型上学习后训练行为,通常需要一个复杂的后训练流程,其中可能包含专有数据和奖励模型,这对于普通实践者来说可能难以复制或成本过高。

本节展示了 On-policy Distillation 可以被有效地用于后训练特定行为。这种方法也适用于持续学习(Continual Learning)或“测试时训练”(test-time training),即在模型部署后进行更新,同时不损害其基础性能。

4.1 学习新知识导致已有能力的退化

实验目标是训练一个内部公司助理,它需要具备两个特性:

  1. 知识性:了解公司内部文档。
  2. 遵循指令:表现出强大的后训练行为,如遵循指令。

实验从 Qwen1.5-8B 开始,这是一个已经通过 RL 进行了后训练、具备良好助理技能的模型。先前的研究表明,强化学习训练的只是原始模型的一个小子网络,因此当网络在大量新数据上进一步训练时,这些能力可能很脆弱。

为了减少灾难性遗忘(catastrophic forgetting),中训练的一个常见方法是在新数据中混入来自原始模型预训练分布的“背景数据”。由于无法获取 Qwen1.5 的预训练数据,实验采用了一个更强且成本更高的基线:使用 Tulu3 提示,并用 Qwen1.5-8B 重新采样答案,作为聊天背景数据。

然后,实验在不同比例的内部文档和聊天数据的混合数据上对 Qwen1.5-8B 进行微调。结果显示,增加文档数据的比例能直接提升模型的知识水平(通过内部 QA 评估)。然而,尽管混入至少 30% 的聊天数据有助于保留大部分指令遵循能力(通过 IF-eval 评估),但没有任何混合比例能够完全维持其原始性能。

不同数据混合比例下的模型性能
不同数据混合比例下的模型性能

进一步的观察发现,在任何混合比例下,IF-eval 的性能在微调过程中都会下降。这损害了使用更长时间训练来进一步特化模型的能力。

Mid-training 过程中的 IF-eval 性能变化
Mid-training 过程中的 IF-eval 性能变化

另一种常用方法是使用 LoRA 来约束参数更新,从而减少灾难性遗忘。然而,实验发现这种方法仍然不足以保留 IF-eval 性能,同时 LoRA 会导致模型学到的新知识也更少。

使用 LoRA 进行 Mid-training 的性能变化
使用 LoRA 进行 Mid-training 的性能变化
4.2 On-policy Distillation 恢复后训练行为

接下来,实验尝试在对内部文档进行微调后,恢复模型的指令遵循能力。这种能力最初是通过昂贵且脆弱的 RL 训练得到的。实验使用模型的早期版本 Qwen1.5-8B 作为教师,在 Tulu3 提示上运行 On-policy Distillation。需要注意的是,这个训练阶段与内部文档数据完全无关,其唯一目的是恢复指令遵循能力。

On-policy Distillation 恢复模型能力的结果
On-policy Distillation 恢复模型能力的结果

结果显示,在一个 70% 内部文档和 30% 聊天数据的混合数据上进行微调后,On-policy Distillation 几乎完全恢复了模型在 IF-eval 上的性能,同时没有损失任何通过中训练学到的知识。甚至还观察到聊天能力和模型在内部 QA 评估中的“知识”表现之间存在一些正向迁移。

使用模型的早期版本作为教师来“重新调用”在微调过程中丢失的能力,使得 On-policy Distillation 成为一个有前景的持续学习工具。我们可以交替进行新数据的微调阶段和恢复行为的蒸馏阶段,从而让模型在不断学习新知识的同时保持最新的能力。

5. 深入讨论与分析

5.1 密集监督显著提升计算效率

强化学习和 On-policy Distillation 都通过反向 KL 散度进行学习,剪除基础策略中存在的行为空间。它们的区别在于奖励的密度。强化学习每个 episode 只教授 O(1) 比特的信息,而蒸馏每个 episode 能教授 O(N) 比特的信息,其中 N 是 token 的数量。

为了量化密集奖励带来的训练效率提升,实验进行了一个直接对比:

  1. 从 Qwen1.5-8B-Base 开始。
  2. 在 DeepMath 数据集上运行 RL,得到一个 RL 训练后的模型,作为蒸馏的教师。
  3. 将这个 RL 训练好的模型通过 On-policy Distillation 蒸馏回基础模型。
RL 与 On-policy Distillation 学习效率对比
RL 与 On-policy Distillation 学习效率对比

结果显示,在模型架构匹配的情况下(LoRA rank 128),蒸馏达到教师模型性能水平的速度比 RL 快了大约 7-10 倍。反向 KL 散度在不到 10 个梯度步骤内就下降到接近零,AIME 分数也得以恢复,而 RL 则需要 70 步。

累积来看,所需的计算量减少了 50-100 倍。这主要源于:

  • RL 需要在与评估上下文长度相当的序列上进行训练,而蒸馏在较短的上下文长度下也能有效学习。
  • 当 SFT 初始化较强时,On-policy Distillation 可以使用更小的批量大小,因为它每个 episode 提供了更多的信息比特,从而降低了梯度噪声。
5.2 蒸馏有效复用训练数据,提升数据效率

对于实践者来说,收集大量的训练提示可能很困难且耗时。因此,能够多次复用提示进行训练至关重要。对于 RL 来说,在同一个提示上进行多轮训练往往会导致模型简单地记住最终答案,尤其是在大模型上。相比之下,On-policy Distillation 通过最小化反向 KL 散度来学习近似教师模型的完整分布,而不仅仅是记住一个单一的答案。这使得我们可以从同一个提示中训练出多个样本。

实验重复了在 DeepMath 上训练 Qwen1.5-8B-Base 的过程,但这次只使用数据集中随机选择的一个提示。模型在这个提示上连续训练了 20 步,每步使用 256 个 rollouts,总共使用了 5120 个评分序列。尽管这种方式计算效率较低,但模型最终的表现几乎与使用完整数据集训练的教师模型相当。

仅使用单个 prompt 进行 On-policy Distillation 的结果
仅使用单个 prompt 进行 On-policy Distillation 的结果
5.3 RL 在语义策略空间中进行搜索

On-policy Distillation 能够用少得多的训练步骤复制 RL 的学习效果,一个解释是,与预训练不同,RL 的大部分计算开销并非花在梯度更新上,而是花在搜索上——即执行策略并进行信用分配。

  • 预训练是通过随机梯度下降在高维参数空间中进行探索。
  • RL 则是在语义策略空间中进行探索。在每一步,RL 都是对过去发现的某个策略进行微小的修改。它并非在参数空间中探索,而是通过随机采样已有的权重组合,“偶然发现”新的策略。

一旦找到一个好的策略,蒸馏就提供了一条学习它的捷径。On-policy Distillation 不需要对 RL 课程中的中间策略进行建模,而只需学习最终的策略。如果只对最终策略感兴趣(这在生产用例中很常见),我们就不必花费计算资源去建模所有的中间策略。

5.4 On-policy 学习作为持续学习的工具

在个性化蒸馏部分,我们探讨了 On-policy Distillation 重新引入模型专业行为的能力。这可以推广到更广泛的持续学习任务,即在不降低先前能力的情况下获取新知识。

先前的工作发现 On-policy 学习(RL)比 Off-policy 学习忘记得更少。然而,RL 只能塑造行为,不能很好地传授新知识,因此不足以作为持续学习的完整解决方案。

实验表明,SFT(包括 Off-policy Distillation)在构建持续学习流程时会因为降低已有行为能力而失败。文章通过一个直接的例子进一步研究了这一点。他们通过从 Qwen1.5-32B 以 temperature=1.0 采样构建了一个数据集,这个数据集与 Qwen1.5-32B 的 KL 散度期望为零。

当在这个模型自己的样本数据集上运行 SFT 时,任何大于零的实际学习率都会导致在指令遵循评估上的性能下降。

在模型自身样本上运行 SFT 导致性能下降
在模型自身样本上运行 SFT 导致性能下降

一个可能的解释是,虽然期望的 KL 散度为零,但每个有限的批次(batch)在实践中都会表现出略微不同的分布。在这些有限批次上进行训练会导致一个非零的梯度更新,从而使更新后的模型策略偏离其原始状态。这个过程随着时间的推移,会将对自己样本的训练转变为 Off-policy 训练,从而导致与标准 Off-policy 训练中相同的误差累积和长序列发散问题。

而 On-policy Distillation 始终保持在 On-policy 状态,由于教师模型是固定的,学生模型会收敛到教师模型的期望行为上,而不会像自蒸馏的 SFT 那样出现性能衰退。这使得 On-policy Distillation 成为一个非常有前景的持续学习工具。

6. 结论

本文深入探讨了 On-policy Distillation 在训练用于数学推理的小型模型和持续学习的助理模型等应用中的作用。通过与 Off-policy Distillation 和 On-policy RL 这两种后训练方法的比较,文章发现 On-policy Distillation 结合了 On-policy 训练的可靠性能和密集奖励信号的成本效益。

后训练是实现前沿模型能力的关键环节。通过利用来自学生模型的 On-policy 采样和来自教师模型的密集监督,On-policy Distillation 以远低于前沿高算力 RL 运行的成本,达到了同等的能力水平。

这项工作的实现可以在 Tinker cookbook 中找到。文章探索了 On-policy Distillation 的简单直接的实例化,以清晰地展示其优势,并希望未来能继续研究蒸馏的新应用、改进教师监督的新方法以及提高数据效率和持续学习能力的方法。


往期文章: