张芷铭的个人博客

CBAM 通过通道注意力和空间注意力串联,轻量高效地增强 CNN 特征表达能力。

论文 | 代码

整体结构

输入特征图 $\mathbf{F} \in \mathbb{R}^{C \times H \times W}$ 依次经过通道注意力和空间注意力:

$$\mathbf{F’} = \text{CAM}(\mathbf{F}) \cdot \mathbf{F}$$

$$\mathbf{F’’} = \text{SAM}(\mathbf{F’}) \cdot \mathbf{F’}$$

通道注意力(CAM)

捕捉全局通道关系,调整各通道重要性:

$$\mathbf{M}c = \sigma(\text{MLP}(\mathbf{F}{\text{avg}}) + \text{MLP}(\mathbf{F}_{\text{max}}))$$

步骤操作
1全局平均池化 + 全局最大池化
2共享 MLP 处理两个池化结果
3相加后 Sigmoid 激活
4与原特征相乘

空间注意力(SAM)

关注每个空间位置的重要性:

$$\mathbf{M}s = \sigma(\text{Conv}{7 \times 7}([\mathbf{F}{\text{avg}}^s; \mathbf{F}{\text{max}}^s]))$$

步骤操作
1通道维度平均池化 + 最大池化
2拼接后经 7×7 卷积
3Sigmoid 激活
4与输入特征相乘

PyTorch 实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False)
        )

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x).view(x.shape[0], -1))
        max_out = self.mlp(self.max_pool(x).view(x.shape[0], -1))
        return torch.sigmoid(avg_out + max_out).view(x.shape[0], x.shape[1], 1, 1) * x

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        return torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) * x

class CBAM(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_attention = ChannelAttention(in_channels, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        return self.spatial_attention(self.channel_attention(x))

核心优势

优势说明
轻量级计算量小于 SE 模块
双维增强同时利用通道和空间注意力
易集成可无缝嵌入 ResNet、VGG 等主干网络

Comments