张芷铭的个人博客

模型迁移训练时,新增模块应零初始化,确保修改后的ckpt载入后对原任务输入产生相同输出。

核心原则

目标:保留模型原本能力,同时适配新任务。

方法:添加新模块而非修改原有结构,新增权重零初始化。

原理

设原模型参数$\theta$,新增模块参数$\theta_{new}$初始化为零:

  • 对于原任务输入$x$,输出$f(x; \theta) + f_{new}(x; \theta_{new}=0) = f(x; \theta)$
  • 新模块对原任务输出无影响,保证原有能力

实践示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import torch.nn as nn

# 原有模块
original_layer = nn.Linear(768, 768)

# 新增适配模块,零初始化
new_adapter = nn.Linear(768, 768)
nn.init.zeros_(new_adapter.weight)
nn.init.zeros_(new_adapter.bias)

# 前向传播
def forward(x):
    return original_layer(x) + new_adapter(x)  # 初始时new_adapter输出为零

迁移流程

  1. 加载预训练checkpoint
  2. 添加新模块,权重零初始化
  3. 冻结原有参数(可选)
  4. 在新任务数据上微调

Comments