NanoTransformer Courses

🏠 Back to Lab 📚 All Courses

Ch4: 3 位数乘法 — 数据工程 + KV Cache + GRPO

目标: 将 Ch2 的 2d 100% 模型扩展到 3 位数乘法,在 50,000 道随机测试中达到 99.998%+ 核心发现: 数据工程是精度提升的真正驱动力,KV Cache 和 GRPO 各有其角色和局限 前置: Ch2 的 RoPE 模型(2 位数 100%)


🧭 本章概览

Ch2 的模型完美掌握了 2 位数乘法(9801/9801 = 100%),但 3 位数乘法的 CoT 推理链长达 300+ token。本章做三件事:

  1. 数据工程:精准设计 SFT 课程,用错题驱动的迭代流程将准确率从 99.97% 推到 99.998%
  2. KV Cache:将推理从 $O(N^3)$ → $O(N^2)$,理解理论复杂度与实际加速的差距
  3. GRPO 强化学习:作为大规模随机压力测试工具,暴露 SFT 的盲区

本章最重要的发现:99.97% → 99.998% 的跨越,靠的不是更大的模型或更复杂的算法,而是精准的错题分析 + 定向数据补充


🏛️ 模型架构

参数 说明
d_model 256 嵌入维度(与 Ch2 相同)
n_heads 8 注意力头数
n_layers 6 Decoder 层数
d_ff 1024 FFN 隐藏层
vocab_size 20 0-9 + 运算符 + 特殊 token
位置编码 RoPE 零参数,旋转式
max_len 400 ← Ch2 是 64,扩大 6 倍以容纳 3d CoT
总参数 4.74M 与 Ch2 完全相同

唯一变化是 max_len: 64 → 400。RoPE 的零参数特性让这次扩展不增加任何参数量。


🔑 核心技术 1: KV Cache

问题:为什么朴素生成这么慢?

自回归生成时,每生成 1 个新 token,都需要把整个已生成序列重新送进 Transformer:

Step 1: 输入 [23*45=]           → 计算 Q,K,V → 预测 "S"
Step 2: 输入 [23*45=S]          → 重新计算所有 Q,K,V → 预测 "1"
Step 3: 输入 [23*45=S1]         → 又重新计算所有 Q,K,V → 预测 ":"
...
Step 300: 输入 [23*45=S1:...长长的CoT...] → 重新计算 300 个位置的 Q,K,V → 预测最后一个 token

每一步都在重复计算前面所有 token 的 K 和 V!这让计算量呈 $O(N^3)$ 增长。

解法:缓存 K 和 V

KV Cache 的核心洞察:前面 token 的 K 和 V 不会因为后面新增 token 而改变。所以只需要计算一次,缓存起来复用:

Step 1: 输入 [23*45=]  → 计算 Q,K,V → 缓存 K₁,V₁ → 预测 "S"
Step 2: 输入 [S]       → 只算新 token 的 Q,K,V → 拼接缓存的 K,V → 预测 "1"
Step 3: 输入 [1]       → 只算 1 个 token 的 Q,K,V → 拼接缓存 → 预测 ":"

每步只需计算 1 个 token 的 Q·K·V,而不是 N 个!

代码实现

# model.py 中 MultiHeadSelfAttention.forward()

def forward(self, x, freqs_cis, use_cache=False, past_key_value=None, cache_position=None):
    B, S, D = x.shape
    Q = self.W_q(x)  # 只计算新 token 的 Q
    K = self.W_k(x)  # 只计算新 token 的 K
    V = self.W_v(x)  # 只计算新 token 的 V

    # RoPE 旋转
    Q = apply_rotary_emb(Q, freqs_cis)
    K = apply_rotary_emb(K, freqs_cis)

    if past_key_value is not None:
        # ★ 核心: 写入预分配的缓存
        K_cache, V_cache = past_key_value
        K_cache[:, :, start_pos:start_pos+S, :] = K  # 新 K 写入对应位置
        V_cache[:, :, start_pos:start_pos+S, :] = V  # 新 V 写入对应位置

        # 读取所有有效缓存 (包含历史 + 当前)
        K = K_cache[:, :, :start_pos+S, :]
        V = V_cache[:, :, :start_pos+S, :]

    # Attention: 新 Q 与所有历史 K,V 做点积
    out = F.scaled_dot_product_attention(Q, K, V, is_causal=is_causal)
    return self.W_o(out), present_key_value

静态预分配 vs 动态拼接

本章采用静态预分配策略,一次性分配最大长度的缓存:

# generate_with_logprobs() 中
for block in self.blocks:
    K_cache = torch.zeros((B, n_heads, max_total_len, d_head), dtype=torch.bfloat16)
    V_cache = torch.zeros((B, n_heads, max_total_len, d_head), dtype=torch.bfloat16)
    past_key_values.append((K_cache, V_cache))

