张芷铭的个人博客

TResNet 通过反混叠降采样、SpaceToDepth、Inplace-ABN 等优化,在相同 FLOPs 下显著提升 ResNet 精度。

主要创新

技术作用
Anti-Aliasing Downsampling平滑降采样,保留低频信息
SpaceToDepth增加通道、减少分辨率,加速计算
Inplace-ABN结合 BN 和激活,减少 50% 显存
Optimized Stem小卷积核替代 7×7 大核
SE 模块通道注意力增强特征表达

性能对比

模型参数 (M)FLOPs (B)Top-1 (%)
ResNet-5025.64.176.0
TResNet-M30.94.380.7
TResNet-L65.18.482.1
TResNet-XL88.012.483.2

TResNet-M 参数增 20%,精度提升 4.7%。

PyTorch 实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn

class TResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_ch, out_ch // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch // 16, out_ch, 1),
            nn.Sigmoid()
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride, bias=False)

    def forward(self, x):
        out = self.conv1(x).relu()
        out = self.conv2(out)
        out = out * self.se(out)
        return (out + self.shortcut(x)).relu()

适用场景

  • 高效推理:GPU/TPU 优化
  • 大规模图像分类:ImageNet、细粒度分类
  • 实时应用:移动端/云端部署

Comments