本文为系列第四篇。在了解了 PPO 的显存痛点和 DPO 的离线局限性后,我们终于迎来了目前大模型在线 RL 的最前沿破局者——GRPO(Group Relative Policy Optimization)。本文将详细推导 GRPO 的核心思想,看它是如何优雅地丢弃 Critic 网络,实现高效的在线强化学习的。

⬅️ 上一篇:笔记|强化学习(三):大模型对齐的另一条路:DPO (Direct Preference Optimization)

➡️ 下一篇:笔记|强化学习(五):Flow-GRPO 与图像生成应用(基于 Flux 的代码解析)

在线 RL 的不可替代性与 Critic 的累赘

正如上一篇所言,DPO 虽然简单省显存,但它只能"死记硬背"人类给出的标准答案(离线学习)。为了让模型产生"顿悟"和自我进化,我们必须回归在线强化学习(Online RL)

然而,PPO 算法中的 Critic 网络(价值网络)成为了最大的绊脚石。对于百亿参数的大模型,多维护一个 Critic 意味着显存开销直接翻倍。

核心思考出发点:既然 Critic 只是为了给出一个"及格线"(基准值 \(V(s)\)),我们能不能彻底去掉 Critic 模型,用一种更简单的方法来估计这个"及格线"?


GRPO 的核心思想:矮子里拔高个

GRPO 的思路极简:对同一个 Prompt 采样 \(G\) 个回答,用组内奖励的均值和标准差做标准化,得到每个回答的相对优势——高于均值的强化,低于均值的抑制。

\[ \hat{A}_i = \frac{r_i - \mu_R}{\sigma_R + \varepsilon}, \quad \mu_R = \frac{1}{G}\sum_{j=1}^G r_j, \quad \sigma_R = \text{std}(r_1, \dots, r_G) \]

这就是"矮子里拔高个":即使绝对水平不高,只要能分出高低,模型就有学习信号。注意分母的 \(\varepsilon\)(通常取 \(10^{-8}\)):当所有回答奖励相同时 \(\sigma_R = 0\),此时分子 \(r_i - \mu_R\) 也恰好为零,\(\hat{A}_i = 0/(0 + \varepsilon) = 0\),模型不更新——避免了无区分信号时的噪声梯度。


GRPO 的理论根源:从 REINFORCE 到组内相对优势

在深入数学推导之前,先理清 GRPO 的理论脉络——它并不是凭空发明的,而是 REINFORCE with Baseline 的一个聪明的工程变体。

经典 RL 中的 Baseline

回顾第一篇(RL 基础)中的 REINFORCE with Baseline:

\[ \nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot \big(G_t - b(s_t)\big) \right] \]

各符号含义:

符号 含义
\(\theta\) 策略网络的参数
\(J(\theta)\) 策略的目标函数(期望总回报),我们要最大化它
\(\tau \sim \pi_\theta\) 轨迹 \(\tau\) 按策略 \(\pi_\theta\) 采样(\(\sim\) 读作"服从/采样自")
\(\tau\) 一条完整轨迹(trajectory):\(s_0, a_0, r_0, s_1, a_1, r_1, \dots\)
\(\pi_\theta(a_t \mid s_t)\) 策略在状态 \(s_t\) 下选择动作 \(a_t\) 的概率
\(G_t = \sum_{k=t}^{T} r_k\) 从时刻 \(t\) 到终止的累积回报(未来总收益)
\(b(s_t)\) 基线(baseline):只依赖状态、不依赖动作的一个标量

公式的物理意义\(\nabla_\theta \log \pi_\theta(a_t \mid s_t)\) 是"让动作 \(a_t\) 更可能"的方向,\((G_t - b)\) 决定沿这个方向走多远——如果实际回报 \(G_t\) 高于基线 \(b\),就强化这个动作;低于基线就抑制

经典 RL 中的基线选择:最自然的基线是状态价值函数 \(V(s_t) = \mathbb{E}[G_t \mid s_t]\),即"从当前状态出发、按当前策略行动,未来累积回报的期望值"。经典做法是训练一个价值网络 \(V_\phi(s_t)\) 来逼近它——这就是 Actor-Critic / PPO 路线(需要额外的 Critic 模型,显存翻倍)。