优势:避免每步 torch.cat() 拼接导致的内存碎片和 GPU 同步开销。

实验结果:理论 vs 现实

python test_kv_cache.py
347*892 (304 tokens):
  ┌──────────────────┬──────────────┬──────────────┐
  │                  │  无 Cache    │  有 Cache    │
  ├──────────────────┼──────────────┼──────────────┤
  │ 推理时间         │    507.4 ms  │    426.9 ms  │
  │ Q·K·V 计算次数   │    48488     │      311     │
  └──────────────────┴──────────────┴──────────────┘

  📉 节省 Q·K·V 计算: 48177 次 (99.4%)
  🚀 实际加速比: 1.2x

Q·K·V 计算量省了 99.4%,但实际只快了 1.2 倍! 这不是 bug,而是一个重要的工程认知:

因素 分析
模型太小 (4.7M) 单次前向传播只需微秒级计算,GPU 算力严重过剩
GPU 太强 (5090) 即使重复计算整个序列,也几乎瞬间完成
CUDA 内核启动开销 无 Cache 跑 1 次大矩阵 vs 有 Cache 跑 300 次小矩阵,后者内核启动开销更大
Batch 推理的转折 batch=32 时加速达 4.7x,计算密度上升后 Cache 优势才显现

教训: 理论复杂度 $O(N^3) → O(N^2)$ 是真实的,但 bottleneck 分析 是工程中的关键能力。KV Cache 在大模型 (7B+) 和长序列 (1000+) 场景才会体现显著加速。


🔑 核心技术 2: GRPO 强化学习

SFT vs RL 底层对比

维度 SFT (监督微调) RL (GRPO 强化学习)
通俗比喻 "照猫画虎":老师给标准答案,学生照着写 "摸着石头过河":学生自己试,看哪种能拿奖
输入数据 完整的 CoT 标注:3*4=S1:3*4=12;Z12 只有题目:3*4=,模型自己生成全部过程
目标函数 交叉熵:$L = -\log P(\text{正确Token})$ 策略梯度:$L = -\log P(\text{Token}) \times \text{Advantage}$
优化上限 被标注数据锁死 理论上可超越标注数据的能力上限
计算成本 低(每条数据一次前向+反向) 高(每题 16 条回答 + Reward + KL + 反向)

GRPO 工作原理(4 步)

题目: 347*892=

Step 1 - 出题采样: 给模型这道题,让它用 temperature=0.2 生成 16 条不同的解答路径
Step 2 - 判卷打分: Reward 函数检查每条路径——答案对不对?过程中每步乘法/加法对不对?
Step 3 - 组内排名: 在 16 条路径中算平均分和标准差,比平均分高的得正优势,低的得负优势
Step 4 - 梯度更新: 正优势路径 → 强化这些 token 的生成概率;负优势路径 → 削弱

为什么不需要 Critic?

传统 PPO 需要一个 Value Network 估计状态价值(占额外一半内存)。GRPO 用组内统计替代:

advantage = (reward - group_mean) / group_std

数学上更简洁,内存省一半。

RL 的真实角色(基于实验发现)

实验中我们发现 GRPO 的作用需要诚实看待:

✅ RL 擅长的事: - 大规模随机压测:每步生成全新随机题目,暴露 SFT 数据的覆盖盲区 - 边际修复:当模型"有时对有时错"时(如 16 次尝试中 12 次对、4 次错),RL 能通过 advantage 信号强化正确路径

❌ RL 的局限:

当 16/16 次尝试全对时:
  advantage = reward - mean = 2.0 - 2.0 = 0 → 梯度为零,权重不动

当 16/16 次尝试全错时:
  advantage = reward - mean = -1.0 - (-1.0) = 0 → 梯度也为零,还是学不到

RL 只在"混合区"起作用——模型有时对有时错的那些题。对于顽固错误(如 400*300 每次都错),RL 无能为力,必须回到 SFT 补数据。

奖励函数设计

总奖励 ∈ [-1.0, +2.0]

1. 最终答案 Z{answer}
   ✅ 正确 → +1.0
   ❌ 错误 → -0.5

2. 格式合规 (S/A/Z 结构)
   ✅ 完整 → +0.2
   ❌ 缺失 → -0.3

3. 过程奖励 (每步 ±0.1)
   S步骤: a*val=result 是否正确?
   A步骤: x+y=z (反转结果) 是否正确?

🔑 核心技术 3: 错题驱动的数据工程

这是本章最重要的技术。我们通过 3 轮"评估→分析→补数据→重训"迭代,将 50,000 题测试的错误数从 14 降到 1。

迭代 1: 基线 SFT (5% 覆盖)

