简单投影层是无复杂非线性变换的线性映射层,是知识蒸馏的基础特征对齐工具。

定义

参数说明
输入特征向量
权重矩阵
偏置向量(可选)
输出投影结果

PyTorch 实现

# 无偏置
projection = nn.Linear(d_in, d_out, bias=False)
z = projection(h)  # z = h @ W.T
 
# 带偏置
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h)  # z = h @ W.T + b

与复杂投影层对比

特性Simple ProjectionMLP
结构单线性层Linear→Activation→Linear
非线性
参数量显著更多
应用FitNets 特征对齐对比学习、跨模态蒸馏

应用示例

# 学生 512 维 → 匹配教师 2048 维
adaptor = nn.Linear(512, 2048)
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())