语言模型场景的关键简化

在经典 RL 中,智能体在环境中走很多步(\(s_0 \to a_0 \to s_1 \to a_1 \to \cdots\)),每步都可能获得奖励,基线需要估计"从当前步到未来的累积回报"。

但在语言模型的 RLHF 场景中,整个回答是一条完整轨迹(prompt \(s\) → 生成完整回答 \(o\) → 得到一个总分 \(r(o)\)),奖励只在最后一步给出。这意味着:

\[ V(s) = \mathbb{E}_{o \sim \pi_\theta(\cdot|s)}[r(o)] \]

符号 含义
\(s\) 用户的 prompt
\(o\) 模型生成的一条完整回答
\(\pi_\theta(\cdot \mid s)\) 模型在 prompt \(s\) 下所有可能回答的概率分布
\(r(o)\) 奖励函数对回答 \(o\) 的打分
\(V(s)\) 这个 prompt 下所有可能回答的平均奖励

物理意义:因为只有一步决策(生成整个回答),"未来累积回报的期望"退化为"当前这个 prompt 下所有可能回答的平均得分"。

GRPO 的关键洞察:用采样均值替代价值网络

\(V(s) = \mathbb{E}_{o \sim \pi_\theta}[r(o)]\) 是理论上最优的基线(使策略梯度方差最小),但精确计算需要遍历所有可能回答——这不可能。

经典做法(PPO):训练一个 Critic 网络 \(V_\phi(s) \approx V(s)\),代价是多一个与策略模型同等规模的网络,显存翻倍。

GRPO 的做法:对同一个 prompt 采样 \(G\) 个回答 \(o_1, \dots, o_G\),直接用经验均值近似:

\[ \mu_R = \frac{1}{G}\sum_{i=1}^G r(o_i) \approx V(s) = \mathbb{E}_{o \sim \pi_\theta}[r(o)] \]

符号 含义
\(G\) 每个 prompt 的采样数量(通常 8~16)
\(o_i\) \(i\) 个采样回答
\(r(o_i)\) \(i\) 个回答的奖励
\(\mu_R\) 组内均值,\(V(s)\) 的蒙特卡洛估计

为什么这行得通? 道理很朴素:想知道一个班的平均成绩,随机抽几个同学算平均分就是一个合理的估计——抽的人越多越准。\(G\) 个回答的均值 \(\mu_R\) 就是对真实平均奖励 \(V(s)\) 的这种"抽样估计",\(G = 8 \sim 16\) 在实践中足够准确。

标准化稳定梯度

除以 \(\sigma_R\) 后得到 \(\hat{A}_i = \frac{r_i - \mu_R}{\sigma_R + \varepsilon}\)。这一步的作用是消除奖励尺度的影响

考虑两个不同的奖励函数:任务 A 的奖励在 \([0, 1]\) 范围(如准确率),任务 B 的奖励在 \([-100, 100]\) 范围(如 BLEU 分数乘以 100)。如果不标准化,任务 B 的梯度会比任务 A 大 100 倍,学习率需要针对每个任务单独调整。

标准化后,无论原始奖励的范围如何,优势值 \(\hat{A}_i\) 都近似服从均值为 0、标准差为 1 的分布(即 \(\hat{A}_i \in [-3, 3]\) 左右),梯度尺度统一,同一套超参数可以跨任务复用

一句话总结:经典 RL 的 baseline 是"未来累积回报的期望",在语言模型的单步场景中退化为"平均奖励",GRPO 用采样均值来近似它——不需要 Critic 网络,只需要多采几个样本

与 RLOO 的对比

RLOO(REINFORCE Leave-One-Out)是另一种去 Critic 的基线方案。两种方法的核心区别在于:计算第 \(i\) 个回答的基线时,是否包含 \(r_i\) 本身

