张芷铭的个人博客

FlashAttention 通过分块计算和算子融合,将注意力内存复杂度从 $O(N^2)$ 降至 $O(N)$,实现 2-4 倍加速。

核心原理

内存层次优化

存储层级容量带宽
HBM40-80GB1.5-2TB/s
SRAM~192KB19TB/s

核心思想:将计算限制在 SRAM 中,减少 HBM 访问。

分块计算

  • Q 分为 $T_r$ 块,K/V 分为 $T_c$ 块
  • 每块大小 $B_r = B_c = 64$
  • 局部注意力计算在 SRAM 中完成

在线 Softmax

单遍扫描更新全局统计量:

$$\begin{aligned} m_{\text{new}} &= \max(m_{\text{old}}, m_{ij}) \ l_{\text{new}} &= e^{m_{\text{old}} - m_{\text{new}}} \cdot l_{\text{old}} + e^{m_{ij} - m_{\text{new}}} \cdot l_{ij} \end{aligned}$$

版本演进

版本改进
v1基础分块,2-4 倍加速
v2优化循环顺序,效率达 GEMM 的 50-73%
v3Hopper 架构优化,支持 MQA

性能对比

特性标准 AttentionFlashAttention
显存占用$O(N^2)$$O(N)$
支持序列长度<8k>32k
计算效率基准提升 2-4 倍

使用方式

1
2
3
4
5
6
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    use_flash_attention_2=True
)

或使用 vLLM 自动启用:

1
2
from vllm import LLM
llm = LLM(model="meta-llama/Llama-3-8B")

应用场景

  • LLM 训练(LLaMA、ChatGLM)
  • 长上下文处理(32k+ tokens)
  • 高并发推理服务
  • 多模态任务

资源链接

Comments