从九九乘法表到任意两位数乘法——一个 4.76M 参数小模型的 100% 泛化之路
在上一章中,我们训练了一个 Nano Transformer 掌握了九九乘法表(0-9 × 0-9 = 100 种组合)。 那只是记忆——100 个答案,死记硬背就行。
现在我们把范围扩大到 1-99 × 1-99 = 9801 种组合。
直接背?一个 4.76M 参数的小模型根本背不下来。我们必须让它学会算法,而不是记答案。
目标: 用 30% 的数据(3000条)训练,在全部 9801 个组合上达到 100% 准确率
人类怎么心算 45 × 67?不是直接蹦出 3015,而是拆步骤:
45 × 67
= 45 × 60 + 45 × 7 ← 拆解
= 2700 + 315 ← 分别乘
= 3015 ← 加起来
我们教 Transformer 做同样的事。输入 45*67=,期望它输出:
S1:45*60=2700;S2:45*7=315;A1:2700+315=5103;Z3015
其中: - S1, S2 = 乘法子步骤(Step) - A1 = 加法步骤(Add) - Z = 最终答案(Zero-in / 终)
你可能注意到 A1 后面写的是
5103而不是3015——这是我们的秘密武器,后面会讲。
| 参数 | 值 | 说明 |
|---|---|---|
d_model |
256 | 嵌入维度 |
n_heads |
8 | 注意力头数 |
n_layers |
6 | Decoder 层数 |
d_ff |
1024 | FFN 隐藏层 |
vocab_size |
20 | 0-9 + 运算符 + 特殊token |
| 总参数 | 4.76M | 比 GPT-2 小 26000 倍 |
标准 Transformer Decoder,没有任何花哨的改动。
# 1. 生成数据
python generate_data.py --num_train 3000 --num_test 500 --num_add_practice 3000 --output_dir data_rev3
# 2. 训练(约10分钟,GPU)
python train.py \
--train_file data_rev3/train.txt \
--test_file data_rev3/test.txt \
--d_model 256 --n_heads 8 --n_layers 6 --d_ff 1024 \
--epochs 300 --batch_size 64 --lr 3e-4 \
--ckpt_dir checkpoints_rev3
# 3. 全量评估(9801个组合)
python eval_full.py --checkpoint checkpoints_rev3/latest.pt
预期结果:
🏆 最终结果: 9801/9801 = 100.00%
🎉🎉🎉 完美! 全部 9801 个组合 100% 正确!
以下记录了我们实际的调试过程。每一次失败都揭示了 Transformer 学习算术的一个关键洞察。
最初的方案很直接——生成 3000 条 CoT 训练数据,加法结果正常书写:
45*67=S1:45*60=2700;S2:45*7=315;A1:2700+315=3015;Z3015
训练 300 epoch 后,500 样本评估达到 87%。看起来不错?但看看错误:
❌ 25*97: A1:2250+175=2325 (正确: 2425)
❌ 35*58: A1:1750+280=1030 (正确: 2030)
❌ 84*23: A1:1680+252=2932 (正确: 1932)
乘法步骤全对,加法步骤全错。 模型学会了拆解和乘法,却算不对加法。
既然加法是瓶颈,就单独练加法。在训练数据中混入独立的加法练习:
2700+315=3015 ← 加法练习
1750+280=2030 ← 加法练习
45*67=S1:... ← 乘法CoT
| 加法练习数量 | 500样本准确率 |
|---|---|
| 0 | 87% |
| 500 | 94% |
| 1200 | 95% |
| 3000 | 97.2% |
| 5000 | 96.6% ↓ |
加法数据有效!但超过 3000 条反而下降——加法数据太多挤占了乘法学习。
全量 9801 评估:95.82%,410 个错误。
分析这 410 个错误:
乘法步骤错误 (S1/S2): 30 个 (7%)
加法步骤错误 (A1): 377 个 (92%) ← 依然是加法!
其他: 3 个 (1%)
典型错误模式:
❌ A1: 80+20=10 ← 应该是 100,漏了百位
❌ A1: 350+42=492 ← 应该是 392,百位多了 1
❌ A1: 640+64=604 ← 应该是 704,百位少了 1
核心矛盾:加法进位从右往左传播,但 Transformer 从左往右生成。
当模型生成 1750+280= 后面的数字时:
- 要先写千位的 2
- 但这个 2 取决于百位 7+2=9 有没有进位
- 百位有没有进位又取决于十位 5+8=13 的进位
- 模型还没算到十位,就要决定千位写什么
这不是数据量的问题,是生成方向的结构性矛盾。
解决方案优雅而简单——把加法结果的数字反过来写:
之前: A1:1750+280=2030 ← 从最高位开始,需要"预判"进位
现在: A1:1750+280=0302 ← 从最低位开始,进位自然向后传播
反转后,模型的生成顺序变成:
1. 先写个位: 0+0=0 → 写 0
2. 再写十位: 5+8=13 → 写 3,记住进位 1
3. 再写百位: 7+2+1=10 → 写 0,记住进位 1
4. 最后千位: 1+0+1=2 → 写 2
每一步只需要当前列的数字加上前一步的进位,不需要预判!
Z(最终答案)保持正序:Z2030。模型只需要学会把反转的结果翻转回来。
全量 9801 评估:99.80%,只有 20 个错误。 而且——
乘法步骤错误 (S1/S2): 19 个
加法步骤错误 (A1): 0 个 ← 加法错误归零!
反转彻底解决了加法问题。 剩下的 19 个错误全是乘法事实错误。
89*7=623, 模型输出 603 ← 乘法事实记错了
89*17: S2:89*7=603 → 错 ← 同一个错误的级联
89*27: S2:89*7=603 → 错
89*37: S2:89*7=603 → 错
... (共 19 个全是 89*7 的级联)
去训练数据里查——89 出现了 21 次,子步骤覆盖了 89×1,2,4,5,6,8,9,唯独没有 89×7 和 89×3。
3000 条随机数据,每个 a 平均出现 30 次,但受随机性影响,某些 a×d 组合可能完全缺失。
和加法练习同理——加入 891 条乘法基本功(99 × 9 = 891 种 a×d 组合):
89*7=S1:89*7=623;Z623 ← 乘法基本功(CoT 格式)
45*3=S1:45*3=135;Z135 ← 乘法基本功
45*67=S1:45*60=2700;S2:... ← 完整乘法 CoT
⚠️ 踩坑记录: 最初乘法基本功用的是简单格式
89*7=623,结果和 CoT 格式89*7=S1:89*7=623;Z623冲突——模型不知道该用哪种格式,准确率反而从 99% 降到 90%。统一用 CoT 格式后解决。
最终全量评估:
🏆 最终结果: 9801/9801 = 100.00%
错误数: 0
🎉🎉🎉 完美! 全部 9801 个组合 100% 正确!
最终方案的训练数据由三部分组成,各司其职:
┌─────────────────────────────────────────────┐
│ 完整乘法 CoT × 3000 条 │ ← 学"怎么拆、怎么串"
│ 45*67=S1:45*60=2700;S2:45*7=315; │
│ A1:2700+315=5103;Z3015 │
├─────────────────────────────────────────────┤
│ 乘法基本功 × 891 条 │ ← 学"怎么乘" (原子事实)
│ 89*7=S1:89*7=623;Z623 │
│ 覆盖全部 a×d (a=1-99, d=1-9) │
├─────────────────────────────────────────────┤
│ 加法练习 × 3000 条 │ ← 学"怎么加" (进位规则)
│ 2700+315=5103 (反转结果) │
│ 均衡覆盖0/1/2次进位 + 位数扩展 │
└─────────────────────────────────────────────┘
总计: 6840 条
模型做 45*67 时,实际执行三个原子操作:
| 原子技能 | 数量 | 训练来源 |
|---|---|---|
| 拆解 b=67 → 60 和 7 | 固定规则 | 3000 条 CoT 学格式 |
| 乘法: 45×60=2700, 45×7=315 | 891 种 | 乘法基本功全覆盖 |
| 加法: 2700+315=3015 | 通用规则 | 3000 条加法练习 |
3000 条 CoT 教"组装流程",891+3000 条练习教"原子技能"。
这就是为什么 30% 的数据能泛化到 100%——模型不是在背 9801 个答案,而是学会了 3 个可组合的技能。
没有 CoT,模型必须记忆 9801 个映射关系。有了 CoT,模型只需要学会 ~891 个乘法事实 + 加法规则 + 组装格式。
加法进位从右到左,但 Transformer 从左到右生成。反转加法结果让两个方向一致,加法错误从 377 个直降到 0。
3000 条随机数据覆盖 30% 的乘法组合,但某些 a×d 原子事实可能完全缺失。加入 891 条乘法基本功保证每个原子事实至少出现一次。
乘法练习用 89*7=623 和 CoT 用 89*7=S1:89*7=623;Z623 会冲突——同一个输入有两种输出格式,模型不知道选哪个。所有数据统一用 CoT 格式。
加法练习太少(500条)不够,太多(5000条)反而挤占乘法学习。最佳配比约 1:1(乘法 3000 : 加法 3000)。
| 版本 | 训练数据 | 500样本 | 全量9801 | 关键变化 |
|---|---|---|---|---|
| v1 朴素 CoT | 3000 乘法 | 87% | - | 基线 |
| v2 +加法500 | 3000+500 | 94% | - | 加法练习有效 |
| v3 +加法3000 | 3000+3000 | 97.2% | 95.8% | 最佳加法量 |
| v4 +加法5000 | 3000+5000 | 96.6% | - | 过多反而下降 |
| v5 反转 | 3000+3000(反转) | 99.8% | 99.8% | 加法错误→0 |
| v6 +乘法基本功 | 3000+891+3000 | 100% | 100% | 乘法错误→0 |
ch1_cot_basics/
├── generate_data.py # 数据生成器(CoT + 加法练习 + 乘法基本功)
├── train.py # 训练脚本(Teacher Forcing + 自回归评估)
├── eval_full.py # 全量 9801 评估脚本(含错误分析)
├── infer.py # 推理脚本(交互式测试)
├── data_rev3/ # 最终训练数据
│ ├── train.txt # 6840 条训练数据
│ ├── test.txt # 500 条测试数据
│ └── train_pairs.txt # 训练用的 (a,b) 对
└── checkpoints_rev3/ # 模型检查点
├── best.pt # 最佳模型
└── latest.pt # 最终模型
为什么 1000 条 CoT 不够? 我们试过只用 1000 条 CoT + 891 乘法基本功,准确率卡在 64%。CoT 不仅教原子事实,还教"组装规则"——如何拆解 b、如何连接 S→A→Z。1000 条不够学会足够多的组装模式。
反转加法结果后,Z 答案为什么还是正序? Z 是最终输出给用户的答案,保持正序可读性好。模型需要学会"把反转的数字翻转回来"——这对 Transformer 来说是一个简单的 copy+reverse 操作,比"预判进位"容易得多。
这个方法能扩展到三位数乘法吗? 可以,但需要解决连加问题——A1 的反转结果要作为 A2 的输入。这是 Ch2 要探索的内容。
Ch2: RoPE 位置编码 — 当序列变长(三位数乘法的 CoT 可达 100+ token),位置编码成为关键。我们将探索旋转位置编码如何帮助模型处理更长的推理链。