GRPO RLOO
基线 \(\mu_R = \frac{1}{G}\sum_{j=1}^{G} r_j\) \(b_i = \frac{1}{G-1}\sum_{j \neq i} r_j\)
\(r_i\) 是否参与基线计算 \(r_i\) 在求和里) (排除了 \(r_i\)
\(r_i\) 与基线的关系 正相关(\(\text{Cov} > 0\) 独立(\(\text{Cov} = 0\)

GRPO 的"自我包含"效应:展开可得 \(r_i - \mu_R = \frac{G-1}{G}(r_i - \bar{r}_{-i})\)(其中 \(\bar{r}_{-i}\) 即 RLOO 的基线),相比 RLOO 多了 \(\frac{G-1}{G}\) 的缩放(\(G=8\) 时为 \(87.5\%\)),梯度信号被轻微压缩。但实践中差距不大,且 GRPO 的 \(\sigma_R\) 标准化会部分补偿这个缩放,加上实现更简单(一次求均值即可),因此被广泛使用。


GRPO 的数学推导与损失函数构建

1. 组内相对优势计算

给定一个输入 Prompt \(s\),策略网络 \(\pi_\theta\) 采样出 \(G\) 个输出(通常 \(G=4 \sim 16\)): \[ o_1, o_2, \dots, o_G \sim \pi_\theta(\cdot|s) \]

奖励模型(或规则判题器)对每个输出打分,得到奖励集合 \(R = \{r_1, r_2, \dots, r_G\}\)

计算组内均值和标准差: \[ \mu_R = \frac{1}{G} \sum_{i=1}^G r_i, \quad \sigma_R = \sqrt{\frac{1}{G} \sum_{i=1}^G (r_i - \mu_R)^2} \]

对于第 \(i\) 个输出 \(o_i\),其相对优势估计为: \[ \hat{A}_i = \frac{r_i - \mu_R}{\sigma_R + \epsilon} \] 其中 \(\epsilon\) 是极小常数,防止除以零(当所有回答奖励相同时 \(\sigma_R = 0\))。

极端情况分析

  • 全对 \(r = [1,1,1,1]\)\(\sigma_R = 0\)\(\hat{A}_i = 0\) → 不更新(都对了,没什么可学的)。

  • 全错 \(r = [0,0,0,0]\)\(\sigma_R = 0\)\(\hat{A}_i = 0\) → 不更新(都错了,没有正样本可以学习)。

  • 一对三错 \(r = [1,0,0,0]\)\(\hat{A}_1 = +1.73\)\(\hat{A}_{2,3,4} = -0.58\) → 大力强化唯一的正确回答。

这种"全对/全错时不更新"的行为避免了在没有区分信号时引入噪声梯度。

2. KL 散度正则化:为什么用这个特殊形式?

为了防止策略"钻空子"(Reward Hacking)或丧失语言连贯性,需要约束 \(\pi_\theta\) 不偏离参考策略 \(\pi_{\text{ref}}\) 太远。直觉上我们需要一个"距离度量"来衡量两个策略有多不同,标准 KL 散度是最自然的选择:

\[ D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}}) = \mathbb{E}_{o \sim \pi_\theta}\left[\log \frac{\pi_\theta(o|s)}{\pi_{\text{ref}}(o|s)}\right] \]

但它有两个实操困难:

  1. 期望无法直接算\(D_{\text{KL}}\) 要求对 \(\pi_\theta\) 分布下所有可能输出求期望,而我们手头只有从旧策略 \(\pi_{\theta_{\text{old}}}\) 采出来的有限样本——分布不匹配,直接用这些样本估计 \(D_{\text{KL}}\) 偏差大。

  2. 逐样本值可正可负:对单个样本 \(o\)\(\log \frac{\pi_\theta(o)}{\pi_{\text{ref}}(o)}\) 可正可负(虽然取期望后 \(D_{\text{KL}} \geq 0\))。回顾最终目标函数的结构:\(J = \text{奖励项} - \beta \cdot \hat{D}_{\text{KL}}\),KL 项被减去以惩罚偏离。如果 \(\hat{D}_{\text{KL}}\) 对某些 token 取负值,\(-\beta \cdot (\text{负数}) = \text{正数}\),反而增加了目标函数——策略在这些 token 上偏离参考模型竟然获得了奖励,与"惩罚偏离"的设计意图相悖。

