NanoTransformer Courses

🏠 Back to Lab 📚 All Courses

Ch1: 教 Transformer 学会两位数乘法

从九九乘法表到任意两位数乘法——一个 4.76M 参数小模型的 100% 泛化之路

🎯 目标

在上一章中,我们训练了一个 Nano Transformer 掌握了九九乘法表(0-9 × 0-9 = 100 种组合)。 那只是记忆——100 个答案,死记硬背就行。

现在我们把范围扩大到 1-99 × 1-99 = 9801 种组合

直接背?一个 4.76M 参数的小模型根本背不下来。我们必须让它学会算法,而不是记答案。

目标: 用 30% 的数据(3000条)训练,在全部 9801 个组合上达到 100% 准确率

💡 核心思想:Chain-of-Thought(思维链)

人类怎么心算 45 × 67?不是直接蹦出 3015,而是拆步骤:

45 × 67
= 45 × 60 + 45 × 7       ← 拆解
= 2700 + 315              ← 分别乘
= 3015                    ← 加起来

我们教 Transformer 做同样的事。输入 45*67=,期望它输出:

S1:45*60=2700;S2:45*7=315;A1:2700+315=5103;Z3015

其中: - S1, S2 = 乘法子步骤(Step) - A1 = 加法步骤(Add) - Z = 最终答案(Zero-in / 终)

你可能注意到 A1 后面写的是 5103 而不是 3015——这是我们的秘密武器,后面会讲。

📐 模型架构

参数 说明
d_model 256 嵌入维度
n_heads 8 注意力头数
n_layers 6 Decoder 层数
d_ff 1024 FFN 隐藏层
vocab_size 20 0-9 + 运算符 + 特殊token
总参数 4.76M 比 GPT-2 小 26000 倍

标准 Transformer Decoder,没有任何花哨的改动。

🚀 快速开始

# 1. 生成数据
python generate_data.py --num_train 3000 --num_test 500 --num_add_practice 3000 --output_dir data_rev3

# 2. 训练(约10分钟,GPU)
python train.py \
  --train_file data_rev3/train.txt \
  --test_file data_rev3/test.txt \
  --d_model 256 --n_heads 8 --n_layers 6 --d_ff 1024 \
  --epochs 300 --batch_size 64 --lr 3e-4 \
  --ckpt_dir checkpoints_rev3

# 3. 全量评估(9801个组合)
python eval_full.py --checkpoint checkpoints_rev3/latest.pt

预期结果:

🏆 最终结果: 9801/9801 = 100.00%
🎉🎉🎉 完美! 全部 9801 个组合 100% 正确!

🔬 探索过程:从 87% 到 100% 的调试之旅

以下记录了我们实际的调试过程。每一次失败都揭示了 Transformer 学习算术的一个关键洞察。

第一版:朴素 CoT(准确率 87%)

最初的方案很直接——生成 3000 条 CoT 训练数据,加法结果正常书写:

45*67=S1:45*60=2700;S2:45*7=315;A1:2700+315=3015;Z3015

训练 300 epoch 后,500 样本评估达到 87%。看起来不错?但看看错误:

❌ 25*97: A1:2250+175=2325  (正确: 2425)
❌ 35*58: A1:1750+280=1030  (正确: 2030)
❌ 84*23: A1:1680+252=2932  (正确: 1932)

乘法步骤全对,加法步骤全错。 模型学会了拆解和乘法,却算不对加法。

第二版:加入加法练习数据(87% → 97%)

既然加法是瓶颈,就单独练加法。在训练数据中混入独立的加法练习:

2700+315=3015     ← 加法练习
1750+280=2030     ← 加法练习
45*67=S1:...      ← 乘法CoT
加法练习数量 500样本准确率
0 87%
500 94%
1200 95%
3000 97.2%
5000 96.6% ↓

加法数据有效!但超过 3000 条反而下降——加法数据太多挤占了乘法学习。

全量 9801 评估:95.82%,410 个错误。

分析这 410 个错误:

乘法步骤错误 (S1/S2):  30 个 (7%)
加法步骤错误 (A1):    377 个 (92%)  ← 依然是加法!
其他:                   3 个 (1%)

深入分析:加法到底错在哪?

典型错误模式:

❌ A1: 80+20=10     ← 应该是 100,漏了百位
❌ A1: 350+42=492   ← 应该是 392,百位多了 1
❌ A1: 640+64=604   ← 应该是 704,百位少了 1

核心矛盾:加法进位从右往左传播,但 Transformer 从左往右生成。

当模型生成 1750+280= 后面的数字时: - 要先写千位的 2 - 但这个 2 取决于百位 7+2=9 有没有进位 - 百位有没有进位又取决于十位 5+8=13 的进位 - 模型还没算到十位,就要决定千位写什么

这不是数据量的问题,是生成方向的结构性矛盾。

第三版:反转加法结果(关键突破!97% → 99.8%)

解决方案优雅而简单——把加法结果的数字反过来写

之前: A1:1750+280=2030     ← 从最高位开始,需要"预判"进位
现在: A1:1750+280=0302     ← 从最低位开始,进位自然向后传播

反转后,模型的生成顺序变成: 1. 先写个位: 0+0=0 → 写 0 2. 再写十位: 5+8=13 → 写 3,记住进位 1 3. 再写百位: 7+2+1=10 → 写 0,记住进位 1 4. 最后千位: 1+0+1=2 → 写 2

每一步只需要当前列的数字加上前一步的进位,不需要预判!

Z(最终答案)保持正序:Z2030。模型只需要学会把反转的结果翻转回来。

全量 9801 评估:99.80%,只有 20 个错误。 而且——

乘法步骤错误 (S1/S2): 19 个
加法步骤错误 (A1):     0 个  ← 加法错误归零!

反转彻底解决了加法问题。 剩下的 19 个错误全是乘法事实错误。

分析最后的 19 个乘法错误

89*7=623, 模型输出 603    ← 乘法事实记错了
89*17: S2:89*7=603 → 错   ← 同一个错误的级联
89*27: S2:89*7=603 → 错
89*37: S2:89*7=603 → 错
... (共 19 个全是 89*7 的级联)

去训练数据里查——89 出现了 21 次,子步骤覆盖了 89×1,2,4,5,6,8,9,唯独没有 89×7 和 89×3。

3000 条随机数据,每个 a 平均出现 30 次,但受随机性影响,某些 a×d 组合可能完全缺失。

第四版:加入乘法基本功(99.8% → 100%!)

和加法练习同理——加入 891 条乘法基本功(99 × 9 = 891 种 a×d 组合):

89*7=S1:89*7=623;Z623       ← 乘法基本功(CoT 格式)
45*3=S1:45*3=135;Z135       ← 乘法基本功
45*67=S1:45*60=2700;S2:...  ← 完整乘法 CoT

⚠️ 踩坑记录: 最初乘法基本功用的是简单格式 89*7=623,结果和 CoT 格式 89*7=S1:89*7=623;Z623 冲突——模型不知道该用哪种格式,准确率反而从 99% 降到 90%。统一用 CoT 格式后解决。

最终全量评估:

🏆 最终结果: 9801/9801 = 100.00%
   错误数: 0
🎉🎉🎉 完美! 全部 9801 个组合 100% 正确!

🏗️ 训练数据的三层结构

最终方案的训练数据由三部分组成,各司其职:

┌─────────────────────────────────────────────┐
│          完整乘法 CoT × 3000 条              │  ← 学"怎么拆、怎么串"
│  45*67=S1:45*60=2700;S2:45*7=315;           │
│         A1:2700+315=5103;Z3015              │
├─────────────────────────────────────────────┤
│          乘法基本功 × 891 条                 │  ← 学"怎么乘" (原子事实)
│  89*7=S1:89*7=623;Z623                      │
│  覆盖全部 a×d (a=1-99, d=1-9)              │
├─────────────────────────────────────────────┤
│          加法练习 × 3000 条                  │  ← 学"怎么加" (进位规则)
│  2700+315=5103     (反转结果)               │
│  均衡覆盖0/1/2次进位 + 位数扩展            │
└─────────────────────────────────────────────┘
              总计: 6840 条