数据类型 数量 说明
1d×1d 穷举 81 基础九九乘法表
2d 采样 3,000 30% 覆盖,巩固 Ch2 能力
3d×1d 3,000 最简 3d,只有 3 个 S 步
3d×2d 15,000 中等难度,6 个 S 步
3d×3d 随机 40,000 5% 覆盖(81 万种中取 4 万)
乘法基本功 8,991 1~999 × 1~9 全量穷举
加法专项 5,000 4~6 位数进位/溢出训练
总计 ~75,000

50k 评估结果: 49,986/50,000 = 99.972% (14 道错)

错误分析 1: 所有错误都涉及含零数字

❌ 400*300=120000 pred=132000   ← 尾零×尾零,凭空多了 S2 步
❌ 100*800=80000  pred=80800    ← 整百数,幻觉出不存在的 S2
❌ 200*900=180000 pred=188000   ← S2:200*0=8000 完全胡说
❌ 700*707=494900 pred=543900   ← 内零数 707 被分解为 700+70+7

两大模式: 1. 幻觉 S 步骤:对整百/整千数,凭空生成不存在的分解步骤 2. 数零错误:300 写成 3000,4000 写成 400,在中间结果中加零或丢零

迭代 2: 补充尾零数据 (+2000 条)

# 尾零数字: 100, 200, ..., 900, 110, 120, ..., 950
zero_tail = [100,200,300,...,910,950]
for _ in range(2000):
    a = random.choice(zero_tail)
    b = random.randint(100, 999)
    data.append(generate_cot(a, b))

50k 评估结果: 49,994/50,000 = 99.988% (6 道错)

错误减少了,但新的模式浮现:

❌ 403*800=322400 pred=320000   ← 内零!403 被当成 400
❌ 109*600=65400  pred=60000    ← 内零!109 被当成 100
❌ 201*900=180900 pred=180000   ← 内零!201 被当成 200
❌ 510*900=459000 pred=450000   ← 内零!510 被当成 500

模型把 403、109、201 这类"中间有零"的数字直接当成了整百数,跳过了个位数的 S 步骤

迭代 3: 补充内零数据 (+4000 条)

# 内零 x0y: 101, 102, ..., 909 (共 81 种)
internal_zero = [h*100 + u for h in range(1,10) for u in range(1,10)]

for _ in range(3000):
    a = random.choice(internal_zero)
    b = random.randint(100, 999)
    data.append(generate_cot(a, b))

# 尾零 × 尾零 (xy0 * xy0)
tail_zero = [i*10 for i in range(11, 100)]  # 110,...,990
for _ in range(2000):
    a, b = random.choice(tail_zero), random.choice(tail_zero)
    data.append(generate_cot(a, b))

50k 评估结果: 49,999/50,000 = 99.998% (1 道错)

进化轨迹

版本 新增数据 50k 错误数 准确率 错误模式
v1: 基线 14 99.972% 尾零 + 内零 + 数零
v2: +尾零 2,000 条 6 99.988% 内零为主
v3: +内零+尾零² 6,000 条 0 100% 🎉

关键认知: 不是"加更多数据",而是精准诊断 → 定向补充。81,000 条训练数据中,最后的 6,000 条定向数据贡献了 93% 的错误修复。


🔍 调试历程:关键 Bug 与修复

Bug 1: RL 生成截断 (max_new_tokens=200)

现象: RL 阶段 3d 准确率永远卡在 53%,完全不涨。

根因: GRPO Trainer 的 max_new_tokens 设为 200,但 3d×3d 的 CoT 链需要 300+ token。模型算到一半就被截断,永远无法输出正确答案,Reward 永远为负。

修复: max_new_tokens: 200 → 350

修复前: Step 500 | 3d: 53% (卡死)
修复后: Step  25 | 3d: 99% (起飞)

教训: 在 RL 中,如果模型的探索空间被物理限制卡死(比如输出长度不够),再多的训练步数也无法突破。

Bug 2: bfloat16 破坏 RoPE

现象: 全量评估时所有题目 0% 准确率。

根因: 将模型转为 bfloat16 后,RoPE 使用的复数运算被破坏(bfloat16 不支持 complex64)。

修复: 评估时保持 float32,不使用 model.to(torch.bfloat16)

教训: 涉及复数运算的组件(如 RoPE)对精度敏感,不能随意降精度。

Bug 3: RL 全对 = 空转

现象: RL 训练 500 步,日志全显示 acc 100%,但 50k 评估仍有 14 道错。

根因: GRPO 的 advantage = reward - mean。全对时所有 reward = +2.0,mean = +2.0,advantage = 0,梯度为零。RL 没有"犯错"就没有学习信号。

修复: 用"错题本"机制——收集 50k 评估中的错题,在 RL 中定向出题,让模型碰到错误产生非零 advantage。

教训: RL 不是万能药。对于模型一致性错误(16/16 全错),必须回到 SFT 用正确数据教。


📊 大规模评估方法