GRPO 转而采用 Schulman (2020) 提出的一种 KL 近似估计量。令 \(u = \frac{\pi_{\text{ref}}(o_i|s)}{\pi_\theta(o_i|s)}\),定义:

\[ \hat{D}_{\text{KL}} = u - \log u - 1 = \frac{\pi_{\text{ref}}(o_i|s)}{\pi_\theta(o_i|s)} - \log \frac{\pi_{\text{ref}}(o_i|s)}{\pi_\theta(o_i|s)} - 1 \]

这个估计量具有两个关键性质,恰好解决了上述问题:

  • 逐样本非负:由不等式 \(e^x \geq x + 1\)(即 \(u - \log u - 1 \geq 0\)\(\forall u > 0\)),等号当且仅当 \(u = 1\)\(\pi_\theta = \pi_{\text{ref}}\))时成立。这保证了每个 token 的惩罚都 \(\geq 0\),不会出现"偏离反获奖励"的问题。

  • 双侧惩罚:当 \(\pi_\theta\) 塌缩(某 token 概率远小于 \(\pi_{\text{ref}}\)\(u \gg 1\))时惩罚以 \(u\) 线性增长;当 \(\pi_\theta\) 膨胀(概率远大于 \(\pi_{\text{ref}}\)\(u \to 0\))时惩罚以 \(-\log u\) 对数增长。两个方向的偏离都被约束。

数学注\(u - \log u - 1\)\(u = 1\) 附近做泰勒展开得 \(\frac{1}{2}(u-1)^2 + O((u-1)^3)\),与标准 KL 的局部行为一致(二者共享同一 Fisher 信息矩阵),因此在策略变化不大时,两者给出的惩罚几乎相同。

Token 级别的计算

上面公式中 \(\pi_\theta(o_i|s)\)整条回答的概率,但语言模型逐 token 生成。由链式法则,序列概率等于各 token 条件概率的乘积:

\[ \pi_\theta(o_i|s) = \prod_{t=1}^{T} \pi_\theta(o_i^t | s, o_i^{<t}) \]

据此,重要性比率 \(\rho_{i} = \pi_\theta(o_i|s) / \pi_{\theta_{\text{old}}}(o_i|s) = \prod_t \rho_{i,t}\) 和 KL 中的 \(u = \pi_{\text{ref}} / \pi_\theta\) 都可以分解为 token 级的乘积。但实际实现中不计算序列级乘积(长序列会数值溢出),而是在每个 token 位置 \(t\) 独立计算 \(\rho_{i,t}\),逐 token 执行 PPO 裁剪和 KL 惩罚,最后对所有 token 求平均。

3. GRPO 最终目标函数

结合 PPO 的裁剪机制和组内相对优势,GRPO 的最终目标函数(需要最大化)定义为:

\[ \mathcal{J}_{\text{GRPO}}(\theta) = \mathbb{E}_{q \sim P(Q),\, \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left( \min \left( \rho_{i,t}(\theta)\, \hat{A}_i,\; \text{clip}(\rho_{i,t}(\theta),\, 1-\varepsilon,\, 1+\varepsilon)\, \hat{A}_i \right) - \beta\, \hat{D}_{\text{KL}}^{(i,t)} \right) \right] \]

其中:

  • \(\rho_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})}\) 是第 \(i\) 个回答第 \(t\) 个 token 的重要性采样比率。

  • \(\varepsilon\) 是裁剪阈值(如 0.2),防止单步更新过大。

  • \(\beta\) 是 KL 惩罚系数,控制偏离参考策略的代价。

  • \(\hat{A}_i\) 是组内归一化优势(序列级标量,对 \(t\) 为常数,广播到每个 token)。

  • \(\hat{D}_{\text{KL}}^{(i,t)}\) 是 token 级 KL 散度近似估计(采用 \(e^{\log u} - \log u - 1\) 形式,其中 \(u = \pi_{\text{ref}} / \pi_\theta\),详见下文),并非严格的 \(D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}})\) 积分定义。

  • \(\frac{1}{|o_i|}\) 对每条回答按长度归一化(per-response normalization):先对 token 求平均,再对 \(G\) 条回答平均。这意味着不论回答长短,每条回答的权重都是 \(\frac{1}{G}\)。后续 DAPO 论文将此聚合方式改为按 token 归一化(\(\frac{1}{\sum_i |o_i|}\sum_i\sum_t\),使每个 token 等权),详见第六篇

