FlashAttention 通过分块计算和算子融合,将注意力内存复杂度从 $O(N^2)$ 降至 $O(N)$,实现 2-4 倍加速。
核心原理
内存层次优化
| 存储层级 | 容量 | 带宽 |
|---|---|---|
| HBM | 40-80GB | 1.5-2TB/s |
| SRAM | ~192KB | 19TB/s |
核心思想:将计算限制在 SRAM 中,减少 HBM 访问。
分块计算
- Q 分为 $T_r$ 块,K/V 分为 $T_c$ 块
- 每块大小 $B_r = B_c = 64$
- 局部注意力计算在 SRAM 中完成
在线 Softmax
单遍扫描更新全局统计量:
$$\begin{aligned} m_{\text{new}} &= \max(m_{\text{old}}, m_{ij}) \ l_{\text{new}} &= e^{m_{\text{old}} - m_{\text{new}}} \cdot l_{\text{old}} + e^{m_{ij} - m_{\text{new}}} \cdot l_{ij} \end{aligned}$$
版本演进
| 版本 | 改进 |
|---|---|
| v1 | 基础分块,2-4 倍加速 |
| v2 | 优化循环顺序,效率达 GEMM 的 50-73% |
| v3 | Hopper 架构优化,支持 MQA |
性能对比
| 特性 | 标准 Attention | FlashAttention |
|---|---|---|
| 显存占用 | $O(N^2)$ | $O(N)$ |
| 支持序列长度 | <8k | >32k |
| 计算效率 | 基准 | 提升 2-4 倍 |
使用方式
| |
或使用 vLLM 自动启用:
| |
应用场景
- LLM 训练(LLaMA、ChatGLM)
- 长上下文处理(32k+ tokens)
- 高并发推理服务
- 多模态任务
张芷铭的个人博客
Comments