TResNet 通过反混叠降采样、SpaceToDepth、Inplace-ABN 等优化,在相同 FLOPs 下显著提升 ResNet 精度。
主要创新
| 技术 | 作用 |
|---|---|
| Anti-Aliasing Downsampling | 平滑降采样,保留低频信息 |
| SpaceToDepth | 增加通道、减少分辨率,加速计算 |
| Inplace-ABN | 结合 BN 和激活,减少 50% 显存 |
| Optimized Stem | 小卷积核替代 7×7 大核 |
| SE 模块 | 通道注意力增强特征表达 |
性能对比
| 模型 | 参数 (M) | FLOPs (B) | Top-1 (%) |
|---|---|---|---|
| ResNet-50 | 25.6 | 4.1 | 76.0 |
| TResNet-M | 30.9 | 4.3 | 80.7 |
| TResNet-L | 65.1 | 8.4 | 82.1 |
| TResNet-XL | 88.0 | 12.4 | 83.2 |
TResNet-M 参数增 20%,精度提升 4.7%。
PyTorch 实现
import torch.nn as nn
class TResNetBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(out_ch, out_ch // 16, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch // 16, out_ch, 1),
nn.Sigmoid()
)
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride, bias=False)
def forward(self, x):
out = self.conv1(x).relu()
out = self.conv2(out)
out = out * self.se(out)
return (out + self.shortcut(x)).relu()适用场景
- 高效推理:GPU/TPU 优化
- 大规模图像分类:ImageNet、细粒度分类
- 实时应用:移动端/云端部署