模型迁移训练时,新增模块应零初始化,确保修改后的ckpt载入后对原任务输入产生相同输出。
核心原则
目标:保留模型原本能力,同时适配新任务。
方法:添加新模块而非修改原有结构,新增权重零初始化。
原理
设原模型参数,新增模块参数初始化为零:
- 对于原任务输入,输出
- 新模块对原任务输出无影响,保证原有能力
实践示例
import torch.nn as nn
# 原有模块
original_layer = nn.Linear(768, 768)
# 新增适配模块,零初始化
new_adapter = nn.Linear(768, 768)
nn.init.zeros_(new_adapter.weight)
nn.init.zeros_(new_adapter.bias)
# 前向传播
def forward(x):
return original_layer(x) + new_adapter(x) # 初始时new_adapter输出为零迁移流程
- 加载预训练checkpoint
- 添加新模块,权重零初始化
- 冻结原有参数(可选)
- 在新任务数据上微调