目标: 从零实现现代前沿架构的 MLA + MoE,以及进阶的 CSA + HCA + mHC 核心发现: MLA 用更少内存反而学得更好;V4 的 CSA/HCA 为百万 token 设计,在短序列上反而拖慢收敛 前置: Ch5 的 4d 乘法 100%
Ch5 验证了标准 Transformer 可以 100% 学会 4 位数乘法。但标准架构有两个根本瓶颈:
现代前沿架构用两个核心创新解决了这两个问题: 1. MLA (Multi-head Latent Attention):把 K、V 压缩到低维 latent 空间再缓存 2. MoE (Mixture of Experts):多个专家 FFN 替代单一 FFN,每个 token 只激活 Top-K 个
回顾 Ch4,KV Cache 的原理是:推理时把每层算过的 K、V 存起来,下次直接复用,避免重复计算。
但存什么?标准 MHA 存的是完整的 K 和 V 矩阵:
标准 MHA 每层每 token 缓存量:
K: n_heads × d_head = 8 × 32 = 256 floats
V: n_heads × d_head = 8 × 32 = 256 floats
合计: 512 floats × 4 bytes = 2048 bytes
6 层 × 750 tokens (4d乘法):
= 6 × 750 × 2048 = 9.2 MB
看起来不大?但想象一下 GPT-4 级别:
128 层 × 128K tokens × 2048 bytes = 33 GB ← 仅 KV Cache 就爆显存!
MLA 的灵感来自一个简单的观察:K 和 V 的信息是高度冗余的。
8 个 head 各有 32 维的 K 和 V,总共 512 维。但这 512 维里大部分信息是重复的! 就像一张 1000×1000 的图片,虽然有 100 万个像素,但用 JPEG 压缩后可能只需要 10KB——因为相邻像素高度相关。
MLA 做的事情就是对 KV 做"JPEG 压缩":
# 标准 MHA: 直接投影出完整的 K, V,然后缓存
K = W_k(x) # [d_model] → [n_heads × d_head] = 256 维
V = W_v(x) # [d_model] → [n_heads × d_head] = 256 维
cache = (K, V) # 缓存 512 维 ← 太大了!
# MLA: 先压缩到低维 latent,只缓存 latent
c_kv = W_compress(x) # [d_model] → [d_c] = 64 维 ← 压缩 4x!
cache = c_kv # 只缓存 64 维 ← 省了很多!
# 需要 K, V 时,从 latent 临时"解压"出来
K = W_decompress_k(c_kv) # [d_c] → [n_heads × d_head]
V = W_decompress_v(c_kv) # [d_c] → [n_heads × d_head]
但这里有个问题:RoPE 位置编码需要对 K 做旋转。如果直接在 latent 上做 RoPE,位置信息就会"污染" latent,降低压缩效率(因为不同位置的 token 即使内容相同,latent 也不一样了)。
工业界 的解法很精妙:把 RoPE 信息独立出来。
# 从输入额外生成一小段专门做 RoPE 的 K
K_rope = W_k_rope(x) # [d_model] → [n_heads × d_rope] = 128 维
K_rope = apply_RoPE(K_rope) # 位置编码只在这里做
# 最终的 K = [K_content | K_rope](拼接)
# K_content 从无位置信息的 latent 解压得到
# K_rope 独立携带位置信息
这样,latent 保持"纯内容"不含位置信息,压缩效率最高。
| 标准 MHA | MLA | |
|---|---|---|
| 缓存内容 | 完整 K + V | latent + K_rope |
| 每 token 维度 | 512 | 64 + 128 = 192 |
| 压缩比 | 1x | 2.7x |
| 6层×750token | 9.2 MB | 3.5 MB |
标准 Transformer 的 FFN 层是一个巨大的全连接网络:
FFN: d_model → d_ff → d_model
256 → 1024 → 256
参数: 256×1024 + 1024×256 = 524,288
每个 token 都要过同一个 FFN。想要更强的模型?只能把 FFN 做大,但计算量也等比增大。
有没有办法"参数翻倍但计算量不变"?
把 1 个大 FFN 拆成 N 个小 FFN("专家"),每个 token 只激活其中 K 个:
# 标准 FFN: 1 个大专家,所有 token 都用
output = FFN(x) # 100% 参数都激活
# MoE FFN: 4 个小专家,每个 token 只选 2 个
scores = softmax(W_gate(x)) # 路由分数: [4]
top2 = topk(scores, k=2) # 选出分数最高的 2 个专家
output = score_1 × Expert_1(x) + score_2 × Expert_2(x)
| 标准 FFN | MoE (4专家, Top-2) | |
|---|---|---|
| 总参数 | 524K | 4 × 262K = 1,049K |
| 每 token 激活 | 524K (100%) | 2 × 262K = 524K (50%) |
| 容量 | 1x | 2x |
| 计算量 | 1x | 1x (不变!) |
MoE 有个致命问题:路由崩塌 (Route Collapse)。
如果不加控制,路由器会"偷懒"——永远只选那 1-2 个表现最好的专家,其他专家永远得不到训练,最终退化为普通 FFN。
传统方案是加一个辅助 loss 来惩罚不均衡:
# 传统方案: 在主 loss 上加 penalty
total_loss = language_loss + α × load_balance_loss
# 问题: α 很难调, 而且会"污染"主 loss 的梯度信号
前沿架构 的创新:完全不用辅助 loss,而是在梯度之外手动调 bias:
# 工业界 方案: 在 softmax 之前加一个可调的 bias
scores = softmax(W_gate(x) + gate_bias) # gate_bias 不参与梯度!
# 每个训练 step 结束后:
for each expert:
if expert 被过度使用:
gate_bias[expert] -= γ # 降低选它的概率
if expert 被冷落:
gate_bias[expert] += γ # 提高选它的概率
简单粗暴但极其有效——完全不干扰主 loss 的优化过程。
我们用 4 种排列组合来精确判断每个创新各自的贡献:
| 配置 | Attention | FFN | 说明 |
|---|---|---|---|
| baseline | 标准 MHA | 标准 FFN | Ch5 的 Standard 模型 |
| mla_only | MLA | 标准 FFN | 只换注意力 |
| moe_only | 标准 MHA | MoE | 只换 FFN |
| full_ds | MLA | MoE | 完整 工业界 风格 |
所有配置使用相同的训练数据(135,072 条)、相同的超参数、从零初始化,训练 5 个 Epoch。
配置 | Epoch 1 (4d) | Epoch 3 (4d) | Epoch 5 (4d) | 最终 Loss
------------+--------------+--------------+--------------+----------
baseline | 92% | 96% | 97% | 0.0555
mla_only | 56% | 100% ✨ | 100% ✨ | 0.0561
moe_only | 65% | 89% | 94% | 0.0556
full_ds | 84% | 99% | 100% ✨ | 0.0560
关键发现: - MLA 是最大赢家:Epoch 3 就达到 100%,比 baseline 更快、更好。 - MoE 初期起步最慢(Epoch 1 只有 65%),因为 4 个专家需要时间"分工"。 - full_ds 综合了两者优势,最终也达到 100%。
配置 | 总参数 | KV Cache/token | 压缩比
------------+----------+----------------+-------
baseline | 4,743K | 12,288 bytes | 1x
mla_only | 4,645K | 4,608 bytes | 2.7x ← 省 63%!
moe_only | 7,906K | 12,288 bytes | 1x
full_ds | 7,807K | 4,608 bytes | 2.7x ← 省 63%!
MLA 用更少的参数 + 更少的内存,学到了更好的结果。 这违反直觉,但背后的原理很深刻:低秩压缩迫使模型用最少的维度表达 K、V 的核心语义,去除了冗余,反而提升了泛化能力。
所有 Epoch 的专家负载: [25% 25% 25% 25%]
完美均衡! 工业界 的 auxiliary-loss-free 动态 bias 方案工作得极其出色,4 个专家被完全平等地使用。
标准 MHA 中,Q 和 K 的最后一维都是 d_head,可以直接用 PyTorch 内置的 scaled_dot_product_attention。
但 MLA 中,Q 和 K 的最后一维是 d_head + d_rope(因为拼接了 RoPE 部分),而 V 的维度仍然是 d_head。这导致无法使用 PyTorch 融合 kernel,必须手写 attention:
# 不能用 F.scaled_dot_product_attention(Q,K,V 最后一维必须相同)
# 必须手写:
scores = Q @ K.T * scale
attn = softmax(scores)
out = attn @ V # V 的维度和 Q,K 不同,但这里是矩阵乘法,没问题
代价:MLA 的训练速度比标准 MHA 慢 2.7 倍(1410s vs 522s per epoch)。 实际工程中需要用自定义 CUDA kernel 来解决。
MoE 需要对每个 token 独立计算路由分数并选择专家,这涉及大量的 scatter/gather 操作,在 GPU 上效率不高。我们的实现用 Python for 循环遍历专家,是最慢的方式。
工业界 实际使用了高度优化的 kernel,将路由和专家计算融合在一起。
训练时每 Epoch 只随机抽 200 题评估,噪声很大。真正的公平对比需要大样本。我们用固定 seed=999 的 5000 道随机 4d×4d 乘法题做终极测试:
配置 | 正确 | 准确率 | KV Cache
-------------+--------+---------+----------
baseline | 4843 | 96.86% | 12,288 B/tok
mla_only | 4988 | 99.76% | 4,608 B/tok
moe_only | 4844 | 96.88% | 12,288 B/tok
full_ds | 4989 | 99.78% | 4,608 B/tok
MLA 的优势在大样本下更加显著:错误数从 157 降到 12(↓92%),而且内存还省了 63%。
5000 题测试后,我们精确知道每个模型错在哪里。这就像考试后拿到了错题本。
在大模型的实际开发中,基础训练(Pre-training + SFT)完成后,模型已经具备了 95%+ 的能力,但在某些特定场景(如数学推理、代码生成)还有弱点。此时不需要重新训练,只需要"定向修补"——用错题及其同类型数据做微量微调。
这就是 Error-Driven Targeted SFT,是大模型后训练阶段的核心技术之一。
┌─────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌─────────┐
│ 已训练的 │───→│ 5000题 │───→│ 错题扩增 │───→│ 低LR │───→│ 重新 │
│ 模型 │ │ 找错题 │ │ 同类型×10│ │ 微调3轮 │ │ 5000测试│
└─────────┘ └──────────┘ └──────────┘ └──────────┘ └─────────┘
96.86% 157个错 2308条 27秒 99.96%!
# 跑一遍 5000 题,收集所有错误的 (a, b) 对
correct, errors = find_errors(model, tokenizer, device, n=5000)
# baseline: 157 个错题
# mla_only: 12 个错题
关键思想:不只训练原题,而是训练"同类型"的题。
为什么?如果只训练那 157 道原题,模型会"死记硬背"这 157 个特定答案,但遇到相似的新题照样错。我们需要让模型学会的是处理这类数字组合的通用能力。
def generate_targeted_data(errors, expand_factor=10):
"""每道错题 → 10 道同类型题"""
targeted = []
for a, b in errors:
targeted.append(generate_cot(a, b)) # 原题
for _ in range(expand_factor):
# 分析 b 的数字特征
b_str = str(b)
has_trailing_zero = b_str.endswith('0') # 如 4010
has_internal_zero = '0' in b_str[1:-1] # 如 5008
has_one = '1' in b_str # 如 1015
# 生成同特征的新数字
if has_trailing_zero and has_one:
# X010 类型 — MLA 最弱的模式
new_b = random.randint(1,9)*1000 + random.choice([10,20,30])
elif has_trailing_zero:
# XX00 / X0X0 类型
new_b = random.randint(1,9)*1000 + random.randint(0,9)*100
# ... 其他模式 ...
targeted.append(generate_cot(new_a, new_b))
return targeted # 157 题 → 1727 条
这是微调中最关键的技巧之一:灾难性遗忘 (Catastrophic Forgetting)。
如果只用错题数据微调,模型会"学了新的、忘了旧的"。解决方案很简单——混入少量通用数据:
anti_forget = generate_anti_forgetting_data(n=581) # ~20% 的混合比例
all_data = targeted + anti_forget # 2308 条
random.shuffle(all_data)
比例选择经验:通用数据占总微调数据的 15-25%。太少会遗忘,太多会稀释修补效果。
def targeted_finetune(model, data, tokenizer, device,
epochs=3, lr=1e-5, batch_size=32):
# 关键 1: 冻结 Embedding 层
# 数字 0-9 的基础表示已经学好了,动它只会添乱
for p in model.tok_emb.parameters():
p.requires_grad = False
# 关键 2: 极低学习率 (1e-5 vs 训练时的 3e-4,差 30 倍)
# 我们要的是"微调"不是"重训"——轻轻推一把,不要把已有能力搞崩
optimizer = torch.optim.AdamW(trainable, lr=1e-5)
# 关键 3: 只跑 3 个 Epoch
# 微调数据量很小 (2308 vs 训练时的 135K)
# 跑太多会过拟合到这些特定模式上
for epoch in range(3):
for batch in loader:
loss = F.cross_entropy(...)
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, 1.0)
optimizer.step()
三个参数的关系:
| 参数 | 训练阶段 | 微调阶段 | 为什么? |
|---|---|---|---|
| 学习率 | 3e-4 | 1e-5 | 微调只需要微小的权重调整 |
| 数据量 | 135,072 | ~2,300 | 只需要修补弱点 |
| Epochs | 5 | 3 | 数据量小,多跑会过拟合 |
| Embedding | 训练 | 冻结 | 基础表示已经学好 |
配置 | 微调前 | 微调后 | 提升 | 数据量 | 耗时
-------------+---------+----------+----------+--------+-----
baseline | 96.86% | 99.96% | +155题 | 2308条 | 27s
mla_only | 99.76% | 100.00% | +12题 | 713条 | 18s
moe_only | 96.88% | 99.96% | +154题 | 2297条 | 33s
full_ds | 99.78% | 99.76% | -1题 ⚠️ | 702条 | 21s
惊人的效率! baseline 从 96.86% 跳到 99.96%,只用了: - 训练数据的 1.7%(2308 / 135072) - 训练时间的 1%(27 秒 / 43 分钟) - MLA 直接达到 100%! 零错题!
注意 full_ds 的结果:修了 11 个旧错题,但新冒出来 12 个!准确率反而下降了 0.02%。
这是大模型微调中经典的水床效应:像按水床一样,按下去一个地方,另一个地方就鼓起来。
full_ds 微调前: 错 {A, B, C, D, E, ...} 共 11 题
full_ds 微调后: 错 {F, G, H, I, J, ...} 共 12 题 ← 旧的修好了, 新的冒出来了!
原因分析:full_ds 是 MLA + MoE 的组合结构。微调改变了 attention 权重后,MoE 路由器的分配也跟着变了——原本路由到 Expert 1 的 token 现在被送到了 Expert 2。这种路由漂移导致了新的错误。
解法(工业界实践): 1. 多轮迭代微调:第一轮修完后重新找错题,再做第二轮,如此反复直到收敛 2. 增大防遗忘数据比例:从 20% 提高到 50%,用更多通用数据"锚定"模型 3. 冻结 MoE 路由器:只调专家权重,不动路由决策
DeepSeek-V4 论文: "Towards Highly Efficient Million-Token Context Intelligence"
V3 MLA 压缩 head 维度 (256->64),序列长度不变。V4 转向压缩 序列维度——上下文 1M tokens 时,即使每 token KV 已经很小,总量依然爆炸。
V3 MLA: 压缩 head 维度 256维->64维 KV 省 2.7x
V4 CSA: 压缩序列维度 4 tokens->1 KV 再省 4x
V4 HCA: 超重压缩序列 128 tokens->1 KV 再省 128x
V4-Pro @ 1M tokens: KV Cache 仅为标准 MHA 的 10%!
把远距离 KV 按组压缩,只挑最相关的几组做 attention:
def _compress_kv(self, K, V):
K_g = K.reshape(B, H, n_groups, gs, D) # 按 group_size=4 分组
w_k = softmax(self.gate_k(K_g)) # 学习每 token 重要性
K_c = (w_k * K_g).sum(dim=3) # 加权压缩: S/4 条目
# Lightning Indexer 选 Top-k 最相关块 + 滑动窗口做 attention
比 CSA 更激进: 16 tokens -> 1 (V4 原版 128:1)。压缩后序列极短,直接 dense attention。作用: 全局上下文摘要。
替代标准残差 x = x + Layer(x): 维护多个"通道",每层用正交矩阵混合,防止深层信号衰减/爆炸。
Layer 0: HCA (全局摘要) → Layer 1: CSA (细粒度选择)
Layer 2: HCA → Layer 3: CSA
Layer 4: HCA → Layer 5: CSA
配置 | 参数量 | Ep1 loss | Ep3 loss | 速度
-------------+----------+----------+----------+---------
baseline | 4,743K | 0.32 | 0.06 | 8 min/ep
v4_hybrid | 7,991K | 0.48 | 0.03 | 33 min/ep
核心教训: V4 的 CSA/HCA 是为 百万 token 设计的。在 750 token 的算术任务上,压缩路径根本用不上(回退到标准 attention),真正影响训练的只有 mHC 超连接——它改变了收敛动力学,使训练变慢。架构创新必须匹配问题规模。
| 文件 | 职责 |
|---|---|
model.py |
🧠 V3 (MLA+MoE) + V4 (CSA+HCA+mHC) 统一解码器 |
train.py |
📊 V3 4配置 A/B 对比训练 |
train_v4.py |
📊 V4 配置训练 + 5000题测试 |
eval_5000.py |
🎯 5000 题极限测试 |
error_analysis.py |
🔬 错题分析 |
finetune_errors.py |
🔧 错题本微调 |
# Part 1: V3
python train.py # V3 4配置训练 (~3h)
python eval_5000.py # 5000 题大测
python finetune_errors.py # 错题本微调
# Part 2: V4
python train_v4.py # V4 配置训练