# 传统:N次小GEMM(低效)outputs = [torch.matmul(x, e.weight) for e in experts]# Group GEMM:1次大GEMM(高效)group_weights = torch.cat([e.weight for e in experts], dim=1)group_out = torch.matmul(x, group_weights)outputs = group_out.split(d_out, dim=-1)
算子库:DeepGEMM、CUTLASS GroupGEMM、Megatron-LM
配置:--moe-grouped-gemm,按模态设置block_m、block_n
Multi-Stream Activation Offload(多流激活卸载)
核心原理
技术
作用
Gradient Checkpoint
只保存检查点,反向重计算激活,显存节省30%-60%
Activation Offload
激活值卸载到CPU内存,释放GPU显存
Multi-Stream
计算流 + 传输流并行,抵消Offload时间开销
收益:E2E吞吐+33%
实现
from torch.utils.checkpoint import checkpointclass TransformerBlock(nn.Module): def forward(self, x): return checkpoint(self._forward, x, use_reentrant=False) def _forward(self, x): x = x + self.attn(self.adaln1(x)) x = x + self.ffn(self.adaln2(x)) return x