:上式使用 token 级记号 \(\rho_{i,t}\),与 DeepSeekMath 原论文(arXiv:2402.03300)从 PPO 继承的 \(\frac{1}{|o|}\sum_t\) 结构一致。GRPO 的所有计算(IS ratio、裁剪、KL)均在 token 级逐位执行——此处"序列级"仅指聚合方式(每条回答等权),不是计算粒度。DAPO 改变的是聚合方式(从 per-response 到 per-token),而非引入 token 级计算。


GRPO 的完整实现

以下是 GRPO 的完整 PyTorch 实现伪代码,包括数据准备、模型定义、采样、优势计算和训练循环。

Step 1: 模型与数据定义

与 DPO 不同,GRPO 是在线算法:不需要预先收集偏好对,只需要 Prompt 集合和一个能打分的奖励函数。模型方面与 DPO 一样只需要两个(策略 + 参考),但保留了在线 RL 的探索能力。

python
import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    AutoModelForSequenceClassification,
)

# 数据: 只需要 Prompt 集合,无需预标注偏好对
prompts = load_dataset("math_problems")

# 模型 1: 待训练的策略模型 (π_θ)
actor = AutoModelForCausalLM.from_pretrained("sft_checkpoint")
# 模型 2: 冻结的参考模型 (π_ref), KL 正则锚点
ref_model = AutoModelForCausalLM.from_pretrained("sft_checkpoint")
ref_model.requires_grad_(False)

tokenizer = AutoTokenizer.from_pretrained("sft_checkpoint")
# decoder-only 模型 batch 生成时必须左填充,
# 确保所有序列的最后一个真实 token 右对齐在同一列,
# 否则 padding 会插在序列末尾, 破坏自回归生成的连续性
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

optimizer = torch.optim.AdamW(actor.parameters(), lr=1e-6)

# 奖励函数: 可以是规则判题器(数学题判对错)或训练好的奖励模型
reward_model = AutoModelForSequenceClassification.from_pretrained(
    "reward_model_checkpoint"
)
reward_model.requires_grad_(False)

def reward_fn(prompt, response_text):
    """用奖励模型给回答打分,返回标量奖励"""
    inputs = tokenizer(
        prompt + response_text,
        return_tensors="pt", truncation=True
    ).to(reward_model.device)
    with torch.no_grad():
        score = reward_model(**inputs).logits.squeeze()
    return score.item()

# 超参数
G = 8             # 每个 Prompt 采样的回答数量
clip_range = 0.2  # PPO 裁剪阈值 ε
beta = 0.04       # KL 惩罚系数
K_epochs = 2      # 每批数据的更新轮数

Step 2: 在线采样 + 奖励收集

这是 GRPO 与 DPO 的核心区别——GRPO 用当前策略在线生成回答并即时评分:

