1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| import torch.nn as nn
class TResNetBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(out_ch, out_ch // 16, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch // 16, out_ch, 1),
nn.Sigmoid()
)
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride, bias=False)
def forward(self, x):
out = self.conv1(x).relu()
out = self.conv2(out)
out = out * self.se(out)
return (out + self.shortcut(x)).relu()
|
Comments