蒸馏投影层解决师生模型特征空间不匹配问题,是知识蒸馏的关键组件。
定义
投影层是附加在中间特征层或输出层的小型神经网络模块,核心目的:
- 特征空间对齐:将师生特征映射到共享空间
- 距离度量优化:使投影后特征可准确度量知识差异
发展历程
| 方法 | 年份 | 贡献 |
|---|---|---|
| Hinton KD | 2015 | 输出层 logits 软化分布 |
| FitNets | 2015 | 首次中间层蒸馏,线性适配层 |
| CRD | 2020 | 对比学习引入蒸馏,MLP 投影头 |
核心原理
投影层是非线性变换:
$$z = g(h) = \sigma(W \cdot h + b)$$
蒸馏损失在投影空间计算:
$$L_{distill} = \mathcal{D}(z_s, z_t)$$
适用场景
| 场景 | 说明 |
|---|---|
| 中间层蒸馏 | 维度、语义差距过大时必需 |
| 对比蒸馏 | CRD、SSKD 的核心组件 |
| 跨架构蒸馏 | CNN 教师 → Transformer 学生 |
| 多模态蒸馏 | 跨模态特征对齐 |
实践经验
结构选择:
- 简单对齐:单线性层
- 复杂对齐:
Linear→ReLU→Linear
维度设计:教师投影维度 $d_p^t$ 与学生 $d_p^s$ 必须相等。
归一化:余弦相似度/对比损失需 L2 归一化。
推理移除:投影层仅训练时使用,部署时移除。
代码示例
| |
最新进展
- 稀疏投影头:减轻遗忘,提升泛化
- 轻量化设计:深度可分离卷积、低秩分解
- 信息瓶颈视角:过滤噪声,提取关键知识
张芷铭的个人博客
Comments