python
def collect_group_rollouts(actor, prompts_batch, G, reward_fn):
    """
    对每个 Prompt 采样 G 个回答并打分。

    Returns:
        all_prompt_ids:   (B×G, L_p)   prompt token 序列
        all_response_ids: (B×G, L_r)   回答 token 序列(不含 prompt)
        all_rewards:      List[float]  标量奖励, 长度 B×G
        old_log_probs:    (B×G, L_r)   采样时各 token 的 log π_old
        ref_log_probs:    (B×G, L_r)   参考策略各 token 的 log π_ref
    """
    actor.eval()
    with torch.no_grad():
        # Step 2a: 编码所有 prompt 并复制 G 份
        encoded = tokenizer(
            prompts_batch,
            return_tensors="pt", padding=True
        ).to(actor.device)
        # input_ids: (B, L_p), attention_mask: (B, L_p)

        # repeat_interleave 让相邻 G 行属于同一 prompt: (B, L_p) → (B×G, L_p)
        all_prompt_ids = encoded.input_ids.repeat_interleave(G, dim=0)
        all_attn_mask = encoded.attention_mask.repeat_interleave(G, dim=0)

        # Step 2b: batch 生成所有回答
        # 必须传 attention_mask, 否则模型把 padding 当真实 token 处理
        full_ids = actor.generate(
            all_prompt_ids,
            attention_mask=all_attn_mask,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7
        )  # (B×G, L_p + L_r)

        # 分离 response: generate() 返回完整序列, 截掉 prompt 前缀
        prompt_len = all_prompt_ids.shape[1]
        all_response_ids = full_ids[:, prompt_len:]  # (B×G, L_r)

        # Step 2c: 只解码 response 部分, 避免 prompt 文本被重复送入奖励函数
        all_response_texts = tokenizer.batch_decode(
            all_response_ids, skip_special_tokens=True
        )  # List[str], 长度 B×G

        prompts_repeated = [p for p in prompts_batch for _ in range(G)]
        all_rewards = [
            reward_fn(p, t)
            for p, t in zip(prompts_repeated, all_response_texts)
        ]  # List[float], 长度 B×G

        # Step 2d: 计算 token 级 log 概率 (后面算 ρ 和 KL 要用)
        old_log_probs = compute_token_log_probs(
            actor, all_prompt_ids, all_response_ids
        )  # (B×G, L_r): π_old 下每个 token 的 log 概率
        ref_log_probs = compute_token_log_probs(
            ref_model, all_prompt_ids, all_response_ids
        )  # (B×G, L_r): π_ref 下每个 token 的 log 概率

    actor.train()
    return (all_prompt_ids, all_response_ids,
            all_rewards, old_log_probs, ref_log_probs)

Step 3: 组内相对优势计算

python
def compute_group_advantages(rewards, G):
    """
    组内相对优势计算: 用组内均值替代 Critic 网络。

    Args:
        rewards: List[float], 长度 = batch_size × G
        G:       int, 每个 prompt 的采样数
    Returns:
        Tensor (batch_size × G,), 标准化后的相对优势 Â_i
    """
    # 显式指定 float32, 避免整数奖励 (如 0/1 判对错) 被推断为 int 导致除法截断
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(-1, G)
    # rewards: (batch_size, G)

    mean_r = rewards.mean(dim=1, keepdim=True)   # (batch_size, 1)

    # correction=0 → 总体标准差 σ = √(1/G Σ(r-μ)²), 与文中公式一致
    # PyTorch 默认 correction=1 是样本标准差 (除以 G-1), G 较小时差异明显
    std_r = rewards.std(dim=1, keepdim=True, correction=0)  # (batch_size, 1)

    # 全对/全错时 σ=0, 分子也为 0, ε 防除零, Â_i=0 → 不更新
    advantages = (rewards - mean_r) / (std_r + 1e-8)
    return advantages.reshape(-1)   # (batch_size × G,)

Step 4: 完整训练循环

