1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| class InceptionV1Module(nn.Module):
def __init__(self, in_ch, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super().__init__()
self.branch1 = nn.Conv2d(in_ch, ch1x1, 1)
self.branch2 = nn.Sequential(
nn.Conv2d(in_ch, ch3x3red, 1),
nn.Conv2d(ch3x3red, ch3x3, 3, padding=1)
)
self.branch3 = nn.Sequential(
nn.Conv2d(in_ch, ch5x5red, 1),
nn.Conv2d(ch5x5red, ch5x5, 5, padding=2)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_ch, pool_proj, 1)
)
def forward(self, x):
return torch.cat([
self.branch1(x), self.branch2(x),
self.branch3(x), self.branch4(x)
], dim=1)
|
Comments