张芷铭的个人博客

蒸馏投影层解决师生模型特征空间不匹配问题,是知识蒸馏的关键组件。

定义

投影层是附加在中间特征层或输出层的小型神经网络模块,核心目的:

  1. 特征空间对齐:将师生特征映射到共享空间
  2. 距离度量优化:使投影后特征可准确度量知识差异

发展历程

方法年份贡献
Hinton KD2015输出层 logits 软化分布
FitNets2015首次中间层蒸馏,线性适配层
CRD2020对比学习引入蒸馏,MLP 投影头

核心原理

投影层是非线性变换:

$$z = g(h) = \sigma(W \cdot h + b)$$

蒸馏损失在投影空间计算:

$$L_{distill} = \mathcal{D}(z_s, z_t)$$

适用场景

场景说明
中间层蒸馏维度、语义差距过大时必需
对比蒸馏CRD、SSKD 的核心组件
跨架构蒸馏CNN 教师 → Transformer 学生
多模态蒸馏跨模态特征对齐

实践经验

结构选择

  • 简单对齐:单线性层
  • 复杂对齐:Linear→ReLU→Linear

维度设计:教师投影维度 $d_p^t$ 与学生 $d_p^s$ 必须相等。

归一化:余弦相似度/对比损失需 L2 归一化。

推理移除:投影层仅训练时使用,部署时移除。

代码示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=None):
        super().__init__()
        if hidden_dim is None:
            self.net = nn.Linear(input_dim, output_dim)
        else:
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, output_dim)
            )

    def forward(self, x):
        return self.net(x)

# 损失计算
l2_loss = F.mse_loss(s_proj, t_proj.detach())
t_proj_norm = F.normalize(t_proj, p=2, dim=1)
s_proj_norm = F.normalize(s_proj, p=2, dim=1)
cos_loss = 1 - (t_proj_norm * s_proj_norm).sum(dim=1).mean()

最新进展

  • 稀疏投影头:减轻遗忘,提升泛化
  • 轻量化设计:深度可分离卷积、低秩分解
  • 信息瓶颈视角:过滤噪声,提取关键知识

Comments