Simple Projection Layers 的定义与运算
基本定义
Simple Projection Layers(简单投影层) 在深度学习领域通常指无复杂非线性变换的线性映射层,其核心运算为: $$ z = W \cdot h + b $$ 其中:
- $h \in \mathbb{R}^{d_{in}}$:输入特征向量
- $W \in \mathbb{R}^{d_{out} \times d_{in}}$:可学习权重矩阵
- $b \in \mathbb{R}^{d_{out}}$:可学习偏置向量(可选)
- $z \in \mathbb{R}^{d_{out}}$:输出投影结果
具体运算解析
1. 纯线性投影(无偏置)
| |
数学形式:
$$z_j = \sum_{i=1}^{d_{in}} W_{ji} h_i \quad \text{其中} \quad W \in \mathbb{R}^{d_{out} \times d_{in}}$$
2. 带偏置的线性投影
# PyTorch 实现(默认含偏置)
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h) # 等价于 z = h @ W.T + b
数学形式:z_j = b_j + \sum_{i=1}^{d_{in}} W_{ji} h_i
3. 特殊变体
降维投影 (d_{out} < d_{in})
常用于特征压缩:W \in \mathbb{R}^{128 \times 1024}, \quad h \in \mathbb{R}^{1024} \Rightarrow z \in \mathbb{R}^{128}
升维投影 (d_{out} > d_{in})
用于特征扩展:W \in \mathbb{R}^{512 \times 256}, \quad h \in \mathbb{R}^{256} \Rightarrow z \in \mathbb{R}^{512}
与复杂投影层的对比
| 特性 | Simple Projection | Complex Projection (e.g., MLP) |
|---|---|---|
| 结构 | 单线性层 | 多层网络 (Linear→Activation→Linear) |
| 非线性能力 | 无(仅仿射变换) | 有(通过激活函数引入) |
| 参数量 | d_{in} \times d_{out} | 显著更多(与隐藏层维度相关) |
| 典型应用场景 | FitNets式特征对齐 | 对比学习(如CRD)、跨模态蒸馏 |
| 计算开销 | 低 | 较高 |
为什么称为"Simple"?
- 无隐藏层:直接进行输入到输出的映射
- 无激活函数:不引入非线性(ReLU/GELU等)
- 参数效率高:仅需学习
W和(可选的)b
经典应用示例
FitNets中的适配层
# 学生特征维度512 → 匹配教师特征维度2048
adaptor = nn.Linear(512, 2048) # Simple Projection
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())
可视化运算过程
假设输入 $h = [1.0, 2.0, 3.0]$ , $W = \begin{bmatrix}0.1 & 0.2 & 0.3\0.4 & 0.5 & 0.6\end{bmatrix}$, $b = [0.1, 0.2]$
则:
$$z = \begin{bmatrix} 0.1 \times 1.0 + 0.2 \times 2.0 + 0.3 \times 3.0 + 0.1 \ 0.4 \times 1.0 + 0.5 \times 2.0 + 0.6 \times 3.0 + 0.2 \end{bmatrix} = \begin{bmatrix}1.5 \ 3.2\end{bmatrix}$$
总结
Simple Projection Layers的本质是可学习的线性变换矩阵,通过矩阵乘法实现特征空间的维度调整或旋转/缩放,是知识蒸馏中最基础但广泛有效的特征对齐工具。
张芷铭的个人博客
💬 评论