FlashAttention 通过分块计算和算子融合,将注意力内存复杂度从 降至 ,实现 2-4 倍加速。

核心原理

内存层次优化

存储层级容量带宽
HBM40-80GB1.5-2TB/s
SRAM~192KB19TB/s

核心思想:将计算限制在 SRAM 中,减少 HBM 访问。

分块计算

  • Q 分为 块,K/V 分为
  • 每块大小
  • 局部注意力计算在 SRAM 中完成

在线 Softmax

单遍扫描更新全局统计量:

m_{\text{new}} &= \max(m_{\text{old}}, m_{ij}) \\ l_{\text{new}} &= e^{m_{\text{old}} - m_{\text{new}}} \cdot l_{\text{old}} + e^{m_{ij} - m_{\text{new}}} \cdot l_{ij} \end{aligned}$$ ## 版本演进 | 版本 | 改进 | |------|------| | v1 | 基础分块,2-4 倍加速 | | v2 | 优化循环顺序,效率达 GEMM 的 50-73% | | v3 | Hopper 架构优化,支持 MQA | ## 性能对比 | 特性 | 标准 Attention | FlashAttention | |------|----------------|----------------| | 显存占用 | $O(N^2)$ | $O(N)$ | | 支持序列长度 | <8k | >32k | | 计算效率 | 基准 | 提升 2-4 倍 | ## 使用方式 ```python from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3-8B", use_flash_attention_2=True ) ``` 或使用 vLLM 自动启用: ```python from vllm import LLM llm = LLM(model="meta-llama/Llama-3-8B") ``` ## 应用场景 - LLM 训练(LLaMA、ChatGLM) - 长上下文处理(32k+ tokens) - 高并发推理服务 - 多模态任务 ## 资源链接 - [论文](https://arxiv.org/abs/2205.14135) - [官方实现](https://github.com/Dao-AILab/flash-attention)