简单投影层是无复杂非线性变换的线性映射层,是知识蒸馏的基础特征对齐工具。
定义
$$z = W \cdot h + b$$
| 参数 | 说明 |
|---|
| $h \in \mathbb{R}^{d_{in}}$ | 输入特征向量 |
| $W \in \mathbb{R}^{d_{out} \times d_{in}}$ | 权重矩阵 |
| $b \in \mathbb{R}^{d_{out}}$ | 偏置向量(可选) |
| $z \in \mathbb{R}^{d_{out}}$ | 输出投影结果 |
PyTorch 实现
1
2
3
4
5
6
7
| # 无偏置
projection = nn.Linear(d_in, d_out, bias=False)
z = projection(h) # z = h @ W.T
# 带偏置
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h) # z = h @ W.T + b
|
与复杂投影层对比
| 特性 | Simple Projection | MLP |
|---|
| 结构 | 单线性层 | Linear→Activation→Linear |
| 非线性 | 无 | 有 |
| 参数量 | $d_{in} \times d_{out}$ | 显著更多 |
| 应用 | FitNets 特征对齐 | 对比学习、跨模态蒸馏 |
应用示例
1
2
3
| # 学生 512 维 → 匹配教师 2048 维
adaptor = nn.Linear(512, 2048)
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())
|
Comments