NanoTransformer Courses

🏠 Back to Lab 📚 All Courses

Ch6: 前沿架构探索 (MoE与MLA)

目标: 从零实现现代前沿架构的 MLA + MoE,以及进阶的 CSA + HCA + mHC 核心发现: MLA 用更少内存反而学得更好;V4 的 CSA/HCA 为百万 token 设计,在短序列上反而拖慢收敛 前置: Ch5 的 4d 乘法 100%


🧭 本章概览

Ch5 验证了标准 Transformer 可以 100% 学会 4 位数乘法。但标准架构有两个根本瓶颈:

  1. KV Cache 太大:每层每 token 缓存完整的 K 和 V(512 个 float),长序列时内存线性爆炸
  2. 参数效率低:FFN 对每个 token 使用全部参数,想提升容量就必须等比增加计算量

现代前沿架构用两个核心创新解决了这两个问题: 1. MLA (Multi-head Latent Attention):把 K、V 压缩到低维 latent 空间再缓存 2. MoE (Mixture of Experts):多个专家 FFN 替代单一 FFN,每个 token 只激活 Top-K 个


🔑 核心技术 1: MLA (Multi-head Latent Attention)

问题:KV Cache 的内存困境

回顾 Ch4,KV Cache 的原理是:推理时把每层算过的 K、V 存起来,下次直接复用,避免重复计算。

但存什么?标准 MHA 存的是完整的 K 和 V 矩阵

标准 MHA 每层每 token 缓存量:
  K: n_heads × d_head = 8 × 32 = 256 floats
  V: n_heads × d_head = 8 × 32 = 256 floats
  合计: 512 floats × 4 bytes = 2048 bytes

6 层 × 750 tokens (4d乘法):
  = 6 × 750 × 2048 = 9.2 MB

看起来不大?但想象一下 GPT-4 级别:

128 层 × 128K tokens × 2048 bytes = 33 GB ← 仅 KV Cache 就爆显存!

MLA 的核心思想:压缩-缓存-解压

MLA 的灵感来自一个简单的观察:K 和 V 的信息是高度冗余的

8 个 head 各有 32 维的 K 和 V,总共 512 维。但这 512 维里大部分信息是重复的! 就像一张 1000×1000 的图片,虽然有 100 万个像素,但用 JPEG 压缩后可能只需要 10KB——因为相邻像素高度相关。

MLA 做的事情就是对 KV 做"JPEG 压缩":

# 标准 MHA: 直接投影出完整的 K, V,然后缓存
K = W_k(x)  # [d_model] → [n_heads × d_head] = 256 维
V = W_v(x)  # [d_model] → [n_heads × d_head] = 256 维
cache = (K, V)  # 缓存 512 维 ← 太大了!

# MLA: 先压缩到低维 latent,只缓存 latent
c_kv = W_compress(x)  # [d_model] → [d_c] = 64 维 ← 压缩 4x!
cache = c_kv           # 只缓存 64 维 ← 省了很多!

# 需要 K, V 时,从 latent 临时"解压"出来
K = W_decompress_k(c_kv)  # [d_c] → [n_heads × d_head]
V = W_decompress_v(c_kv)  # [d_c] → [n_heads × d_head]

Decoupled RoPE:位置信息的分离

但这里有个问题:RoPE 位置编码需要对 K 做旋转。如果直接在 latent 上做 RoPE,位置信息就会"污染" latent,降低压缩效率(因为不同位置的 token 即使内容相同,latent 也不一样了)。

工业界 的解法很精妙:把 RoPE 信息独立出来

# 从输入额外生成一小段专门做 RoPE 的 K
K_rope = W_k_rope(x)   # [d_model] → [n_heads × d_rope] = 128 维
K_rope = apply_RoPE(K_rope)  # 位置编码只在这里做

# 最终的 K = [K_content | K_rope](拼接)
# K_content 从无位置信息的 latent 解压得到
# K_rope 独立携带位置信息

这样,latent 保持"纯内容"不含位置信息,压缩效率最高。

内存对比

标准 MHA MLA
缓存内容 完整 K + V latent + K_rope
每 token 维度 512 64 + 128 = 192
压缩比 1x 2.7x
6层×750token 9.2 MB 3.5 MB

