1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| class ConvNeXtBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.dw_conv = nn.Conv2d(in_channels, in_channels, 7, padding=3, groups=in_channels)
self.norm = nn.LayerNorm(in_channels, eps=1e-6)
self.pw_conv1 = nn.Linear(in_channels, 4 * in_channels)
self.gelu = nn.GELU()
self.pw_conv2 = nn.Linear(4 * in_channels, in_channels)
def forward(self, x):
shortcut = x
x = self.dw_conv(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.pw_conv2(self.gelu(self.pw_conv1(x)))
x = x.permute(0, 3, 1, 2)
return shortcut + x
|
Comments