目录

「从零开始学大模型」TRL GRPOTrainer 源码导读

这是在 Hugging Face 的 TRL Quickstart 界面里,GRPO 训练的示例。今天从这里开始,探索一下 HF 的 GRPOTrainer 是怎样实现的,涵盖普通 Training loop 与 RL/GRPO 中的 Rollout、Training 环节,为之后我们手搓 GRPO Training Loop 打下基础!

from trl import GRPOTrainer
from datasets import load_dataset
from trl.rewards import accuracy_reward

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"),
    reward_funcs=accuracy_reward,
)
trainer.train()

继承链与分工

GRPOTrainer 继承自 _BaseTrainer_BaseTrainer 继承自 transformers.Trainer

transformers.Trainer -> _BaseTrainer -> GRPOTrainer

在这条链上,主要的 training loop 由基类 Trainer 实现,而 GRPOTrainer 则是重写了几个关键函数,让 GRPO 的 rollout 生成、group 打分与 loss 计算得以实现。

后文我们详细拆解一下这三个类的功能,定位到 trainer.train() 这一个函数的具体执行逻辑是什么,调用了哪些其他函数,详细理解一下 RL 的 training loop 在代码上是怎么实现的。

Trainer 基类

功能模块分工

class Trainer 定义于 transformers 的库内,整个 trainer.py 有整整 4412 行,相当复杂。根据代码里的分段注释,class Trainer 的代码自上而下,实现了这些功能:

  1. Initialization & Validation - 初始化与参数检验
  2. Data Loading - 加载数据
  3. Optimizer & Scheduler & Learning rate - 优化、调度、学习率
  4. Training - 训练模块
  5. Training Utilites - 训练工具
  6. Evaluation & Prediction - 验证与推理
  7. Checkpoint Saving - 训练checkpoint保存
  8. Checkpoint Resuming - 训练checkpoint加载、恢复
  9. Saving & Serialization - 整个模型的保存与序列化
  10. Logging & Metrics - 日志与性能测试
  11. Hub Integration - Hugging Face 集成
  12. Hyperparameter Search - 自动尝试超参数组合
  13. Callbacks - 添加删除回调函数
  14. Utilities - 小工具

就算只是列举功能模块也足足有 14 项,更别说每个模块里少则二三多则十几个的函数了。好在我们目标是学习 training loop,只需要看 Training 模块就好。

Training Loop

Training 模块里面的函数也不少,有这些:

  1. train() - 训练入口
  2. _inner_training_loop() - 实际训练循环发生的地方,逐 epoch 的 for 循环
  3. _init_training_state() - 训练状态初始化
  4. _run_epoch() - 每个 epoch 做什么,两层循环:外更新 / 内 micro-batch 拆分
  5. _finalize_training() - 收集 metric 数据、清理、生成 output 等等
  6. training_step() - 在 run_epoch() 的内层循环调用,forward + backword
  7. compute_loss() - 每一个 training_step() 调用一次,forward 与计算 loss
  8. compute_loss_context_manager() - 计算 loss 的辅助函数
  9. autocast_smart_context_manager() - autocast 的辅助函数
  10. _maybe_log_save_evaluate() - log 与记录

简化一下逻辑,可以这么理解:

# 外部 call trainer.train()
# call _inner_training_loop()
for epoch in range(epoches): # epoch 循环
	# call _run_epoch()
	for update_step in range (update_steps):  # 外层:权重更新循环
		for i in range(batch_samples): # 内层:micro batch
			# call training_step()
			inputs = self._prepare_inputs()
			loss = self.compute_loss(inputs)
			self.accelerator.backward(loss)
		# 每个 update_step, 或整个 epoch 的最后一步
		optimizer.step()    # 权重更新
		lr_scheduler.step() # 学习率更新
		model.zero_grad()   # 梯度清零

简单的三层循环:epoch、update_step、batch。这就是核心的训练循环了,很朴素,除了把大 batch 拆成显存装得下的 micro batch,和其他常见的训练循环没有什么特别的。

在子类 GRPOTrainer 中,一样用的是这个循环,仅仅是重写了其中的几个函数。我们理解了基类的训练循环,就可以详细深入 GRPOTrainer,看看为什么重写那几个函数就能把普通的训练变成 GRPO 了。

_BaseTrainer

虽然这个模块基本上什么都没做,但既然在继承链上还是顺带提一嘴。它继承自 Trainer ,仅仅是重写了在基类的 Hub Integration 这一模块 的 create_model_card() 函数。

父类 HF Trainer 的关注点是通用 ML 模型发布:license、language、tasks、dataset 这些是 HF Hub 检索和展示需要的元数据。

