张芷铭的个人博客

Simple Projection Layers 的定义与运算

Simple Projection Layers 的定义与运算

基本定义

Simple Projection Layers(简单投影层) 在深度学习领域通常指无复杂非线性变换的线性映射层,其核心运算为: $$ z = W \cdot h + b $$ 其中:

  • $h \in \mathbb{R}^{d_{in}}$:输入特征向量
  • $W \in \mathbb{R}^{d_{out} \times d_{in}}$:可学习权重矩阵
  • $b \in \mathbb{R}^{d_{out}}$:可学习偏置向量(可选)
  • $z \in \mathbb{R}^{d_{out}}$:输出投影结果

具体运算解析

1. 纯线性投影(无偏置)

1
2
3
# PyTorch 实现
projection = nn.Linear(in_features=d_in, out_features=d_out, bias=False)
z = projection(h)  # 等价于 z = h @ W.T

数学形式:
$$z_j = \sum_{i=1}^{d_{in}} W_{ji} h_i \quad \text{其中} \quad W \in \mathbb{R}^{d_{out} \times d_{in}}$$

2. 带偏置的线性投影

# PyTorch 实现(默认含偏置)
projection = nn.Linear(d_in, d_out, bias=True)
z = projection(h)  # 等价于 z = h @ W.T + b

数学形式:
z_j = b_j + \sum_{i=1}^{d_{in}} W_{ji} h_i

3. 特殊变体

降维投影 (d_{out} < d_{in})

常用于特征压缩:
W \in \mathbb{R}^{128 \times 1024}, \quad h \in \mathbb{R}^{1024} \Rightarrow z \in \mathbb{R}^{128}

升维投影 (d_{out} > d_{in})

用于特征扩展:
W \in \mathbb{R}^{512 \times 256}, \quad h \in \mathbb{R}^{256} \Rightarrow z \in \mathbb{R}^{512}

与复杂投影层的对比

特性Simple ProjectionComplex Projection (e.g., MLP)
结构单线性层多层网络 (Linear→Activation→Linear)
非线性能力无(仅仿射变换)有(通过激活函数引入)
参数量d_{in} \times d_{out}显著更多(与隐藏层维度相关)
典型应用场景FitNets式特征对齐对比学习(如CRD)、跨模态蒸馏
计算开销较高

为什么称为"Simple"?

  1. 无隐藏层​:直接进行输入到输出的映射
  2. 无激活函数​:不引入非线性(ReLU/GELU等)
  3. 参数效率高​:仅需学习 W 和(可选的)b

经典应用示例

FitNets中的适配层

# 学生特征维度512 → 匹配教师特征维度2048
adaptor = nn.Linear(512, 2048)  # Simple Projection
loss = F.mse_loss(adaptor(student_feat), teacher_feat.detach())

可视化运算过程

假设输入 $h = [1.0, 2.0, 3.0]$ , $W = \begin{bmatrix}0.1 & 0.2 & 0.3\0.4 & 0.5 & 0.6\end{bmatrix}$, $b = [0.1, 0.2]$

则:

$$z = \begin{bmatrix} 0.1 \times 1.0 + 0.2 \times 2.0 + 0.3 \times 3.0 + 0.1 \ 0.4 \times 1.0 + 0.5 \times 2.0 + 0.6 \times 3.0 + 0.2 \end{bmatrix} = \begin{bmatrix}1.5 \ 3.2\end{bmatrix}$$

总结

Simple Projection Layers的本质是可学习的线性变换矩阵,通过矩阵乘法实现特征空间的维度调整或旋转/缩放,是知识蒸馏中最基础但广泛有效的特征对齐工具。

💬 评论