使用ResNet18网络结构,为了更好适配Cifar-10数据集【h*w=32*32】,所以不是完全按照ResNet18的参数写的。
下图是ResNet18的内部结构图。
先写内部结构:有两层的weight layer。
class ResBlk(nn.Module):
def __init__(self, ch_in, ch_out, stride=1):
super(ResBlk, self).__init__()
'''和标准不同的是,加了stride'''
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
'''extra()考虑ch_in!=ch_out的情况'''
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""param x: [b, ch, h, w]"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
'''extra()考虑x和out的tensor不一样,[b, ch_in, h, w] => [b, ch_out, h, w]'''
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
接下来再写Resnet18的整体结构。除了一层的conv1,还有四个卷积block。
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(64, 128, stride=2)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(128, 256, stride=2)
# # [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(256, 512, stride=2)
# # [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(512, 512, stride=2)
self.outlayer = nn.Linear(512*1*1, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print('after conv:', x.shape) #[b, 512, 2, 2]
# [b, 512, h, w] => [b, 512, 1, 1]
x = F.adaptive_avg_pool2d(x, [1, 1])
# print('after pool:', x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x