目标: 从零实现 FlashAttention 分块算法,并用 A/B 实验验证 Kimi AttnRes 的层间聚合机制 核心发现: AttnRes 让梯度流更平滑(5-7 倍),并学会了 U 形聚合模式 前置: Ch4 的 KV Cache + 3d 乘法 100%
Ch4 实现了 KV Cache 加速推理。但还有两个重要问题没解决:
本章做两件事: 1. 从零手写 FlashAttention:用 tiling + online softmax 避免实例化 N×N 矩阵 2. 实现 Kimi AttnRes:用可学习权重替代固定残差,对比训练效果
标准 Self-Attention 需要实例化完整的 N×N 注意力矩阵:
# 标准实现
attn_weights = Q @ K.T / sqrt(d) # [N, N] ← 显存!
attn_probs = softmax(attn_weights) # [N, N] ← 又一份显存!
output = attn_probs @ V
| 序列长度 | 注意力矩阵大小 | 显存 (fp32, 8 heads) |
|---|---|---|
| 300 (3d CoT) | 300×300 | 2.7 MB |
| 700 (4d CoT) | 700×700 | 15 MB |
| 1500 (5d CoT) | 1500×1500 | 69 MB |
| 8192 (GPT-4) | 8192×8192 | 2 GB |
乘以 batch_size × n_layers,显存迅速爆炸。
FlashAttention 的核心思想:永远不实例化完整的 N×N 矩阵,而是分块计算并用在线算法累积结果。
标准: 一次算完 300×300 的大矩阵 → 显存 O(N²)
Flash: 分成 5×5=25 个 64×64 的小块,逐块计算累积 → 显存 O(N)
但 softmax 需要全局信息(分母是所有 exp 的求和),怎么分块算?
标准 softmax 需要先算完所有值才能归一化:
$$\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$
Online Softmax 的关键洞察:可以一边算一边更新。
处理每个新的 KV 块时,维护三个累积量:
- m: 已见过的最大值(数值稳定性)
- l: 分母累积 Σexp(x - m)
- O: 输出累积 Σexp(x - m)·V
# 处理第 k 个 KV 块
m_new = max(m_old, max(scores_k)) # 更新最大值
rescale = exp(m_old - m_new) # 修正之前的累积
l = l * rescale + sum(exp(scores_k - m_new)) # 更新分母
O = O * rescale + exp(scores_k - m_new) @ V_k # 更新输出
m = m_new
最后 O / l 就是精确的 attention 输出——与标准实现数学等价。
# flash_attention.py — 完整的分块实现
def flash_attention_forward(Q, K, V, block_size=64, is_causal=False):
B, H, S_q, D = Q.shape
scale = 1.0 / math.sqrt(D)
output_blocks = []
for q_idx in range(n_q_blocks):
Q_block = Q[:, :, q_start:q_end, :]
# 每个 Q 块独立的累积器
O_block = torch.zeros(B, H, Bq, D)
l_block = torch.zeros(B, H, Bq, 1)
m_block = torch.full((B, H, Bq, 1), -inf)
for kv_idx in range(n_kv_blocks):
K_block = K[:, :, kv_start:kv_end, :]
V_block = V[:, :, kv_start:kv_end, :]
# 局部注意力分数
scores = (Q_block @ K_block.T) * scale # [Bq, Bkv] ← 小矩阵!
# Online Softmax 更新
m_new = max(m_block, scores.max())
rescale = exp(m_block - m_new)
P = exp(scores - m_new)
l_block = l_block * rescale + P.sum()
O_block = O_block * rescale + P @ V_block
m_block = m_new
output_blocks.append(O_block / l_block)
return torch.cat(output_blocks, dim=2)
python flash_attention.py
🔬 手写 FlashAttention 等价性验证
block_size= 32 | max_diff=8.34e-07 | ✅
block_size= 64 | max_diff=7.15e-07 | ✅
block_size=128 | max_diff=9.54e-07 | ✅
📊 显存分析 (S=256):
标准 Attention: 实例化 256×256 矩阵 = 4096 KB
FlashAttention: 只需累积器 = 1536 KB
节省: 62%
三种 block_size 全部等价(误差 < 1e-6),显存节省 62%。
我们的纯 PyTorch 实现展示了算法原理,但实际加速有限(甚至更慢),因为:
| 层面 | 标准实现 | 手写 FlashAttention | CUDA FlashAttention |
|---|---|---|---|
| 内存层级 | HBM(慢但大) | HBM(同样慢) | SRAM(快但小) |
| 内核调用 | 1 次大矩阵 | N/block_size 次小矩阵 | 1 次融合 kernel |
| IO 次数 | O(N²) 读写 HBM | O(N²) 读写 HBM | O(N) 读写 HBM |
真正的 FlashAttention 用 Triton/CUDA 把整个分块循环写成一个 GPU kernel,数据在 SRAM(快 10 倍)中计算完毕后才写回 HBM。这是 IO-aware 的精髓——算得多但搬得少,反而更快。
PyTorch 2.0+ 的 F.scaled_dot_product_attention 已经内置了 FlashAttention 后端,实际工程中直接用即可。
标准 Pre-LN Transformer 的残差连接是固定权重 1.0:
# 每层都是: x_l = x_{l-1} + Attention(LayerNorm(x_{l-1}))
x_1 = x_0 + attn_1(LN(x_0))
x_2 = x_1 + attn_2(LN(x_1))
...
x_6 = x_5 + attn_6(LN(x_5))
展开后:
x_6 = x_0 + attn_1 + attn_2 + attn_3 + attn_4 + attn_5 + attn_6
↑
原始 embedding 被 6 个 attention 输出"淹没"
底层的 token embedding 信息在深层被稀释。对于乘法 CoT 这类需要精确追踪中间数字的任务,这可能导致深层"忘记"早期的 S 步骤结果。
Moonshot (Kimi) 提出的 AttnRes,核心改动很小:
class AttnResBlock(nn.Module):
def __init__(self, layer_idx, d_model):
# 每层一个可学习的聚合权重向量
self.agg_weights = nn.Parameter(torch.zeros(layer_idx + 1))
def forward(self, x, all_prev_outputs):
# all_prev_outputs = [h_0, h_1, ..., h_{l-1}]
# 学习每个前序层的聚合权重
alpha = softmax(self.agg_weights)
# 加权聚合(替代简单的 x + attn(x))
aggregated = Σ α_i × h_i
# 当前层的变换
h = aggregated + Attention(LN(aggregated)) + FFN(...)
return h
参数开销:6 层模型只多 1+2+3+4+5+6 = 21 个参数(0.0004%),几乎可以忽略。
为了真正测试标准架构与 AttnRes 的区别,我们将难度拉满,进行 4位数乘法 (4d×4d) 训练。 4d 乘法在我们的 CoT 模板下,需要输出高达 700+ 个 token 的长推理链(包含 16 个 S 步乘法和 15 个 A 步加法)。在这种长程推理下,底层的原始数字信息极易在深层被稀释。
| 控制变量 | 值 |
|---|---|
| 模型参数 | 标准: 4,743,168 vs AttnRes: 4,743,189 (+21) |
| 训练数据 | 135,072 条(包含极难的 4d 专项数据) |
| 最大序列长度 | 800 |
| Epochs | 5 |
| 初始化 | 两者都从零开始(无预训练权重) |
Epoch | 标准 Acc (4d) | AttnRes Acc (4d)
------+-------------------+-------------------
1 | 88% | 79%
2 | 98% | 94%
3 | 99% | 98%
4 | 100% | 100%
5 | 100% | 100%
发现: - 两者最终都达到了惊人的 100% 准确率!这意味着仅 4.7M 的小模型彻底掌握了四位数乘法算法! - AttnRes 初期(Epoch 1)起步稍慢,这是因为它初始化为均匀聚合,偏离了传统的残差模式(只看上一层),需要时间来学习各层的最优分配。但它很快在 Epoch 4 追平并达到完美。
Layer | 标准 (3d) | AttnRes (4d) |
------+-----------------+---------------+
0 | 0.0117 | 0.0011 | ← 小了 10 倍!
3 | 0.0096 | 0.0008 |
5 | 0.0112 | 0.0032 |
即便在更难的 4d 任务中,AttnRes 的梯度比标准模型小了近 10 倍,且极其平滑。 AttnRes 实质上是一种自适应的梯度调节器。对于更深的模型(12层、24层),这种避免梯度爆炸/消失的稳定性优势会成为决定模型生死的关键。
完成 4d 训练后,AttnRes 自动学习到的层间权重分配如下:
layer_0: [1.000]
layer_1: [0.650, 0.350] ← 严重偏重 embedding
layer_2: [0.480, 0.223, 0.297]
layer_3: [0.391, 0.172, 0.171, 0.267]
layer_4: [0.330, 0.152, 0.127, 0.164, 0.227]
layer_5: [0.316, 0.125, 0.098, 0.109, 0.136, 0.216] ← 极致 U 形
极其夸张的“底层穿透”:
注意最深层 layer_5。它分配给最底层 layer_0(原始数字 Embedding)的权重高达 31.6%,甚至超过了对上一层 layer_4(21.6%)的关注!
为什么会这样? 当进行长达 800 token 的 4d 复杂推理时,模型随时可能“算晕”。它必须时刻回头看最初输入的那两个 4 位数到底长什么样。标准的残差连接在经历 6 层非线性变换后,原始数字信息早就模糊了。而 AttnRes 允许最后一层直接“抄近路”调取原始的 Token Embedding,这就是它能在复杂 CoT 推理中稳住阵脚的核心秘密。
现象:标准模型训练正常,AttnRes 模型 loss.backward() 报错 RuntimeError: modified by an inplace operation
根因:
1. AttnRes 需要保存所有前序层输出 (all_prev_outputs)
2. FlashAttention 原始实现用 slice 写入 O[:, :, q_start:q_end, :] = ...
3. PyTorch autograd 认为这是 inplace 操作,破坏了保存的张量的版本号
修复:
# 错误: 全局 O 张量 + slice 写入 (inplace)
O = torch.zeros(B, H, S_q, D)
for q_idx in range(n_q_blocks):
O[:, :, q_start:q_end, :] = result # ← inplace!
# 正确: 每个 Q 块独立累积,最后 cat
output_blocks = []
for q_idx in range(n_q_blocks):
O_block = torch.zeros(B, H, Bq, D) # ← 独立张量
...
output_blocks.append(O_block / l_block)
return torch.cat(output_blocks, dim=2) # ← 非 inplace
教训:当计算图中有长距离依赖(如 AttnRes 跨层引用)时,所有中间操作必须避免 inplace 修改。
| 模型 | Epoch 耗时 | 原因 |
|---|---|---|
| 标准 | 120s | PyTorch 内置 scaled_dot_product_attention |
| AttnRes | 212s | 手写 FlashAttention + .clone() 开销 |
AttnRes 慢了 1.8 倍,主要是:
1. 手写 Python 循环 vs PyTorch 融合 CUDA kernel
2. 每层 x.clone() 保存到 all_prev_outputs 的拷贝开销
3. torch.stack + 加权求和的计算
实际工程中,AttnRes 的聚合操作可以用 CUDA kernel 优化,开销可忽略。
为了验证模型是真的“学会”了乘法,而不是“背下”了训练集,我们在完全脱离训练集分布的情况下,随机生成了 5000 道全新的 4d×4d 乘法题 进行自回归推理测试。 随后,为了追求极致的 100%,我们对 AttnRes 降低学习率额外精调了 2 个 Epoch(Loss 继续下降),并再次进行 5000 题测试。
结果出现了极具教学意义的现象:
架构 | 5 Epochs 测试准确率 | 7 Epochs 测试准确率 (精调后)
-----------+----------------------+-----------------------------
Standard | 100% | 100%
AttnRes | 99.92% (错4题) | 99.64% (错18题) ↓
为什么 AttnRes 越训练越错?
当我们查看 AttnRes 错误输出的完整 CoT 推理序列时发现,它其实已经算出了正确答案并且输出了 <EOS>,但它没有停下,而是陷入了类似 Z10443000<EOS>30000=00037401...111111111 的死循环中(循环输出历史算式中的数字片段)。
这就是深层残差网络中经典的 特征坍缩 (Feature Collapse) 与 回声效应 (Echo Effect): 1. AttnRes 因为有高达 31.6% 的权重直接短接回了最底层 (Layer 0)。 2. 在模型轻微过拟合时,最底层的“残影”持续高强度地刺激输出层。 3. 导致模型在完成复杂算法后,被历史特征干扰,陷入了重复生成之前 token 的幻觉。
| 文件 | 职责 |
|---|---|
flash_attention.py |
🔬 手写 FlashAttention(tiling + online softmax) |
model.py |
🧠 标准/AttnRes 统一模型 (use_attn_res 开关) |
train.py |
📊 A/B 对比训练 + 梯度分析 + 权重可视化 |
eval_5000_both.py |
🎯 5000题极限泛化能力与特征坍缩测试脚本 |
python flash_attention.py
# 从零开始,公平对比
python train.py --epochs 5 --batch_size 64
# 生成 5000 题,输出错题的完整 CoT 序列以供分析
python eval_5000_both.py
下一章 (Ch6),我们将引入 DeepSeek-V4 的前沿架构(MLA 与 MoE 雏形),探索如何在大规模参数下突破 Transformer 的计算效率极限!