🔑 核心技术 2: MoE (Mixture of Experts)

问题:参数效率的困境

标准 Transformer 的 FFN 层是一个巨大的全连接网络:

FFN: d_model → d_ff → d_model
     256 → 1024 → 256
     参数: 256×1024 + 1024×256 = 524,288

每个 token 都要过同一个 FFN。想要更强的模型?只能把 FFN 做大,但计算量也等比增大。

有没有办法"参数翻倍但计算量不变"?

MoE 的核心思想:条件计算

把 1 个大 FFN 拆成 N 个小 FFN("专家"),每个 token 只激活其中 K 个

# 标准 FFN: 1 个大专家,所有 token 都用
output = FFN(x)  # 100% 参数都激活

# MoE FFN: 4 个小专家,每个 token 只选 2 个
scores = softmax(W_gate(x))           # 路由分数: [4]
top2 = topk(scores, k=2)              # 选出分数最高的 2 个专家
output = score_1 × Expert_1(x) + score_2 × Expert_2(x)
标准 FFN MoE (4专家, Top-2)
总参数 524K 4 × 262K = 1,049K
每 token 激活 524K (100%) 2 × 262K = 524K (50%)
容量 1x 2x
计算量 1x 1x (不变!)

Auxiliary-Loss-Free 负载均衡

MoE 有个致命问题:路由崩塌 (Route Collapse)

如果不加控制,路由器会"偷懒"——永远只选那 1-2 个表现最好的专家,其他专家永远得不到训练,最终退化为普通 FFN。

传统方案是加一个辅助 loss 来惩罚不均衡:

# 传统方案: 在主 loss 上加 penalty
total_loss = language_loss + α × load_balance_loss
# 问题: α 很难调, 而且会"污染"主 loss 的梯度信号

前沿架构 的创新:完全不用辅助 loss,而是在梯度之外手动调 bias:

# 工业界 方案: 在 softmax 之前加一个可调的 bias
scores = softmax(W_gate(x) + gate_bias)  # gate_bias 不参与梯度!

# 每个训练 step 结束后:
for each expert:
    if expert 被过度使用:
        gate_bias[expert] -= γ    # 降低选它的概率
    if expert 被冷落:
        gate_bias[expert] += γ    # 提高选它的概率

简单粗暴但极其有效——完全不干扰主 loss 的优化过程。


📊 4 配置 A/B 对比实验

实验设计

我们用 4 种排列组合来精确判断每个创新各自的贡献:

配置 Attention FFN 说明
baseline 标准 MHA 标准 FFN Ch5 的 Standard 模型
mla_only MLA 标准 FFN 只换注意力
moe_only 标准 MHA MoE 只换 FFN
full_ds MLA MoE 完整 工业界 风格

所有配置使用相同的训练数据(135,072 条)、相同的超参数、从零初始化,训练 5 个 Epoch。

结果 1: 准确率对比

  配置        | Epoch 1 (4d) | Epoch 3 (4d) | Epoch 5 (4d) | 最终 Loss
  ------------+--------------+--------------+--------------+----------
  baseline    |     92%      |     96%      |     97%      |  0.0555
  mla_only    |     56%      |    100% ✨   |    100% ✨   |  0.0561
  moe_only    |     65%      |     89%      |     94%      |  0.0556
  full_ds     |     84%      |     99%      |    100% ✨   |  0.0560

关键发现: - MLA 是最大赢家:Epoch 3 就达到 100%,比 baseline 更快、更好。 - MoE 初期起步最慢(Epoch 1 只有 65%),因为 4 个专家需要时间"分工"。 - full_ds 综合了两者优势,最终也达到 100%。

结果 2: 内存效率

  配置        |  总参数    | KV Cache/token | 压缩比
  ------------+----------+----------------+-------
  baseline    |  4,743K  |   12,288 bytes |   1x
  mla_only    |  4,645K  |    4,608 bytes |  2.7x ← 省 63%!
  moe_only    |  7,906K  |   12,288 bytes |   1x
  full_ds     |  7,807K  |    4,608 bytes |  2.7x ← 省 63%!

MLA 用更少的参数 + 更少的内存,学到了更好的结果。 这违反直觉,但背后的原理很深刻:低秩压缩迫使模型用最少的维度表达 K、V 的核心语义,去除了冗余,反而提升了泛化能力。

