张芷铭的个人博客

简单投影层是无复杂非线性变换的线性映射层,是知识蒸馏的基础特征对齐工具。

定义

$$z = W \cdot h + b$$

参数说明
$h \in \mathbb{R}^{d_{in}}$输入特征向量
$W \in \mathbb{R}^{d_{out} \times d_{in}}$权重矩阵
$b \in \mathbb{R}^{d_{out}}$偏置向量(可选)
$z \in \mathbb{R}^{d_{out}}$输出投影结果

PyTorch 实现

1
2
3
4
5
6
7
# 无偏置
projection = nn.Linear(d_in, d_out, bias=False)
z = projection(h)  # z = h @ W.T

# 带偏置
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h)  # z = h @ W.T + b

与复杂投影层对比

特性Simple ProjectionMLP
结构单线性层Linear→Activation→Linear
非线性
参数量$d_{in} \times d_{out}$显著更多
应用FitNets 特征对齐对比学习、跨模态蒸馏

应用示例

1
2
3
# 学生 512 维 → 匹配教师 2048 维
adaptor = nn.Linear(512, 2048)
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())

Comments