嵌套表示学习(MRL)通过模块化、层次化结构组织表示空间,提升模型的泛化能力与可解释性。
核心思想
MRL 是一类以模块化结构组织表示空间的学习框架:
- 模块性:表示由独立功能单元组成
- 层次性:模块存在嵌套关系,低层捕获局部特征,高层组合形成抽象语义
- 可组合性:模块按需组合应对不同任务
- 稀疏激活:仅激活相关模块,提升效率
发展历程
| 时间 | 里程碑 |
|---|---|
| 2016 | 神经模块网络(NMN):VQA 任务分解为可组合模块 |
| 1991/2017 | MoE 架构:门控选择专家子网络 |
| 2017 | Capsule Network:部分-整体关系的向量表示 |
| 近年 | Switch Transformer、GLaM 等稀疏大模型 |
数学原理
模块化表示
其中 为模块函数, 为组合函数。
模块间可形成 DAG 结构:
门控机制
硬门控 实现稀疏激活。
损失函数
负载均衡损失防止所有样本只激活少数专家。
关键性质
| 性质 | 说明 |
|---|---|
| 可扩展性 | 新增模块无需重训整个系统 |
| 可解释性 | 模块可对应语义概念 |
| 计算效率 | 稀疏激活减少 FLOPs |
| 泛化能力 | 模块重组泛化到未见任务组合 |
应用场景
| 场景 | 说明 |
|---|---|
| 多任务学习 | 共享底层模块,高层任务专用 |
| 持续学习 | 新任务引入新模块,避免灾难性遗忘 |
| VQA | 问题解析为模块执行序列 |
| 大模型 | Switch Transformer 使用 MoE 架构 |
PyTorch 实现
class MRLMoE(nn.Module):
def __init__(self, num_experts, input_dim, hidden_dim, output_dim, k=2):
super().__init__()
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_experts)
])
self.gate = nn.Linear(input_dim, num_experts)
self.k = k
def forward(self, x):
gate_logits = self.gate(x)
topk_vals, topk_indices = torch.topk(gate_logits, self.k, dim=1)
topk_weights = F.softmax(topk_vals, dim=1)
expert_outputs = torch.stack([e(x) for e in self.experts], dim=1)
batch_idx = torch.arange(x.size(0)).unsqueeze(1)
selected = expert_outputs[batch_idx, topk_indices]
return torch.sum(topk_weights.unsqueeze(-1) * selected, dim=1)实践建议
- 模块粒度:从任务语义出发划分
- 路由策略:硬路由省计算但不可导,软路由可端到端训练
- 避免模块坍缩:多样性正则、专家 dropout
- 评估指标:监控模块激活熵、OOD 泛化性能