ConvNeXt 融合 Transformer 设计思想,现代化升级 CNN 架构,性能接近 ViT 同时保持高效计算。
核心特点
| 特点 | 说明 |
|---|
| 简化架构 | 去除最大池化层,优化残差块 |
| 现代设计 | LayerNorm、GELU、大卷积核 |
| 大卷积核 | 7×7 替代 3×3,扩大感受野 |
| 分层设计 | 四阶段特征提取,逐步降分辨率增通道 |
| 深度可分离卷积 | 提升计算效率 |
ConvNeXt 块
class ConvNeXtBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.dw_conv = nn.Conv2d(in_channels, in_channels, 7, padding=3, groups=in_channels)
self.norm = nn.LayerNorm(in_channels, eps=1e-6)
self.pw_conv1 = nn.Linear(in_channels, 4 * in_channels)
self.gelu = nn.GELU()
self.pw_conv2 = nn.Linear(4 * in_channels, in_channels)
def forward(self, x):
shortcut = x
x = self.dw_conv(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.pw_conv2(self.gelu(self.pw_conv1(x)))
x = x.permute(0, 3, 1, 2)
return shortcut + x
与 ViT 对比
| 特性 | ConvNeXt | ViT |
|---|
| 架构 | CNN | Transformer |
| 计算效率 | 更高 | 需更多资源 |
| 全局上下文 | 宽卷积核捕获 | 自注意力机制 |
| 数据需求 | 中小规模可行 | 需大规模预训练 |
核心创新
- LayerNorm 替代 BatchNorm,提升稳定性
- 倒置瓶颈设计(通道先扩展再压缩)
- 借鉴 Transformer 的分层设计
适用场景
- 图像分类:ImageNet 高准确率
- 目标检测/分割:Mask R-CNN 骨干网络
- 边缘部署:高效计算特性