训练与推理优化模块 Wan针对大规模视频生成的计算与内存瓶颈,设计了并行策略、内存优化、推理加速三大模块,支撑14B参数模型的训练与部署。
3.1 并行训练策略
针对DiT模块的高计算需求,采用“2D上下文并行(CP)+全分片数据并行(FSDP)+数据并行(DP)”的混合并行架构:
- 2D上下文并行(Context Parallel):结合Ulysses和环形注意力(Ring Attention),将序列长度(s)和隐藏层维度(h)分块到不同GPU,减少注意力计算的内存占用。例如,在128 GPU配置下,Ulysses=8、Ring=2,CP规模16,通信开销从10%降至1%以下;
- FSDP:将模型参数、梯度、优化器状态分片存储,解决单GPU内存不足问题,与训练阶段保持一致;
- 模块间策略切换:VAE和文本编码器采用DP,DiT采用“DP+CP”,通过数据广播确保模块间输入一致性,减少冗余计算。
3.2 内存优化
- 激活卸载(Activation Offloading):将Transformer层的激活值卸载到CPU,与计算重叠,相比梯度检查点(GC)更能平衡内存与速度;
- 混合GC+卸载:对内存-计算比高的层优先使用GC,对长序列场景结合CPU卸载,避免内存溢出;
- 集群可靠性:基于阿里云的智能调度、故障检测与自愈能力,确保训练过程中节点故障时任务自动恢复。
3.3 推理加速技术
为降低视频生成延迟,Wan集成了扩散缓存、量化、并行推理三大技术:
- 扩散缓存(Diffusion Cache):利用采样步骤间的注意力相似性和CFG(Classifier-Free Guidance)相似性,每隔若干步缓存注意力结果和无条件生成结果,复用缓存减少计算,使14B模型推理速度提升1.62×;
- 量化优化:
- FP8 GEMM:对Transformer中的GEMM操作采用FP8精度,权重按张量量化,激活按token量化,性能比BF16提升2倍,DiT模块速度提升1.13×;
- 8位FlashAttention:混合INT8(用于QK^T计算)和FP8(用于PV计算),结合FP32跨块累积避免溢出,在NVIDIA H20 GPU上实现95%的MFU(模型浮点利用率),推理效率提升1.27×;
- 并行推理:沿用训练阶段的2D CP+FSDP策略,实现近线性的多GPU加速。
💬 评论