张芷铭的个人博客

Prefill(预填充)是LLM自回归推理的首个核心阶段,一次性处理完整输入序列并初始化KV Cache,直接决定首token延迟(TTFT)。

核心定位

Prefill是"一次性阅读与缓存构建"阶段,与后续Decode阶段分工明确:

维度Prefill 阶段Decode 阶段
输入完整Prompt序列(长度$n$)上一轮生成的单个token
计算方式全量并行前向传播串行自回归循环
注意力复杂度$O(n^2)$(全量注意力)$O(n)$(仅关注历史缓存)
核心产出初始化KV Cache + 首个token Logits后续每个token的Logits
资源瓶颈GPU计算(FLOPs)+ 长序列显存内存带宽(KV Cache访存)
延迟影响决定TTFT决定TPS

数学原理与执行流程

以标准Transformer Decoder架构为例,流程包含输入编码→全量掩码注意力→FFN→KV Cache缓存→首token Logits输出

输入编码

  • 分词:将输入文本转换为token ID序列 $\boldsymbol{x} = [x_1, …, x_n]$,形状 $[B, n]$
  • 嵌入层:通过嵌入矩阵 $E \in \mathbb{R}^{V \times d_{model}}$ 映射为词嵌入,叠加位置编码:

$$ \boldsymbol{h}0 = \boldsymbol{E}[\boldsymbol{x}] + \boldsymbol{PE}(1..n) \in \mathbb{R}^{B \times n \times d{model}} $$

全量掩码自注意力

对每一层执行多头自注意力,通过下三角掩码防止未来token信息泄露。

  1. Q/K/V投影

$$ \boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V} = \boldsymbol{h}_{l-1} \cdot \boldsymbol{W}Q, \boldsymbol{h}{l-1} \cdot \boldsymbol{W}K, \boldsymbol{h}{l-1} \cdot \boldsymbol{W}_V $$

  1. 注意力得分:缩放后应用下三角掩码 $M$:

$$ \boldsymbol{Attn}_{raw} = \frac{\boldsymbol{Q} \cdot \boldsymbol{K}^\top}{\sqrt{d_k}} + \boldsymbol{M} $$

  1. 归一化与加权

$$ \boldsymbol{Attn}{out} = \text{Softmax}(\boldsymbol{Attn}{raw}) \cdot \boldsymbol{V} $$

KV Cache初始化

缓存每一层的K、V矩阵,形状为 $\mathbb{R}^{B \times h \times n \times d_k}$。Decode阶段只需计算新token的Q,与历史KV Cache做注意力,复杂度从 $O(n^2)$ 降至 $O(n)$。

首token Logits输出

取最后位置的隐藏状态,通过输出线性层生成Logits: $$ \boldsymbol{\text{Logits}} = \boldsymbol{h}L[:, n, :] \cdot \boldsymbol{W}{out} \in \mathbb{R}^{B \times V} $$

关键特性与性能瓶颈

特性说明
计算密集型全量注意力$O(n^2)$复杂度,GPU算力利用率高
显存瓶颈KV Cache占用$O(L \cdot B \cdot h \cdot n \cdot d_k)$,长序列易OOM
TTFT敏感耗时直接决定用户感知的"首token等待时间"

典型优化技术

技术核心原理适用场景
FlashAttention分块计算替代全量矩阵存储,显存$O(n^2)$→$O(n)$长序列Prefill
Chunked Prefill超长Prompt拆分为固定长度块,分批执行超长输入(100k+ token)
动态KV Cache仅缓存有效上下文,减少无效占用对话式场景
模型并行(TP)参数拆分到多GPU,分摊计算与显存超大模型(70B+)
PD分离部署Prefill与Decode部署在不同GPU集群高并发推理服务

代码示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.Wq = nn.Linear(d_model, d_k)
        self.Wk = nn.Linear(d_model, d_k)
        self.Wv = nn.Linear(d_model, d_k)
        self.Wo = nn.Linear(d_k, d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x, mask, cache_k=None, cache_v=None):
        q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)

        if cache_k is None:
            cache_k, cache_v = k, v
        else:
            cache_k = torch.cat([cache_k, k], dim=1)
            cache_v = torch.cat([cache_v, v], dim=1)

        attn_scores = torch.matmul(q, cache_k.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        attn_scores = attn_scores + mask
        attn_out = torch.matmul(torch.softmax(attn_scores, dim=-1), cache_v)

        return self.ffn(self.Wo(attn_out)), cache_k, cache_v

# Prefill执行
d_model, d_k, n = 512, 64, 100
x = torch.randn(1, n, d_model)
mask = torch.triu(torch.ones(n, n) * float('-inf'), diagonal=1).unsqueeze(0).unsqueeze(0)
decoder = TransformerDecoderLayer(d_model, d_k)
h, cache_k, cache_v = decoder(x, mask)
logits = nn.Linear(d_model, 10000)(h[:, -1, :])
print(f"KV Cache形状: {cache_k.shape}, Logits形状: {logits.shape}")

总结

Prefill阶段是LLM推理的"启动引擎",通过一次性并行计算构建上下文缓存,将后续生成的复杂度从平方级降至线性级。理解其数学原理、性能瓶颈与优化技术,是实现高效大模型推理的关键。

Comments