张芷铭的个人博客

TResNet

TResNet(Tightly-coupled Residual Network)TResNet

TResNet 是 Alibaba DAMO Academy 提出的一个高效的图像分类网络,旨在 提高ResNet的计算效率,同时保持高精度。相比于标准 ResNet,TResNet 采用了一系列优化策略,使其在相同 FLOPs(计算量)下取得更好的性能。

1. 主要创新点

TResNet 主要通过以下技术优化标准 ResNet:

1.1 Anti-Aliasing Downsampling(反混叠降采样)

• 传统 CNN 网络的池化层容易丢失信息,导致特征图变得粗糙,影响分类性能。

• TResNet 采用了 Anti-Aliasing Filters 进行平滑降采样,保留更多低频信息,使得特征图更加平滑,提高模型的泛化能力。

1.2 SpaceToDepth 模块

• SpaceToDepth 通过 在模型的输入端增加通道数,并减少空间分辨率,以此加速模型计算,同时减少内存占用。

• 这个操作相当于在输入时进行等效的降采样,提高计算效率。

1.3 In-Place Activated BatchNorm (Inplace-ABN)

• Inplace-ABN 结合了 Batch Normalization(BN)和 Activation Function(ReLU or Swish),减少了额外的内存消耗,提升计算效率。

• 在训练过程中 减少了近 50% 的显存占用,特别适用于大规模训练任务。

1.4 Optimized Stem Block

• ResNet 的 Stem 部分(最开始的卷积层)包含 7x7 大卷积核 + MaxPooling,计算开销较大。

• TResNet 用 多个小卷积核(3×3)堆叠 来代替大卷积核,并使用更高效的降采样策略,提高计算效率。

1.5 SE-ResNet 结构

• TResNet 采用 SE(Squeeze-and-Excitation)模块 来增强重要特征的表达能力。

• 通过全局池化的方式学习通道注意力权重,使得模型在计算开销较低的情况下提高准确率。

2. 网络结构

TResNet 主要基于 ResNet-50 及其变种进行优化,形成 TResNet-M / L / XL 这几个不同规模的版本,结构如下:

ModelParams (M)FLOPs (B)Top-1 Accuracy (%)
ResNet-5025.64.176.0
TResNet-M30.94.380.7
TResNet-L65.18.482.1
TResNet-XL88.012.483.2

TResNet-M: 轻量级版本,在 FLOPs 只增加 5% 的情况下,提升 4.7% 的 Top-1 精度。

TResNet-L / XL: 适用于更高精度任务,在 ImageNet 上的 Top-1 精度超越 EfficientNet 和 ResNeXt。

3. 代码实现

TResNet 的 PyTorch 代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import SpaceToDepth, InplaceAbn

class TResNetBlock(nn.Module):
    """基本残差块,带有SE模块"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(TResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = InplaceAbn(out_channels)  # 采用 Inplace-ABN
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = InplaceAbn(out_channels)

        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 16, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                InplaceAbn(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        se_weight = self.se(out)
        out = out * se_weight
        out += self.shortcut(x)
        return F.relu(out)

class TResNet(nn.Module):
    """TResNet 结构"""
    def __init__(self, num_classes=1000):
        super(TResNet, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
            InplaceAbn(64),
            SpaceToDepth(2)  # SpaceToDepth 处理
        )
        self.layer1 = self._make_layer(64, 128, stride=2)
        self.layer2 = self._make_layer(128, 256, stride=2)
        self.layer3 = self._make_layer(256, 512, stride=2)
        self.layer4 = self._make_layer(512, 1024, stride=2)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)

    def _make_layer(self, in_channels, out_channels, stride):
        return nn.Sequential(
            TResNetBlock(in_channels, out_channels, stride),
            TResNetBlock(out_channels, out_channels)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)

# 测试模型
model = TResNet(num_classes=1000)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))  # 计算参数量

4. 性能比较

TResNet 在 ImageNet 上的表现优于 ResNet,并且在高效计算下仍保持良好的精度:

ModelParams (M)FLOPs (B)Top-1 Acc (%)Top-5 Acc (%)
ResNet-5025.64.176.092.9
ResNet-10144.57.877.493.5
TResNet-M30.94.380.795.2
TResNet-L65.18.482.195.9
TResNet-XL88.012.483.296.4

TResNet-M 在参数量比 ResNet-50 多 20% 的情况下,Top-1 提升 4.7%

TResNet-XL 进一步优化,精度可达 83.2%,接近 EfficientNet 但计算量更小。

5. 适用场景

高效推理:适用于 GPU/TPU 计算优化,提升吞吐量。

大规模图像分类:适用于 ImageNet、Fine-grained 分类任务。

实时应用:适合移动端或云端部署,计算量小但性能强。

6. 结论

TResNet 通过一系列优化(反混叠降采样、SpaceToDepth、Inplace-ABN等),在保持高效计算的同时提升了 ResNet 的精度,是一个极具应用价值的 CNN 结构。

💬 评论