Transformer3DModel 采用 RoPE 三维位置编码、FiLM 条件调制和跳过层机制,专为视频生成设计。
核心结构
| 参数 | 作用 |
|---|---|
num_attention_heads | 多头注意力头数 |
attention_head_dim | 每个头的维度 |
positional_embedding_type | 位置编码类型(支持 rope) |
adaptive_norm | 条件归一化策略 |
初始化
# 投影层
self.patchify_proj = nn.Linear(in_channels, inner_dim)
# RoPE 位置编码
if positional_embedding_type == "rope":
self.precompute_freqs_cis()
# Transformer 块堆叠
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim, ...) for _ in range(num_layers)
])
# 输出层(条件归一化)
self.adaln_single = AdaLayerNormSingle(inner_dim)
self.proj_out = nn.Linear(inner_dim, out_channels)RoPE 位置编码
将时空坐标 (t, h, w) 转换为旋转矩阵:
def precompute_freqs_cis(self, indices_grid):
fractional_positions = indices_grid / self.positional_embedding_max_pos
freqs = indices * (fractional_positions * 2 - 1)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
return cos_freq, sin_freqFiLM 条件调制
# 时间步条件生成 scale/shift
shift, scale = self.scale_shift_table + embedded_timestep
hidden_states = hidden_states * (1 + scale) + shift前向传播
# 输入预处理
hidden_states = self.patchify_proj(hidden_states)
freqs_cis = self.precompute_freqs_cis(indices_grid)
# Transformer 块执行
for block in self.transformer_blocks:
hidden_states = block(hidden_states, freqs_cis=freqs_cis)
# 输出调制
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)关键技术
| 技术 | 作用 |
|---|---|
| RoPE | 三维时空位置编码 |
| FiLM | 时间步条件调制 |
| 梯度检查点 | 节省显存 |
| 跳过层机制 | 时空解耦生成 |
与经典 Transformer 的差异
- 三维位置编码:RoPE 支持时空坐标
- 条件归一化:注入时间步条件
- 动态计算图:
skip_layer_mask控制层选择