Transformer3DModel 采用 RoPE 三维位置编码、FiLM 条件调制和跳过层机制,专为视频生成设计。
核心结构
| 参数 | 作用 |
|---|
num_attention_heads | 多头注意力头数 |
attention_head_dim | 每个头的维度 |
positional_embedding_type | 位置编码类型(支持 rope) |
adaptive_norm | 条件归一化策略 |
初始化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| # 投影层
self.patchify_proj = nn.Linear(in_channels, inner_dim)
# RoPE 位置编码
if positional_embedding_type == "rope":
self.precompute_freqs_cis()
# Transformer 块堆叠
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim, ...) for _ in range(num_layers)
])
# 输出层(条件归一化)
self.adaln_single = AdaLayerNormSingle(inner_dim)
self.proj_out = nn.Linear(inner_dim, out_channels)
|
RoPE 位置编码
将时空坐标 (t, h, w) 转换为旋转矩阵:
$$\text{RoPE}(x_m, x_n) = \text{Re} \left[ e^{i(m-n)\theta} \cdot x_m x_n^* \right]$$
1
2
3
4
5
6
| def precompute_freqs_cis(self, indices_grid):
fractional_positions = indices_grid / self.positional_embedding_max_pos
freqs = indices * (fractional_positions * 2 - 1)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
return cos_freq, sin_freq
|
FiLM 条件调制
$$y = \gamma \cdot x + \beta$$
1
2
3
| # 时间步条件生成 scale/shift
shift, scale = self.scale_shift_table + embedded_timestep
hidden_states = hidden_states * (1 + scale) + shift
|
前向传播
1
2
3
4
5
6
7
8
9
10
11
12
| # 输入预处理
hidden_states = self.patchify_proj(hidden_states)
freqs_cis = self.precompute_freqs_cis(indices_grid)
# Transformer 块执行
for block in self.transformer_blocks:
hidden_states = block(hidden_states, freqs_cis=freqs_cis)
# 输出调制
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
|
关键技术
| 技术 | 作用 |
|---|
| RoPE | 三维时空位置编码 |
| FiLM | 时间步条件调制 |
| 梯度检查点 | 节省显存 |
| 跳过层机制 | 时空解耦生成 |
- 三维位置编码:RoPE 支持时空坐标
- 条件归一化:注入时间步条件
- 动态计算图:
skip_layer_mask 控制层选择
Comments