张芷铭的个人博客

Transformer3DModel 采用 RoPE 三维位置编码、FiLM 条件调制和跳过层机制,专为视频生成设计。

核心结构

参数作用
num_attention_heads多头注意力头数
attention_head_dim每个头的维度
positional_embedding_type位置编码类型(支持 rope
adaptive_norm条件归一化策略

初始化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# 投影层
self.patchify_proj = nn.Linear(in_channels, inner_dim)

# RoPE 位置编码
if positional_embedding_type == "rope":
    self.precompute_freqs_cis()

# Transformer 块堆叠
self.transformer_blocks = nn.ModuleList([
    BasicTransformerBlock(inner_dim, ...) for _ in range(num_layers)
])

# 输出层(条件归一化)
self.adaln_single = AdaLayerNormSingle(inner_dim)
self.proj_out = nn.Linear(inner_dim, out_channels)

RoPE 位置编码

将时空坐标 (t, h, w) 转换为旋转矩阵:

$$\text{RoPE}(x_m, x_n) = \text{Re} \left[ e^{i(m-n)\theta} \cdot x_m x_n^* \right]$$

1
2
3
4
5
6
def precompute_freqs_cis(self, indices_grid):
    fractional_positions = indices_grid / self.positional_embedding_max_pos
    freqs = indices * (fractional_positions * 2 - 1)
    cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
    sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
    return cos_freq, sin_freq

FiLM 条件调制

$$y = \gamma \cdot x + \beta$$

1
2
3
# 时间步条件生成 scale/shift
shift, scale = self.scale_shift_table + embedded_timestep
hidden_states = hidden_states * (1 + scale) + shift

前向传播

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 输入预处理
hidden_states = self.patchify_proj(hidden_states)
freqs_cis = self.precompute_freqs_cis(indices_grid)

# Transformer 块执行
for block in self.transformer_blocks:
    hidden_states = block(hidden_states, freqs_cis=freqs_cis)

# 输出调制
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)

关键技术

技术作用
RoPE三维时空位置编码
FiLM时间步条件调制
梯度检查点节省显存
跳过层机制时空解耦生成

与经典 Transformer 的差异

  1. 三维位置编码:RoPE 支持时空坐标
  2. 条件归一化:注入时间步条件
  3. 动态计算图skip_layer_mask 控制层选择

Comments