嵌套表示学习(MRL)通过模块化、层次化结构组织表示空间,提升模型的泛化能力与可解释性。
核心思想
MRL 是一类以模块化结构组织表示空间的学习框架:
- 模块性:表示由独立功能单元组成
- 层次性:模块存在嵌套关系,低层捕获局部特征,高层组合形成抽象语义
- 可组合性:模块按需组合应对不同任务
- 稀疏激活:仅激活相关模块,提升效率
发展历程
| 时间 | 里程碑 |
|---|---|
| 2016 | 神经模块网络(NMN):VQA 任务分解为可组合模块 |
| 1991/2017 | MoE 架构:门控选择专家子网络 |
| 2017 | Capsule Network:部分-整体关系的向量表示 |
| 近年 | Switch Transformer、GLaM 等稀疏大模型 |
数学原理
模块化表示
$$f(x) = g(h_1(x), h_2(x), \dots, h_K(x))$$
其中 ${h_k}$ 为模块函数,$g$ 为组合函数。
模块间可形成 DAG 结构:
$$h_k(x) = \phi_k({h_j(x)}_{j \in \text{Pa}(k)})$$
门控机制
$$f(x) = \sum_{k=1}^K a_k(x) \cdot h_k(x)$$
硬门控 $a_k(x) \in {0,1}$ 实现稀疏激活。
损失函数
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{task}} + \lambda \cdot \mathcal{L}_{\text{balance}}$$
负载均衡损失防止所有样本只激活少数专家。
关键性质
| 性质 | 说明 |
|---|---|
| 可扩展性 | 新增模块无需重训整个系统 |
| 可解释性 | 模块可对应语义概念 |
| 计算效率 | 稀疏激活减少 FLOPs |
| 泛化能力 | 模块重组泛化到未见任务组合 |
应用场景
| 场景 | 说明 |
|---|---|
| 多任务学习 | 共享底层模块,高层任务专用 |
| 持续学习 | 新任务引入新模块,避免灾难性遗忘 |
| VQA | 问题解析为模块执行序列 |
| 大模型 | Switch Transformer 使用 MoE 架构 |
PyTorch 实现
| |
实践建议
- 模块粒度:从任务语义出发划分
- 路由策略:硬路由省计算但不可导,软路由可端到端训练
- 避免模块坍缩:多样性正则、专家 dropout
- 评估指标:监控模块激活熵、OOD 泛化性能
张芷铭的个人博客
Comments