import torch
from torch import nn
class idendity_mapping_block(nn.Module):
def __init__(self, input_channels, output_channels, use_1x1_conv=False):
super(idendity_mapping_block, self).__init__()
self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)
self.act = nn.SiLU(inplace=True)
if use_1x1_conv:
self.conv3 = nn.Conv2d(input_channels, output_channels, kernel_size=1)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
def forward(self, x):
y = self.act(self.bn1(self.conv1(x)))
y = self.act(self.bn2(self.conv2(y)))
# 如果有的话,改变一下通道,没有的话说明输入输出通道数一致,不需要改变
if self.conv3:
x = self.conv3(x)
y += x
return y
class ResNet50(nn.Module):
def __init__(self):
super(ResNet50, self).__init__()
self.layer1 = self.head()
self.layer2 = self.resnet_block(64, 64, 3, first_block=True)
self.layer3 = self.resnet_block(64, 128, 4)
self.layer4 = self.resnet_block(128, 256, 6)
self.layer5 = self.resnet_block(256, 512, 3)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.linear = nn.Linear(512, 3)
def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
x6 = self.avg(x5)
x7 = self.flatten(x6)
y = self.linear(x7)
return y
def resnet_block(self, in_channel, out_channel, num_block, first_block=False):
layer = []
for i in range(num_block):
if i == 0 and not first_block:
layer.append(idendity_mapping_block(in_channel, out_channel, use_1x1_conv=True))
else:
layer.append(idendity_mapping_block(in_channel, out_channel, use_1x1_conv=False))
in_channel = out_channel
return nn.Sequential(*layer)
def head(self):
return nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.SiLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
if __name__ == '__main__':
net = ResNet50()
x = torch.rand(4, 3, 224, 224)
for name, layer in net.named_children():
x = layer(x)
print(name, x.shape)