为什么这样设计?

模型做 45*67 时,实际执行三个原子操作:

原子技能 数量 训练来源
拆解 b=67 → 60 和 7 固定规则 3000 条 CoT 学格式
乘法: 45×60=2700, 45×7=315 891 种 乘法基本功全覆盖
加法: 2700+315=3015 通用规则 3000 条加法练习

3000 条 CoT 教"组装流程",891+3000 条练习教"原子技能"。

这就是为什么 30% 的数据能泛化到 100%——模型不是在背 9801 个答案,而是学会了 3 个可组合的技能。


🔑 关键洞察总结

1. CoT 是泛化的基础

没有 CoT,模型必须记忆 9801 个映射关系。有了 CoT,模型只需要学会 ~891 个乘法事实 + 加法规则 + 组装格式。

2. 生成方向必须匹配计算方向

加法进位从右到左,但 Transformer 从左到右生成。反转加法结果让两个方向一致,加法错误从 377 个直降到 0。

3. 原子事实必须全覆盖

3000 条随机数据覆盖 30% 的乘法组合,但某些 a×d 原子事实可能完全缺失。加入 891 条乘法基本功保证每个原子事实至少出现一次。

4. 格式一致性很重要

乘法练习用 89*7=623 和 CoT 用 89*7=S1:89*7=623;Z623 会冲突——同一个输入有两种输出格式,模型不知道选哪个。所有数据统一用 CoT 格式。

5. 数据配比影响效果

加法练习太少(500条)不够,太多(5000条)反而挤占乘法学习。最佳配比约 1:1(乘法 3000 : 加法 3000)。


📊 实验全记录

版本 训练数据 500样本 全量9801 关键变化
v1 朴素 CoT 3000 乘法 87% - 基线
v2 +加法500 3000+500 94% - 加法练习有效
v3 +加法3000 3000+3000 97.2% 95.8% 最佳加法量
v4 +加法5000 3000+5000 96.6% - 过多反而下降
v5 反转 3000+3000(反转) 99.8% 99.8% 加法错误→0
v6 +乘法基本功 3000+891+3000 100% 100% 乘法错误→0

📁 文件说明

ch1_cot_basics/
├── generate_data.py    # 数据生成器(CoT + 加法练习 + 乘法基本功)
├── train.py            # 训练脚本(Teacher Forcing + 自回归评估)
├── eval_full.py        # 全量 9801 评估脚本(含错误分析)
├── infer.py            # 推理脚本(交互式测试)
├── data_rev3/          # 最终训练数据
│   ├── train.txt       # 6840 条训练数据
│   ├── test.txt        # 500 条测试数据
│   └── train_pairs.txt # 训练用的 (a,b) 对
└── checkpoints_rev3/   # 模型检查点
    ├── best.pt         # 最佳模型
    └── latest.pt       # 最终模型

🤔 思考题

  1. 为什么 1000 条 CoT 不够? 我们试过只用 1000 条 CoT + 891 乘法基本功,准确率卡在 64%。CoT 不仅教原子事实,还教"组装规则"——如何拆解 b、如何连接 S→A→Z。1000 条不够学会足够多的组装模式。

  2. 反转加法结果后,Z 答案为什么还是正序? Z 是最终输出给用户的答案,保持正序可读性好。模型需要学会"把反转的数字翻转回来"——这对 Transformer 来说是一个简单的 copy+reverse 操作,比"预判进位"容易得多。

  3. 这个方法能扩展到三位数乘法吗? 可以,但需要解决连加问题——A1 的反转结果要作为 A2 的输入。这是 Ch2 要探索的内容。


📖 下一章预告

Ch2: RoPE 位置编码 — 当序列变长(三位数乘法的 CoT 可达 100+ token),位置编码成为关键。我们将探索旋转位置编码如何帮助模型处理更长的推理链。