结果 3: MoE 专家负载均衡

  所有 Epoch 的专家负载: [25% 25% 25% 25%]

完美均衡! 工业界 的 auxiliary-loss-free 动态 bias 方案工作得极其出色,4 个专家被完全平等地使用。


🔍 工程踩坑

坑 1: MLA 的 Q/K 维度不对齐

标准 MHA 中,Q 和 K 的最后一维都是 d_head,可以直接用 PyTorch 内置的 scaled_dot_product_attention

但 MLA 中,Q 和 K 的最后一维是 d_head + d_rope(因为拼接了 RoPE 部分),而 V 的维度仍然是 d_head。这导致无法使用 PyTorch 融合 kernel,必须手写 attention:

# 不能用 F.scaled_dot_product_attention(Q,K,V 最后一维必须相同)
# 必须手写:
scores = Q @ K.T * scale
attn = softmax(scores)
out = attn @ V  # V 的维度和 Q,K 不同,但这里是矩阵乘法,没问题

代价:MLA 的训练速度比标准 MHA 慢 2.7 倍(1410s vs 522s per epoch)。 实际工程中需要用自定义 CUDA kernel 来解决。

坑 2: MoE 的 Token 级路由开销

MoE 需要对每个 token 独立计算路由分数并选择专家,这涉及大量的 scatter/gather 操作,在 GPU 上效率不高。我们的实现用 Python for 循环遍历专家,是最慢的方式。

工业界 实际使用了高度优化的 kernel,将路由和专家计算融合在一起。


📊 5000 题极限测试

训练时每 Epoch 只随机抽 200 题评估,噪声很大。真正的公平对比需要大样本。我们用固定 seed=999 的 5000 道随机 4d×4d 乘法题做终极测试:

  配置         | 正确    | 准确率   | KV Cache
  -------------+--------+---------+----------
  baseline     | 4843   |  96.86% | 12,288 B/tok
  mla_only     | 4988   |  99.76% |  4,608 B/tok
  moe_only     | 4844   |  96.88% | 12,288 B/tok
  full_ds      | 4989   |  99.78% |  4,608 B/tok

MLA 的优势在大样本下更加显著:错误数从 157 降到 12(↓92%),而且内存还省了 63%。


🔧 核心技术 3: 错题本微调 (Error-Driven Fine-Tuning)

问题:大模型训练后如何修补弱点?

5000 题测试后,我们精确知道每个模型错在哪里。这就像考试后拿到了错题本。

在大模型的实际开发中,基础训练(Pre-training + SFT)完成后,模型已经具备了 95%+ 的能力,但在某些特定场景(如数学推理、代码生成)还有弱点。此时不需要重新训练,只需要"定向修补"——用错题及其同类型数据做微量微调。

这就是 Error-Driven Targeted SFT,是大模型后训练阶段的核心技术之一。

微调流程总览

┌─────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌─────────┐
│ 已训练的 │───→│ 5000题   │───→│ 错题扩增 │───→│ 低LR     │───→│ 重新    │
│ 模型     │    │ 找错题   │    │ 同类型×10│    │ 微调3轮  │    │ 5000测试│
└─────────┘    └──────────┘    └──────────┘    └──────────┘    └─────────┘
   96.86%         157个错          2308条          27秒          99.96%!

Step 1: 找错题 — 精准定位弱点

# 跑一遍 5000 题,收集所有错误的 (a, b) 对
correct, errors = find_errors(model, tokenizer, device, n=5000)
# baseline: 157 个错题
# mla_only: 12 个错题

Step 2: 错题扩增 — 生成同类型数据

关键思想:不只训练原题,而是训练"同类型"的题

为什么?如果只训练那 157 道原题,模型会"死记硬背"这 157 个特定答案,但遇到相似的新题照样错。我们需要让模型学会的是处理这类数字组合的通用能力

