张芷铭的个人博客

嵌套表示学习(MRL)通过模块化、层次化结构组织表示空间,提升模型的泛化能力与可解释性。

核心思想

MRL 是一类以模块化结构组织表示空间的学习框架:

  • 模块性:表示由独立功能单元组成
  • 层次性:模块存在嵌套关系,低层捕获局部特征,高层组合形成抽象语义
  • 可组合性:模块按需组合应对不同任务
  • 稀疏激活:仅激活相关模块,提升效率

发展历程

时间里程碑
2016神经模块网络(NMN):VQA 任务分解为可组合模块
1991/2017MoE 架构:门控选择专家子网络
2017Capsule Network:部分-整体关系的向量表示
近年Switch Transformer、GLaM 等稀疏大模型

数学原理

模块化表示

$$f(x) = g(h_1(x), h_2(x), \dots, h_K(x))$$

其中 ${h_k}$ 为模块函数,$g$ 为组合函数。

模块间可形成 DAG 结构:

$$h_k(x) = \phi_k({h_j(x)}_{j \in \text{Pa}(k)})$$

门控机制

$$f(x) = \sum_{k=1}^K a_k(x) \cdot h_k(x)$$

硬门控 $a_k(x) \in {0,1}$ 实现稀疏激活。

损失函数

$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{task}} + \lambda \cdot \mathcal{L}_{\text{balance}}$$

负载均衡损失防止所有样本只激活少数专家。

关键性质

性质说明
可扩展性新增模块无需重训整个系统
可解释性模块可对应语义概念
计算效率稀疏激活减少 FLOPs
泛化能力模块重组泛化到未见任务组合

应用场景

场景说明
多任务学习共享底层模块,高层任务专用
持续学习新任务引入新模块,避免灾难性遗忘
VQA问题解析为模块执行序列
大模型Switch Transformer 使用 MoE 架构

PyTorch 实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class MRLMoE(nn.Module):
    def __init__(self, num_experts, input_dim, hidden_dim, output_dim, k=2):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])
        self.gate = nn.Linear(input_dim, num_experts)
        self.k = k

    def forward(self, x):
        gate_logits = self.gate(x)
        topk_vals, topk_indices = torch.topk(gate_logits, self.k, dim=1)
        topk_weights = F.softmax(topk_vals, dim=1)

        expert_outputs = torch.stack([e(x) for e in self.experts], dim=1)
        batch_idx = torch.arange(x.size(0)).unsqueeze(1)
        selected = expert_outputs[batch_idx, topk_indices]
        return torch.sum(topk_weights.unsqueeze(-1) * selected, dim=1)

实践建议

  • 模块粒度:从任务语义出发划分
  • 路由策略:硬路由省计算但不可导,软路由可端到端训练
  • 避免模块坍缩:多样性正则、专家 dropout
  • 评估指标:监控模块激活熵、OOD 泛化性能

Comments