python
for step in range(total_steps):

    # ================================================================
    # 阶段 1: 在线采样 (GRPO 独有, DPO 没有这一步)
    # ================================================================
    prompts_batch = sample_prompts(prompts, batch_size=8)

    (prompt_ids, response_ids, rewards,
     old_log_probs, ref_log_probs) = collect_group_rollouts(
        actor, prompts_batch, G, reward_fn
    )
    # old_log_probs, ref_log_probs: (B×G, L_r) token 级

    # ================================================================
    # 阶段 2: 计算组内相对优势 Â_i
    # ================================================================
    advantages = compute_group_advantages(rewards, G).to(actor.device)
    # advantages: (B×G,)

    # ================================================================
    # 阶段 3: 多 epoch 更新
    # 同一批采样数据复用 K 轮, 靠 PPO 裁剪防止更新过大
    # ================================================================
    for epoch in range(K_epochs):
        for idx in minibatch_indices(len(response_ids), batch_size=16):

            # --- 3a: 重新计算当前 π_θ 的 token 级 log 概率 ---
            # π_θ 每个 minibatch 后都在更新, 必须用最新 θ 重算
            new_log_probs = compute_token_log_probs(
                actor, prompt_ids[idx], response_ids[idx]
            )  # (M, L_r): 每个 token 的 log π_θ

            # response 中的 padding token 不应参与损失计算
            response_mask = (response_ids[idx] != tokenizer.pad_token_id)
            # response_mask: (M, L_r), True 为真实 token

            # --- 3b: token 级重要性采样比率 ---
            # 在 token 级别算比率, 避免序列级乘积导致数值爆炸
            log_ratio = new_log_probs - old_log_probs[idx]  # (M, L_r)
            ratio = torch.exp(log_ratio)                     # (M, L_r)

            # --- 3c: PPO 裁剪 (token 级) ---
            # 优势是序列级标量, 广播到每个 token
            adv = advantages[idx].unsqueeze(-1)  # (M,) → (M, 1)
            surr1 = ratio * adv                              # (M, L_r)
            surr2 = torch.clamp(
                ratio, 1.0 - clip_range, 1.0 + clip_range
            ) * adv                                          # (M, L_r)
            # 只对真实 token 求均值, 忽略 padding
            clipped_obj = torch.min(surr1, surr2) * response_mask
            policy_loss = -clipped_obj.sum() / response_mask.sum()

            # --- 3d: KL 散度惩罚 (token 级 f-散度) ---
            log_u = ref_log_probs[idx] - new_log_probs  # log(π_ref/π_θ)
            kl_per_token = torch.exp(log_u) - log_u - 1.0
            kl_penalty = (kl_per_token * response_mask).sum() / response_mask.sum()

            # --- 3e: 总损失 ---
            loss = policy_loss + beta * kl_penalty

            # --- 3f: 梯度更新 ---
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=1.0)
            optimizer.step()

与 DPO 训练循环的关键对比

  • DPO 只有"前向传播 + 损失计算 + 反向传播",与标准 SFT 训练几乎一样。
  • GRPO 多了"在线采样"阶段(调用 actor.generate()),这是计算开销的主要来源,但也是在线 RL 探索能力的来源。
  • DPO 的 batch 是固定的偏好对;GRPO 的 batch 是模型自己实时生成的,每步训练都能看到新的探索结果。

GRPO 与 PPO / DPO 的全景对比

维度 PPO (RLHF) DPO GRPO
模型数量 4 (Actor+Critic+Ref+RM) 2 (Actor+Ref) 2 (Actor+Ref) + 外部奖励函数
训练方式 在线 RL 离线监督学习 在线 RL
基线估计 Critic 网络 \(V_\phi(s)\) 无需基线 组内经验均值 \(\mu_R\)
显存开销 极高 (4 个大模型) 低 (2 个大模型) 低 (2 个大模型)
计算开销 中等 (每 Prompt 采样 1 次) 最低 (纯前向传播) 较高 (每 Prompt 采样 G 次)
探索能力 弱 (离线数据)
核心优势 经典稳定 极简高效 省显存 + 在线探索

开源代码参考: GRPO 随 DeepSeek 开源而爆火,Hugging Face TRL 库 (trl.GRPOTrainer) 提供了生产级实现。

GRPO 证明了在生成式大模型时代,简单的经验统计(组内均值)往往比复杂的神经网络预测(Critic)更加鲁棒和高效。

参考资料:

  1. Shao, Z., Wang, P., Zhu, Q., Hao, K., Bugliarello, B., ... & Liu, Y. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models. arXiv:2402.03300.
  2. Ahmadian, A., Cremer, C., Gallé, M., Fadaee, S., ... & Vulić, A. (2024). Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs. arXiv:2402.14740.

下一篇:笔记|强化学习(五):Flow-GRPO 与图像生成应用(基于 Flux 的代码解析)