TRL _BaseTrainer 的关注点是论文算法复现:TRL 里每个 trainer 对应一篇论文(PPO、DPO、GRPO…),所以卡片里要突出:

  • 这是用哪个算法训练的(_name
  • 算法出自哪篇论文(_paper
  • 训练过程日志在哪看(wandb/trackio/comet URL)
  • 谁也可以用什么子类属性扩展(_tag_names_template_file

对理解训练 GRPO 没什么用,知道一下就行。

GRPOTrainer

GRPOTrainer 类也不是一个善茬,grop_trainer.py 文件有整整 2822 行。好在很多都是工具函数,或为了工程化做的补丁,我们也直击最重要的两个函数:_prepare_inputs()compute_loss().

GRPOTrainer 重写了这两个函数,让 RL 循环嵌入了基类 Trainer 的训练循环中。而 GRPO 的 rollout、打分与计算 loss 都在这两个函数的内部实现,对 Trainer 的训练循环没有侵入式的影响。

鉴于读者一个已经掌握了 RL 的基础算法知识,了解 GRPO 是在做什么,这里就按 GRPO 的几个模块来分类:rollout → reward → advantage → loss。其中,rollout、reward、advantage 部分都在 _prepare_inputs() 内部完成,而计算 GRPO 的 loss 在 compute_loss() 里完成

_prepare_input()

这个函数比较简单,做这么几件事:

  • generate_any step:
    1. 真正的 rollout 一次(由 _generate_and_score_completions() 处理)
    2. 打乱结果顺序(避免同 prompt 的 G 个 completion 落在同一 micro-batch)
    3. 切分成 micro batch,作为 inputs 给之前的最内层循环使用
    4. 更新缓存
  • 从缓存中读当前 step 要用的那一份

它在工程实现上海做了对于文本和图像的双适配,两种 prompt 都走这个路径,不过我们只关注文本就行。在 eval 时,这个函数不做缓冲,每一次调用都重新由 _generate_and_score_completions() 生成。

_generate_and_score_completions() - Rollout 侧

这个函数就复杂了,整整 487 行。它的功能囊括了生成 completions、算 reward 和算 advantage,零零碎碎的。流程是这样的:

  1. 预处理 prompts ,多模态与 RL 环境交互适配,格式化
  2. _generate() ,tokenize + rollout,生成 completions
  3. pad 和 mask 输出,成固定 shape (B,P+C)
  4. _get_per_token_logps_and_entropies() 算两套 logprobs,之后算 loss 用
    1. old:用当前 $\pi_\theta$ 重跑,用作 importance sampling 基准
    2. ref:用 ref_model 算,给 KL penalty 用
  5. 计算 rewards 与 advantage
    1. 调用 _calculate_rewards() 得到输出 rewards_per_func
    2. 先是二维矩阵,[completion, reward_func],对应每个 completion 的多个评价函数值
    3. 在 reward_func 维度,加权求和得到 rewards,维度[completion]
    4. 计算 mean_grouped_reward,组内求平均,但整个维度还是 [completion] 不变
    5. 计算 advantage 并归一化
  6. 收集输出、保存 log 等等的杂活

其中 5.5 的计算组内 advantage 方式如下:

advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4)
$$ A_i = \frac{r_i - \mu_{\text{group}}}{\sigma_{\text{group}} + \varepsilon} $$

简单来说,一次 generate_and_score_completions 产出"一批 completion 的全部 RL 训练素材":token、mask、advantage、参考 logprob、多模态辅料;同时把 reward/采样/漂移的各种指标写进日志管道。它是 GRPO 数据流水线的唯一源头,下游所有 update step 都吃它的输出。

它调用的几个函数在这里直接解释一下,不复杂就不单开章节细讲了:

_generate():调用 vllm 或自定义的 rollout_func,生成 completion,可以当黑盒看待

_calculate_rewards():调用由用户注入的 reward_funcs 函数,逐 completion 逐 func 打分。Trainer 在这里只负责编排使用,不关心它的具体实现。reward_func 可以是 Python 函数、reward model、远程服务等等。

_get_per_token_logps_and_entropies() 算 logprobs

回顾一下 logps 咋算的: 我们实际要算的是“在 $s_t$ 下,做 $a_t$ 动作概率的 log。在 LLM 里:

  • state $s_t$ = prompt + 已经生成的 tokens $y_{< t}$
  • action $a_t$ = 下一个 token $y_t$
$$ \log \pi_\theta(a_t \mid s_t) = \log \pi_\theta (y_t \mid x, y_{< t} ) $$

这个函数做了什么:

  1. forward 获得 logits = model(input_ids, attention_mask).logits
    • input_ids shape 是 [B, prompt_len] 的,是“具体的数(token id)”
    • logits shape 是 [B, prompt_len + completion_len, vocab_size] 的,是“分布”
  2. logits 对齐与整形
    1. 去掉 input_ids 第一个,logits 最后一个,对齐 $s_t$ 与 $a_t$
    2. 只保留 completion 段(算 prompt 段没有意义,不关模型的事)
    3. 对 temperature 归一化
  3. 算 log_softmax + gather 数学上: $$\log \pi(y_t \mid \cdot) = \text{logits}[y_t] - \log \sum_v \exp(\text{logits}[v])$$ 代码上等价于:
log_probs = logits.log_softmax(dim=-1)              # [B, L, V]
per_token_logps = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
# shape: [B, L]
$$\log \pi(y_t) = z_{y_t} - \text{logsumexp}(z)$$

只需要对每个位置算一次 logsumexp 和取一次 z_{y_t}

compute_loss() - Training 侧

总算看完_prepare_input() 了,到第二个大模块,loss 计算。然后其实 compute_loss() 是对 _compute_loss() 的一个超薄封装,如果 self.use_liger_kernel 为 True 则路由到另一个 liger grpo loss 计算,正常的话直接调用的是 _compute_loss()。后面就直接讲 _compute_loss() 内的东西了。

我们先梳理一下现在手上有哪些东西:

  • prompt_ids, prompt_mask
    • prompt 段 token id 与有效位掩码
    • _prepare_inputs() 的 input 处理环节产生
  • completions_ids, completion_mask
    • 采样得到的 completion 段与掩码
    • 在 rollout 环节 _generate() 时生成
  • old_per_token_logps
    • 旧策略下每 token 的对数概率
    • 在 rollout 里 _get_per_token_logps_...() 算的 $\log \pi_\text{old}(a_t \mid s_t)$
  • ref_per_token_logps
    • 参考策略下每 token 的对数概率
    • 在 rollout 里 _get_per_token_logps_...()算的 $\log \pi_{ref}(a_t \mid s_t)$
  • advanages
    • group 标准化后的优势
    • 在 rollout 里调用 _calculate_rewards() 后计算得到组内 $\hat A_i = (r_i - \bar r)/\sigma_r$

要计算 loss,还缺了一个 per_token_logps:

  • 当前策略对数概率 $\log \pi_\theta(y_t \mid x, y_{< t})$
  • _compute_loss() 里面自己算
  • 还是用 _get_per_token_logps_and_entropies()

然后我们再看看需要用到的几个中间量:

importance ratio 重要性采样 $r_t$:

$$r_t = \frac{\pi_\theta(a_t \mid s_t)}{\pi_\text{old}(a_t \mid s_t)} = \exp( \log \pi_\theta - \log \pi_\text{old})$$

根据 PPO 的 clip 计算出来的 Policy Gradient 的 surrogate loss

$$\mathcal{L}^{PG}_t = \min (r_t \hat A_t, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \hat A_t ) $$

KL 正则项 (这个也是 per_token 的哦),在 beta==0 的时候会跳过不算:

$$ \mathbb{D}_{\mathrm{KL}}[\pi_\theta \,\|\, \pi_{\mathrm{ref}}] = \mathbb{E}_{y \sim \pi_\theta}\! \left[ \log \frac{\pi_\theta(y)}{\pi_{\mathrm{ref}}(y)} \right] $$

但是在 GRPO 中不是用的朴素 $\hat k_1 = \log ( \pi_\theta / \pi_\text{ref})$ (方差大且可能为负),用的是:

$$\hat k_3 = \frac{\pi_\text{ref}}{\pi_\theta} - \log \frac{\pi_\text{ref}}{\pi_\theta} - 1 = e^{-\hat k_1} - (-\hat k_1 + 1) $$

同样无偏(在期望意义下),且 $\geq 0$ ,方差更小。

在这里合并,得到 per_token_loss

$$ loss_t = \mathcal{L}^{PG}_t - \beta \cdot \mathbb{D}_{KL} \mid_t$$

然后是合并 per_token_loss 得到最终 loss:但是这里 TRL 给了三种不同实现:

  • "grpo":原论文:先 per-sequence 平均,再 per-batch 平均
  • "dr_grpo":Dr.GRPO:用固定的 max_completion_length 归一化,去掉长度偏差
  • "dapo":也叫 token-level,所有有效 token 一起平均,长 completion 权重更大

随后就是 loss 回传,backward 算梯度,更新权重了。从 loss 回传回 weight 的路径:

loss 
 └─ per_token_loss
     └─ coef_1 = exp(per_token_logps - old_per_token_logps.detach())
                     ↑ 梯度只从这里流
                 per_token_logps
                     └─ logits.log_softmax().gather(input_ids)
                         └─ model(input_ids) forward
                             └─ model.parameters()  ← 梯度到这

只有在 _compute_loss() 里面临时算的 per_token_logps 这一路径,其他都是 detach 的或者单一的数值,不参与梯度计算。

读后感想

觉得最有意思的,是看完 TRL 代码就知道为什么要做 Rollout / Training 分离了。Rollout 侧看似做了很多的 foward 产出了很多,但都是中间数据,只有最后在训练侧的 compute_loss() 里面做的那一次 forward 才真正对应着梯度回传的路径。这是两个看起来都在用同一个模型,但是实际上截然不同的两件事,如果做了 R/T 分离,各司其职,直觉上都能让 RL loop 快很多。

此外,代码风格上,感觉 TRL 特别喜欢 no-op,整个代码处处都有 no-op 来简化代码路径。比如说多模态全部塞进一个函数,函数内部如果发现如果没有图片(是单模态)就直接返回,多模态再处理。又比如说做 completion 数组切片,如果是已经切好的,就自然是一个 no-op。让多模态、图片处理、不同的 rollout 后端支持能在明面上尽量走同一个代码路径,看起来很舒服。