CBAM 通过通道注意力和空间注意力串联,轻量高效地增强 CNN 特征表达能力。
整体结构
输入特征图 依次经过通道注意力和空间注意力:
通道注意力(CAM)
捕捉全局通道关系,调整各通道重要性:
| 步骤 | 操作 |
|---|---|
| 1 | 全局平均池化 + 全局最大池化 |
| 2 | 共享 MLP 处理两个池化结果 |
| 3 | 相加后 Sigmoid 激活 |
| 4 | 与原特征相乘 |
空间注意力(SAM)
关注每个空间位置的重要性:
| 步骤 | 操作 |
|---|---|
| 1 | 通道维度平均池化 + 最大池化 |
| 2 | 拼接后经 7×7 卷积 |
| 3 | Sigmoid 激活 |
| 4 | 与输入特征相乘 |
PyTorch 实现
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels // ratio, bias=False),
nn.ReLU(),
nn.Linear(in_channels // ratio, in_channels, bias=False)
)
def forward(self, x):
avg_out = self.mlp(self.avg_pool(x).view(x.shape[0], -1))
max_out = self.mlp(self.max_pool(x).view(x.shape[0], -1))
return torch.sigmoid(avg_out + max_out).view(x.shape[0], x.shape[1], 1, 1) * x
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
return torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) * x
class CBAM(nn.Module):
def __init__(self, in_channels, ratio=16, kernel_size=7):
super().__init__()
self.channel_attention = ChannelAttention(in_channels, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
return self.spatial_attention(self.channel_attention(x))核心优势
| 优势 | 说明 |
|---|---|
| 轻量级 | 计算量小于 SE 模块 |
| 双维增强 | 同时利用通道和空间注意力 |
| 易集成 | 可无缝嵌入 ResNet、VGG 等主干网络 |