张芷铭的个人博客

大模型训练/推理的端到端吞吐优化三大核心技术:算子融合、多模态GEMM调度、显存/激活管理。

Adaln Varlen Fuse

核心原理

  • Adaln:LayerNorm + 可学习缩放/偏移参数,适配多模态输入分布差异
  • Varlen:支持变长序列,避免padding计算浪费
  • Fuse:Adaln + 激活函数 + 残差连接 + 变长mask合并为单个CUDA内核

收益:E2E吞吐+10%,减少访存瓶颈

实现示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class AdaLayerNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(1, 1, dim))
        self.shift = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + self.eps)
        return x * self.scale + self.shift

工程配置:使用torch.compile、FlashAttention-2、Triton自动融合

Modality-Aware Group GEMM

核心原理

将多个小GEMM打包为一组大GEMM并行执行,按模态类型分组使用专属分块策略。

收益:E2E吞吐+20%

实现示例

1
2
3
4
5
6
7
# 传统:N个专家 → N次小GEMM(低效)
outputs = [torch.matmul(x, e.weight) for e in experts]

# Group GEMM:合并为1次大GEMM(高效)
group_weights = torch.cat([e.weight for e in experts], dim=1)
group_out = torch.matmul(x, group_weights)
outputs = group_out.split(d_out, dim=-1)

工程配置:开启--moe-grouped-gemm(Megatron)或使用DeepGEMM、CUTLASS GroupGEMM

Multi-Stream Activation Offload

核心原理

  • Gradient Checkpoint:只保存检查点,反向时重计算激活,用计算换显存
  • Activation Offload:激活值卸载到CPU内存,释放GPU显存
  • Multi-Stream:计算流与传输流并行,隐藏Offload开销

收益:E2E吞吐+33%,Batch Size翻倍

实现示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x, use_reentrant=False)

# 多流重叠
stream_compute = torch.cuda.Stream()
stream_transfer = torch.cuda.Stream()

with torch.cuda.stream(stream_compute):
    out = self.block(x)
with torch.cuda.stream(stream_transfer):
    checkpoint_cpu = out.to('cpu', non_blocking=True)

技术对比

技术优化点收益适用场景
Adaln Varlen Fuse算子融合+变长计算+10%多模态Transformer
Modality-Aware Group GEMM分组GEMM+模态调度+20%MoE、多分支模型
Multi-Stream Activation Offload重计算+显存卸载+33%大模型、显存受限

落地优先级

  1. 先开Activation Offload:见效最快(+33%)
  2. 再做Group GEMM:多模态/MoE必选(+20%)
  3. 最后优化Adaln Varlen Fuse:极致性能调优(+10%)

工程避坑

技术注意事项
Adaln Fuse融合后校准数值稳定性
Group GEMM分组过大可能显存峰值飙升
Activation Offload多流需设置non_blocking=True

Comments