def generate_targeted_data(errors, expand_factor=10):
    """每道错题 → 10 道同类型题"""
    targeted = []

    for a, b in errors:
        targeted.append(generate_cot(a, b))  # 原题

        for _ in range(expand_factor):
            # 分析 b 的数字特征
            b_str = str(b)
            has_trailing_zero = b_str.endswith('0')   # 如 4010
            has_internal_zero = '0' in b_str[1:-1]    # 如 5008
            has_one = '1' in b_str                     # 如 1015

            # 生成同特征的新数字
            if has_trailing_zero and has_one:
                # X010 类型 — MLA 最弱的模式
                new_b = random.randint(1,9)*1000 + random.choice([10,20,30])
            elif has_trailing_zero:
                # XX00 / X0X0 类型
                new_b = random.randint(1,9)*1000 + random.randint(0,9)*100
            # ... 其他模式 ...

            targeted.append(generate_cot(new_a, new_b))

    return targeted  # 157 题 → 1727 条

Step 3: 防遗忘混合 — 少量通用数据

这是微调中最关键的技巧之一:灾难性遗忘 (Catastrophic Forgetting)

如果只用错题数据微调,模型会"学了新的、忘了旧的"。解决方案很简单——混入少量通用数据:

anti_forget = generate_anti_forgetting_data(n=581)  # ~20% 的混合比例
all_data = targeted + anti_forget  # 2308 条
random.shuffle(all_data)

比例选择经验:通用数据占总微调数据的 15-25%。太少会遗忘,太多会稀释修补效果。

Step 4: 低学习率微调 — 三个关键参数

def targeted_finetune(model, data, tokenizer, device,
                      epochs=3, lr=1e-5, batch_size=32):
    # 关键 1: 冻结 Embedding 层
    # 数字 0-9 的基础表示已经学好了,动它只会添乱
    for p in model.tok_emb.parameters():
        p.requires_grad = False

    # 关键 2: 极低学习率 (1e-5 vs 训练时的 3e-4,差 30 倍)
    # 我们要的是"微调"不是"重训"——轻轻推一把,不要把已有能力搞崩
    optimizer = torch.optim.AdamW(trainable, lr=1e-5)

    # 关键 3: 只跑 3 个 Epoch
    # 微调数据量很小 (2308 vs 训练时的 135K)
    # 跑太多会过拟合到这些特定模式上
    for epoch in range(3):
        for batch in loader:
            loss = F.cross_entropy(...)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(trainable, 1.0)
            optimizer.step()

三个参数的关系:

参数 训练阶段 微调阶段 为什么?
学习率 3e-4 1e-5 微调只需要微小的权重调整
数据量 135,072 ~2,300 只需要修补弱点
Epochs 5 3 数据量小,多跑会过拟合
Embedding 训练 冻结 基础表示已经学好

微调结果

  配置         | 微调前   | 微调后    | 提升      | 数据量  | 耗时
  -------------+---------+----------+----------+--------+-----
  baseline     |  96.86% |  99.96%  | +155题   | 2308条 | 27s
  mla_only     |  99.76% | 100.00%  | +12题    |  713条 | 18s
  moe_only     |  96.88% |  99.96%  | +154题   | 2297条 | 33s
  full_ds      |  99.78% |  99.76%  |  -1题 ⚠️ |  702条 | 21s

惊人的效率! baseline 从 96.86% 跳到 99.96%,只用了: - 训练数据的 1.7%(2308 / 135072) - 训练时间的 1%(27 秒 / 43 分钟) - MLA 直接达到 100%! 零错题!

水床效应 (Waterbed Effect) — full_ds 的教训

注意 full_ds 的结果:修了 11 个旧错题,但新冒出来 12 个!准确率反而下降了 0.02%。

这是大模型微调中经典的水床效应:像按水床一样,按下去一个地方,另一个地方就鼓起来。

full_ds 微调前: 错 {A, B, C, D, E, ...}  共 11 题
full_ds 微调后: 错 {F, G, H, I, J, ...}  共 12 题 ← 旧的修好了, 新的冒出来了!

原因分析:full_ds 是 MLA + MoE 的组合结构。微调改变了 attention 权重后,MoE 路由器的分配也跟着变了——原本路由到 Expert 1 的 token 现在被送到了 Expert 2。这种路由漂移导致了新的错误。

解法(工业界实践): 1. 多轮迭代微调:第一轮修完后重新找错题,再做第二轮,如此反复直到收敛 2. 增大防遗忘数据比例:从 20% 提高到 50%,用更多通用数据"锚定"模型 3. 冻结 MoE 路由器:只调专家权重,不动路由决策


