Prefill(预填充)是LLM自回归推理的首个核心阶段,一次性处理完整输入序列并初始化KV Cache,直接决定首token延迟(TTFT)。

核心定位

Prefill是”一次性阅读与缓存构建”阶段,与后续Decode阶段分工明确:

维度Prefill 阶段Decode 阶段
输入完整Prompt序列(长度上一轮生成的单个token
计算方式全量并行前向传播串行自回归循环
注意力复杂度(全量注意力)(仅关注历史缓存)
核心产出初始化KV Cache + 首个token Logits后续每个token的Logits
资源瓶颈GPU计算(FLOPs)+ 长序列显存内存带宽(KV Cache访存)
延迟影响决定TTFT决定TPS

数学原理与执行流程

以标准Transformer Decoder架构为例,流程包含输入编码→全量掩码注意力→FFN→KV Cache缓存→首token Logits输出

输入编码

  • 分词:将输入文本转换为token ID序列 ,形状
  • 嵌入层:通过嵌入矩阵 映射为词嵌入,叠加位置编码:

全量掩码自注意力

对每一层执行多头自注意力,通过下三角掩码防止未来token信息泄露。

  1. Q/K/V投影
  1. 注意力得分:缩放后应用下三角掩码
  1. 归一化与加权

KV Cache初始化

缓存每一层的K、V矩阵,形状为 。Decode阶段只需计算新token的Q,与历史KV Cache做注意力,复杂度从 降至

首token Logits输出

取最后位置的隐藏状态,通过输出线性层生成Logits:

关键特性与性能瓶颈

特性说明
计算密集型全量注意力复杂度,GPU算力利用率高
显存瓶颈KV Cache占用,长序列易OOM
TTFT敏感耗时直接决定用户感知的”首token等待时间”

典型优化技术

技术核心原理适用场景
FlashAttention分块计算替代全量矩阵存储,显存长序列Prefill
Chunked Prefill超长Prompt拆分为固定长度块,分批执行超长输入(100k+ token)
动态KV Cache仅缓存有效上下文,减少无效占用对话式场景
模型并行(TP)参数拆分到多GPU,分摊计算与显存超大模型(70B+)
PD分离部署Prefill与Decode部署在不同GPU集群高并发推理服务

代码示例

import torch
import torch.nn as nn
 
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.Wq = nn.Linear(d_model, d_k)
        self.Wk = nn.Linear(d_model, d_k)
        self.Wv = nn.Linear(d_model, d_k)
        self.Wo = nn.Linear(d_k, d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
 
    def forward(self, x, mask, cache_k=None, cache_v=None):
        q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
 
        if cache_k is None:
            cache_k, cache_v = k, v
        else:
            cache_k = torch.cat([cache_k, k], dim=1)
            cache_v = torch.cat([cache_v, v], dim=1)
 
        attn_scores = torch.matmul(q, cache_k.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        attn_scores = attn_scores + mask
        attn_out = torch.matmul(torch.softmax(attn_scores, dim=-1), cache_v)
 
        return self.ffn(self.Wo(attn_out)), cache_k, cache_v
 
# Prefill执行
d_model, d_k, n = 512, 64, 100
x = torch.randn(1, n, d_model)
mask = torch.triu(torch.ones(n, n) * float('-inf'), diagonal=1).unsqueeze(0).unsqueeze(0)
decoder = TransformerDecoderLayer(d_model, d_k)
h, cache_k, cache_v = decoder(x, mask)
logits = nn.Linear(d_model, 10000)(h[:, -1, :])
print(f"KV Cache形状: {cache_k.shape}, Logits形状: {logits.shape}")

总结

Prefill阶段是LLM推理的”启动引擎”,通过一次性并行计算构建上下文缓存,将后续生成的复杂度从平方级降至线性级。理解其数学原理、性能瓶颈与优化技术,是实现高效大模型推理的关键。