「从零开始学大模型」TRL GRPOTrainer 源码导读
这是在 Hugging Face 的 TRL Quickstart 界面里,GRPO 训练的示例。今天从这里开始,探索一下 HF 的 GRPOTrainer 是怎样实现的,涵盖普通 Training loop 与 RL/GRPO 中的 Rollout、Training 环节,为之后我们手搓 GRPO Training Loop 打下基础!
from trl import GRPOTrainer
from datasets import load_dataset
from trl.rewards import accuracy_reward
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"),
reward_funcs=accuracy_reward,
)
trainer.train()继承链与分工
GRPOTrainer 继承自 _BaseTrainer,_BaseTrainer 继承自 transformers.Trainer。
transformers.Trainer -> _BaseTrainer -> GRPOTrainer在这条链上,主要的 training loop 由基类 Trainer 实现,而 GRPOTrainer 则是重写了几个关键函数,让 GRPO 的 rollout 生成、group 打分与 loss 计算得以实现。
后文我们详细拆解一下这三个类的功能,定位到 trainer.train() 这一个函数的具体执行逻辑是什么,调用了哪些其他函数,详细理解一下 RL 的 training loop 在代码上是怎么实现的。
Trainer 基类
功能模块分工
class Trainer 定义于 transformers 的库内,整个 trainer.py 有整整 4412 行,相当复杂。根据代码里的分段注释,class Trainer 的代码自上而下,实现了这些功能:
- Initialization & Validation - 初始化与参数检验
- Data Loading - 加载数据
- Optimizer & Scheduler & Learning rate - 优化、调度、学习率
- Training - 训练模块
- Training Utilites - 训练工具
- Evaluation & Prediction - 验证与推理
- Checkpoint Saving - 训练checkpoint保存
- Checkpoint Resuming - 训练checkpoint加载、恢复
- Saving & Serialization - 整个模型的保存与序列化
- Logging & Metrics - 日志与性能测试
- Hub Integration - Hugging Face 集成
- Hyperparameter Search - 自动尝试超参数组合
- Callbacks - 添加删除回调函数
- Utilities - 小工具
就算只是列举功能模块也足足有 14 项,更别说每个模块里少则二三多则十几个的函数了。好在我们目标是学习 training loop,只需要看 Training 模块就好。
Training Loop
Training 模块里面的函数也不少,有这些:
train()- 训练入口_inner_training_loop()- 实际训练循环发生的地方,逐 epoch 的 for 循环_init_training_state()- 训练状态初始化_run_epoch()- 每个 epoch 做什么,两层循环:外更新 / 内 micro-batch 拆分_finalize_training()- 收集 metric 数据、清理、生成 output 等等training_step()- 在 run_epoch() 的内层循环调用,forward + backwordcompute_loss()- 每一个 training_step() 调用一次,forward 与计算 losscompute_loss_context_manager()- 计算 loss 的辅助函数autocast_smart_context_manager()- autocast 的辅助函数_maybe_log_save_evaluate()- log 与记录
简化一下逻辑,可以这么理解:
# 外部 call trainer.train()
# call _inner_training_loop()
for epoch in range(epoches): # epoch 循环
# call _run_epoch()
for update_step in range (update_steps): # 外层:权重更新循环
for i in range(batch_samples): # 内层:micro batch
# call training_step()
inputs = self._prepare_inputs()
loss = self.compute_loss(inputs)
self.accelerator.backward(loss)
# 每个 update_step, 或整个 epoch 的最后一步
optimizer.step() # 权重更新
lr_scheduler.step() # 学习率更新
model.zero_grad() # 梯度清零简单的三层循环:epoch、update_step、batch。这就是核心的训练循环了,很朴素,除了把大 batch 拆成显存装得下的 micro batch,和其他常见的训练循环没有什么特别的。
在子类 GRPOTrainer 中,一样用的是这个循环,仅仅是重写了其中的几个函数。我们理解了基类的训练循环,就可以详细深入 GRPOTrainer,看看为什么重写那几个函数就能把普通的训练变成 GRPO 了。
_BaseTrainer 类
虽然这个模块基本上什么都没做,但既然在继承链上还是顺带提一嘴。它继承自 Trainer ,仅仅是重写了在基类的 Hub Integration 这一模块 的 create_model_card() 函数。
父类 HF Trainer 的关注点是通用 ML 模型发布:license、language、tasks、dataset 这些是 HF Hub 检索和展示需要的元数据。
TRL _BaseTrainer 的关注点是论文算法复现:TRL 里每个 trainer 对应一篇论文(PPO、DPO、GRPO…),所以卡片里要突出:
- 这是用哪个算法训练的(
_name) - 算法出自哪篇论文(
_paper) - 训练过程日志在哪看(wandb/trackio/comet URL)
- 谁也可以用什么子类属性扩展(
_tag_names、_template_file)
对理解训练 GRPO 没什么用,知道一下就行。
GRPOTrainer 类
GRPOTrainer 类也不是一个善茬,grop_trainer.py 文件有整整 2822 行。好在很多都是工具函数,或为了工程化做的补丁,我们也直击最重要的两个函数:_prepare_inputs() 与 compute_loss().
GRPOTrainer 重写了这两个函数,让 RL 循环嵌入了基类 Trainer 的训练循环中。而 GRPO 的 rollout、打分与计算 loss 都在这两个函数的内部实现,对 Trainer 的训练循环没有侵入式的影响。
鉴于读者一个已经掌握了 RL 的基础算法知识,了解 GRPO 是在做什么,这里就按 GRPO 的几个模块来分类:rollout → reward → advantage → loss。其中,rollout、reward、advantage 部分都在 _prepare_inputs() 内部完成,而计算 GRPO 的 loss 在 compute_loss() 里完成
_prepare_input()
这个函数比较简单,做这么几件事:
- 每
generate_anystep:- 真正的 rollout 一次(由
_generate_and_score_completions()处理) - 打乱结果顺序(避免同 prompt 的 G 个 completion 落在同一 micro-batch)
- 切分成 micro batch,作为
inputs给之前的最内层循环使用 - 更新缓存
- 真正的 rollout 一次(由
- 从缓存中读当前 step 要用的那一份
它在工程实现上海做了对于文本和图像的双适配,两种 prompt 都走这个路径,不过我们只关注文本就行。在 eval 时,这个函数不做缓冲,每一次调用都重新由 _generate_and_score_completions() 生成。
_generate_and_score_completions() - Rollout 侧
这个函数就复杂了,整整 487 行。它的功能囊括了生成 completions、算 reward 和算 advantage,零零碎碎的。流程是这样的:
- 预处理 prompts ,多模态与 RL 环境交互适配,格式化
_generate(),tokenize + rollout,生成 completions- pad 和 mask 输出,成固定 shape (B,P+C)
_get_per_token_logps_and_entropies()算两套 logprobs,之后算 loss 用- old:用当前 $\pi_\theta$ 重跑,用作 importance sampling 基准
- ref:用 ref_model 算,给 KL penalty 用
- 计算 rewards 与 advantage
- 调用
_calculate_rewards()得到输出 rewards_per_func - 先是二维矩阵,
[completion, reward_func],对应每个 completion 的多个评价函数值 - 在 reward_func 维度,加权求和得到 rewards,维度
[completion] - 计算
mean_grouped_reward,组内求平均,但整个维度还是[completion]不变 - 计算 advantage 并归一化
- 调用
- 收集输出、保存 log 等等的杂活
其中 5.5 的计算组内 advantage 方式如下:
advantages = (rewards - mean_grouped_rewards) / (std_rewards + 1e-4)简单来说,一次 generate_and_score_completions 产出"一批 completion 的全部 RL 训练素材":token、mask、advantage、参考 logprob、多模态辅料;同时把 reward/采样/漂移的各种指标写进日志管道。它是 GRPO 数据流水线的唯一源头,下游所有 update step 都吃它的输出。
它调用的几个函数在这里直接解释一下,不复杂就不单开章节细讲了:
_generate():调用 vllm 或自定义的 rollout_func,生成 completion,可以当黑盒看待
_calculate_rewards():调用由用户注入的 reward_funcs 函数,逐 completion 逐 func 打分。Trainer 在这里只负责编排使用,不关心它的具体实现。reward_func 可以是 Python 函数、reward model、远程服务等等。
_get_per_token_logps_and_entropies() 算 logprobs
回顾一下 logps 咋算的: 我们实际要算的是“在 $s_t$ 下,做 $a_t$ 动作概率的 log。在 LLM 里:
- state $s_t$ = prompt + 已经生成的 tokens $y_{< t}$
- action $a_t$ = 下一个 token $y_t$
这个函数做了什么:
- forward 获得
logits = model(input_ids, attention_mask).logitsinput_idsshape 是[B, prompt_len]的,是“具体的数(token id)”logitsshape 是[B, prompt_len + completion_len, vocab_size]的,是“分布”
- logits 对齐与整形
- 去掉
input_ids第一个,logits最后一个,对齐 $s_t$ 与 $a_t$ - 只保留 completion 段(算 prompt 段没有意义,不关模型的事)
- 对 temperature 归一化
- 去掉
- 算 log_softmax + gather 数学上: $$\log \pi(y_t \mid \cdot) = \text{logits}[y_t] - \log \sum_v \exp(\text{logits}[v])$$ 代码上等价于:
log_probs = logits.log_softmax(dim=-1) # [B, L, V]
per_token_logps = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
# shape: [B, L]只需要对每个位置算一次 logsumexp 和取一次 z_{y_t}。
compute_loss() - Training 侧
总算看完_prepare_input() 了,到第二个大模块,loss 计算。然后其实 compute_loss() 是对 _compute_loss() 的一个超薄封装,如果 self.use_liger_kernel 为 True 则路由到另一个 liger grpo loss 计算,正常的话直接调用的是 _compute_loss()。后面就直接讲 _compute_loss() 内的东西了。
我们先梳理一下现在手上有哪些东西:
prompt_ids,prompt_mask- prompt 段 token id 与有效位掩码
_prepare_inputs()的 input 处理环节产生
completions_ids,completion_mask- 采样得到的 completion 段与掩码
- 在 rollout 环节
_generate()时生成
old_per_token_logps- 旧策略下每 token 的对数概率
- 在 rollout 里
_get_per_token_logps_...()算的 $\log \pi_\text{old}(a_t \mid s_t)$
ref_per_token_logps- 参考策略下每 token 的对数概率
- 在 rollout 里
_get_per_token_logps_...()算的 $\log \pi_{ref}(a_t \mid s_t)$
advanages- group 标准化后的优势
- 在 rollout 里调用
_calculate_rewards()后计算得到组内 $\hat A_i = (r_i - \bar r)/\sigma_r$
要计算 loss,还缺了一个 per_token_logps:
- 当前策略对数概率 $\log \pi_\theta(y_t \mid x, y_{< t})$
- 在
_compute_loss()里面自己算 - 还是用
_get_per_token_logps_and_entropies()
然后我们再看看需要用到的几个中间量:
importance ratio 重要性采样 $r_t$:
$$r_t = \frac{\pi_\theta(a_t \mid s_t)}{\pi_\text{old}(a_t \mid s_t)} = \exp( \log \pi_\theta - \log \pi_\text{old})$$根据 PPO 的 clip 计算出来的 Policy Gradient 的 surrogate loss:
$$\mathcal{L}^{PG}_t = \min (r_t \hat A_t, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \hat A_t ) $$KL 正则项 (这个也是 per_token 的哦),在 beta==0 的时候会跳过不算:
但是在 GRPO 中不是用的朴素 $\hat k_1 = \log ( \pi_\theta / \pi_\text{ref})$ (方差大且可能为负),用的是:
$$\hat k_3 = \frac{\pi_\text{ref}}{\pi_\theta} - \log \frac{\pi_\text{ref}}{\pi_\theta} - 1 = e^{-\hat k_1} - (-\hat k_1 + 1) $$同样无偏(在期望意义下),且 $\geq 0$ ,方差更小。
在这里合并,得到 per_token_loss:
$$ loss_t = \mathcal{L}^{PG}_t - \beta \cdot \mathbb{D}_{KL} \mid_t$$然后是合并 per_token_loss 得到最终 loss:但是这里 TRL 给了三种不同实现:
"grpo":原论文:先 per-sequence 平均,再 per-batch 平均"dr_grpo":Dr.GRPO:用固定的 max_completion_length 归一化,去掉长度偏差"dapo":也叫 token-level,所有有效 token 一起平均,长 completion 权重更大
随后就是 loss 回传,backward 算梯度,更新权重了。从 loss 回传回 weight 的路径:
loss
└─ per_token_loss
└─ coef_1 = exp(per_token_logps - old_per_token_logps.detach())
↑ 梯度只从这里流
per_token_logps
└─ logits.log_softmax().gather(input_ids)
└─ model(input_ids) forward
└─ model.parameters() ← 梯度到这只有在 _compute_loss() 里面临时算的 per_token_logps 这一路径,其他都是 detach 的或者单一的数值,不参与梯度计算。
读后感想
觉得最有意思的,是看完 TRL 代码就知道为什么要做 Rollout / Training 分离了。Rollout 侧看似做了很多的 foward 产出了很多,但都是中间数据,只有最后在训练侧的 compute_loss() 里面做的那一次 forward 才真正对应着梯度回传的路径。这是两个看起来都在用同一个模型,但是实际上截然不同的两件事,如果做了 R/T 分离,各司其职,直觉上都能让 RL loop 快很多。
此外,代码风格上,感觉 TRL 特别喜欢 no-op,整个代码处处都有 no-op 来简化代码路径。比如说多模态全部塞进一个函数,函数内部如果发现如果没有图片(是单模态)就直接返回,多模态再处理。又比如说做 completion 数组切片,如果是已经切好的,就自然是一个 no-op。让多模态、图片处理、不同的 rollout 后端支持能在明面上尽量走同一个代码路径,看起来很舒服。