简单投影层是无复杂非线性变换的线性映射层,是知识蒸馏的基础特征对齐工具。
定义
| 参数 | 说明 |
|---|---|
| 输入特征向量 | |
| 权重矩阵 | |
| 偏置向量(可选) | |
| 输出投影结果 |
PyTorch 实现
# 无偏置
projection = nn.Linear(d_in, d_out, bias=False)
z = projection(h) # z = h @ W.T
# 带偏置
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h) # z = h @ W.T + b与复杂投影层对比
| 特性 | Simple Projection | MLP |
|---|---|---|
| 结构 | 单线性层 | Linear→Activation→Linear |
| 非线性 | 无 | 有 |
| 参数量 | 显著更多 | |
| 应用 | FitNets 特征对齐 | 对比学习、跨模态蒸馏 |
应用示例
# 学生 512 维 → 匹配教师 2048 维
adaptor = nn.Linear(512, 2048)
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())