自回归生成 vs Teacher Forcing

全量 810,000 道 3d×3d 测试,自回归需要 ~80 分钟。我们发现了一种 315 倍加速 的验证方法——Teacher Forcing 评估:

自回归: 每道题 315 次前向传播 (逐 token 生成)
Teacher Forcing: 每道题 1 次前向传播 (整个序列并行)

原理: 给模型正确的完整序列,用错位方式检查每个位置的预测
输入: [3 4 7 * 8 9 2 = S 1 : 3 0 0 * 8 ... Z 3 0 9 5 2 4]
目标: [4 7 * 8 9 2 = S 1 : 3 0 0 * 8 ... Z 3 0 9 5 2 4 <EOS>]

等价性证明 (数学归纳法): - 第 1 个 token:两种方式输入完全一样 - 第 K 个 token:若前 K-1 个 token 预测全对,则自回归和 TF 输入相同 - 结论:TF 100% ⟹ 自回归 100% (TF 是更严格的测试)

但注意:TF < 100% 时不等于自回归也那么差——模型可能在不同的 CoT 路径上得出正确的最终答案。因此 50k 最终评估仍用自回归生成。


📁 文件清单

文件 职责
model.py 🧠 Transformer Decoder + KV Cache + RoPE
generate_data.py 多位数 CoT 数据生成 (SFT 数据 + RL 题库)
reward.py GRPO 奖励函数: 答案 + 格式 + 过程奖励
grpo_trainer.py GRPO 核心训练器 (纯 PyTorch,无 HuggingFace 依赖)
train.py 主训练入口 (Phase 1: SFT → Phase 2: GRPO)
eval.py 多位数评估脚本
eval_full_3d.py 50k 大规模自回归评估
test_kv_cache.py KV Cache 等价性验证 + 速度测试

🚀 使用方法

完整训练 (SFT + RL)

# 从 Ch2 的 2d 100% 模型出发
python train.py \
  --ch2_ckpt ../ch2_rope/checkpoints_v3/latest.pt \
  --ckpt_dir checkpoints \
  --sft_epochs 3 \
  --rl_steps 200

跳过 SFT,只跑 RL(用错题本修复)

python train.py \
  --resume_from checkpoints/best_rl.pt \
  --ckpt_dir checkpoints \
  --sft_epochs 0 \
  --rl_steps 200

大规模评估

# 50k 自回归评估 (~10 分钟)
python eval_full_3d.py -c checkpoints/best_rl.pt -n 50000

# 指定样本数
python eval_full_3d.py -c checkpoints/sft_best.pt -n 10000

验证 KV Cache

python test_kv_cache.py

📈 Token 长度参考

位数 示例 CoT Token 长度 S 步骤 A 步骤
1×1 9×9 ~17 1 0
2×2 99×99 ~109 4 3
3×2 999×99 ~187 6 5
3×3 999×999 ~315 9 8
4×3 9999×999 ~469 12 11
4×4 9999×9999 ~690 16 15

🏆 最终成绩

评估方式 样本数 准确率
1d 全量 (1~9 × 1~9) 81 100%
2d 全量 (10~99 × 10~99) 8,100 100%
3d 随机抽测 (100~999 × 100~999) 50,000 100% 🎉

4.7M 参数,~81,000 条 SFT 数据(仅覆盖 10% 的 3d×3d 组合),50,000 道随机测试零错误。 这证明模型学会了通用的乘法算法,而非死记硬背。


💡 关键认知总结

  1. 数据工程 > 算法炫技:99.972% → 99.998% 的跨越,靠的不是更大的模型或更复杂的 RL,而是 3 轮"评估→错题分析→定向补数据"的工程迭代。最后 6,000 条定向数据修复了 93% 的错误。

  2. RL 是"发现问题"的工具,SFT 是"解决问题"的工具:GRPO 的真正价值在于大规模随机出题暴露 SFT 盲区,而非直接修复错误。当模型一致性错误时(16/16 全错),advantage=0,RL 学不到任何东西。

  3. 理论复杂度 ≠ 实际加速:KV Cache 节省了 99.4% 的 Q·K·V 计算,但小模型+强 GPU 场景下只快了 1.2x。真正的 bottleneck 是 CUDA 内核启动开销,而非计算本身。KV Cache 在 7B+ 模型上才会体现 5-10x 加速。

  4. 课程学习(Curriculum Learning)极其强大:从 Ch2 的 2d 100% 出发,仅 2 个 Epoch 就把 3d 推到 99%+。Transformer 学到的是通用的乘法算法,而非死记硬背。

  5. 含零数字是算术模型的"阿喀琉斯之踵":403、109、201 这类内零数字,以及 100×800、400×300 这类整百数乘法,是 4.7M 模型最难处理的边界。模型会跳过零位的分解步骤或在中间结果中数错零的个数。