Prefill(预填充)是LLM自回归推理的首个核心阶段,一次性处理完整输入序列并初始化KV Cache,直接决定首token延迟(TTFT)。
核心定位
Prefill是”一次性阅读与缓存构建”阶段,与后续Decode阶段分工明确:
| 维度 | Prefill 阶段 | Decode 阶段 |
|---|---|---|
| 输入 | 完整Prompt序列(长度) | 上一轮生成的单个token |
| 计算方式 | 全量并行前向传播 | 串行自回归循环 |
| 注意力复杂度 | (全量注意力) | (仅关注历史缓存) |
| 核心产出 | 初始化KV Cache + 首个token Logits | 后续每个token的Logits |
| 资源瓶颈 | GPU计算(FLOPs)+ 长序列显存 | 内存带宽(KV Cache访存) |
| 延迟影响 | 决定TTFT | 决定TPS |
数学原理与执行流程
以标准Transformer Decoder架构为例,流程包含输入编码→全量掩码注意力→FFN→KV Cache缓存→首token Logits输出。
输入编码
- 分词:将输入文本转换为token ID序列 ,形状
- 嵌入层:通过嵌入矩阵 映射为词嵌入,叠加位置编码:
全量掩码自注意力
对每一层执行多头自注意力,通过下三角掩码防止未来token信息泄露。
- Q/K/V投影:
- 注意力得分:缩放后应用下三角掩码 :
- 归一化与加权:
KV Cache初始化
缓存每一层的K、V矩阵,形状为 。Decode阶段只需计算新token的Q,与历史KV Cache做注意力,复杂度从 降至 。
首token Logits输出
取最后位置的隐藏状态,通过输出线性层生成Logits:
关键特性与性能瓶颈
| 特性 | 说明 |
|---|---|
| 计算密集型 | 全量注意力复杂度,GPU算力利用率高 |
| 显存瓶颈 | KV Cache占用,长序列易OOM |
| TTFT敏感 | 耗时直接决定用户感知的”首token等待时间” |
典型优化技术
| 技术 | 核心原理 | 适用场景 |
|---|---|---|
| FlashAttention | 分块计算替代全量矩阵存储,显存→ | 长序列Prefill |
| Chunked Prefill | 超长Prompt拆分为固定长度块,分批执行 | 超长输入(100k+ token) |
| 动态KV Cache | 仅缓存有效上下文,减少无效占用 | 对话式场景 |
| 模型并行(TP) | 参数拆分到多GPU,分摊计算与显存 | 超大模型(70B+) |
| PD分离部署 | Prefill与Decode部署在不同GPU集群 | 高并发推理服务 |
代码示例
import torch
import torch.nn as nn
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.Wq = nn.Linear(d_model, d_k)
self.Wk = nn.Linear(d_model, d_k)
self.Wv = nn.Linear(d_model, d_k)
self.Wo = nn.Linear(d_k, d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model)
)
def forward(self, x, mask, cache_k=None, cache_v=None):
q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
if cache_k is None:
cache_k, cache_v = k, v
else:
cache_k = torch.cat([cache_k, k], dim=1)
cache_v = torch.cat([cache_v, v], dim=1)
attn_scores = torch.matmul(q, cache_k.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
attn_scores = attn_scores + mask
attn_out = torch.matmul(torch.softmax(attn_scores, dim=-1), cache_v)
return self.ffn(self.Wo(attn_out)), cache_k, cache_v
# Prefill执行
d_model, d_k, n = 512, 64, 100
x = torch.randn(1, n, d_model)
mask = torch.triu(torch.ones(n, n) * float('-inf'), diagonal=1).unsqueeze(0).unsqueeze(0)
decoder = TransformerDecoderLayer(d_model, d_k)
h, cache_k, cache_v = decoder(x, mask)
logits = nn.Linear(d_model, 10000)(h[:, -1, :])
print(f"KV Cache形状: {cache_k.shape}, Logits形状: {logits.shape}")总结
Prefill阶段是LLM推理的”启动引擎”,通过一次性并行计算构建上下文缓存,将后续生成的复杂度从平方级降至线性级。理解其数学原理、性能瓶颈与优化技术,是实现高效大模型推理的关键。