🔬 核心技术 4: CSA + HCA (前沿架构, 2026.4.24)

DeepSeek-V4 论文: "Towards Highly Efficient Million-Token Context Intelligence"

V3 vs V4: 压缩方向不同

V3 MLA 压缩 head 维度 (256->64),序列长度不变。V4 转向压缩 序列维度——上下文 1M tokens 时,即使每 token KV 已经很小,总量依然爆炸。

V3 MLA: 压缩 head 维度    256维->64维     KV 省 2.7x
V4 CSA: 压缩序列维度      4 tokens->1     KV 再省 4x
V4 HCA: 超重压缩序列      128 tokens->1   KV 再省 128x

V4-Pro @ 1M tokens: KV Cache 仅为标准 MHA 的 10%!

CSA (Compressed Sparse Attention)

把远距离 KV 按组压缩,只挑最相关的几组做 attention:

def _compress_kv(self, K, V):
    K_g = K.reshape(B, H, n_groups, gs, D)  # 按 group_size=4 分组
    w_k = softmax(self.gate_k(K_g))          # 学习每 token 重要性
    K_c = (w_k * K_g).sum(dim=3)             # 加权压缩: S/4 条目

# Lightning Indexer 选 Top-k 最相关块 + 滑动窗口做 attention

HCA (Heavily Compressed Attention)

比 CSA 更激进: 16 tokens -> 1 (V4 原版 128:1)。压缩后序列极短,直接 dense attention。作用: 全局上下文摘要。

mHC (Manifold-Constrained Hyper-Connections)

替代标准残差 x = x + Layer(x): 维护多个"通道",每层用正交矩阵混合,防止深层信号衰减/爆炸。

V4 层排列: HCA/CSA 交替

Layer 0: HCA (全局摘要)  →  Layer 1: CSA (细粒度选择)
Layer 2: HCA             →  Layer 3: CSA
Layer 4: HCA             →  Layer 5: CSA

初步实验结果 (训练中)

  配置         |  参数量    | Ep1 loss | Ep3 loss | 速度
  -------------+----------+----------+----------+---------
  baseline     |  4,743K  |   0.32   |   0.06   |  8 min/ep
  v4_hybrid    |  7,991K  |   0.48   |   0.03   | 33 min/ep

核心教训: V4 的 CSA/HCA 是为 百万 token 设计的。在 750 token 的算术任务上,压缩路径根本用不上(回退到标准 attention),真正影响训练的只有 mHC 超连接——它改变了收敛动力学,使训练变慢。架构创新必须匹配问题规模。


📁 文件清单

文件 职责
model.py 🧠 V3 (MLA+MoE) + V4 (CSA+HCA+mHC) 统一解码器
train.py 📊 V3 4配置 A/B 对比训练
train_v4.py 📊 V4 配置训练 + 5000题测试
eval_5000.py 🎯 5000 题极限测试
error_analysis.py 🔬 错题分析
finetune_errors.py 🔧 错题本微调

🚀 使用方法

# Part 1: V3
python train.py                # V3 4配置训练 (~3h)
python eval_5000.py            # 5000 题大测
python finetune_errors.py      # 错题本微调

# Part 2: V4
python train_v4.py             # V4 配置训练

💡 关键认知总结

  1. MLA = 有损压缩换效率: 压缩迫使模型学更本质的特征,去冗余提泛化
  2. MoE = 条件计算: 2x 参数 1x 计算量,模型自动学分工
  3. 负载均衡不需要辅助 loss: 梯度外手动调 bias 更干净有效
  4. 错题本微调三板斧: 错题扩增 + 防遗忘混合 + 极低 LR,1.7% 数据修复 96.86% -> 99.96%
  5. 水床效应: MLA+MoE 微调导致路由漂移,修旧错冒新错
  6. V3 压 head,V4 压序列: 两个正交方向,V4 把 1M token KV Cache 压到 10%
  7. 架构必须匹配规模: CSA/HCA 在 750 token 上用不上,用核弹打蚊子适得其反
  8. mHC 改变收敛动力学: 更复杂的架构不等于更快的学习