初始PyTorch(六+):ResNet18的网络结构

使用ResNet18网络结构,为了更好适配Cifar-10数据集【h*w=32*32】,所以不是完全按照ResNet18的参数写的。

下图是ResNet18的内部结构图。

初始PyTorch(六+):ResNet18的网络结构_第1张图片

先写内部结构:有两层的weight layer。

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(ResBlk, self).__init__()
        '''和标准不同的是,加了stride'''
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
'''extra()考虑ch_in!=ch_out的情况'''
        self.extra = nn.Sequential()
        if ch_out != ch_in:
            # [b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
    def forward(self, x):
        """param x: [b, ch, h, w]"""
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
'''extra()考虑x和out的tensor不一样,[b, ch_in, h, w] => [b, ch_out, h, w]'''
        # element-wise add:
        out = self.extra(x) + out
        out = F.relu(out)
        return out

接下来再写Resnet18的整体结构。除了一层的conv1,还有四个卷积block。

初始PyTorch(六+):ResNet18的网络结构_第2张图片

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        # [b, 64, h, w] => [b, 128, h ,w]
        self.blk1 = ResBlk(64, 128, stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk2 = ResBlk(128, 256, stride=2)
        # # [b, 256, h, w] => [b, 512, h, w]
        self.blk3 = ResBlk(256, 512, stride=2)
        # # [b, 512, h, w] => [b, 1024, h, w]
        self.blk4 = ResBlk(512, 512, stride=2)
        self.outlayer = nn.Linear(512*1*1, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        # print('after conv:', x.shape) #[b, 512, 2, 2]
        # [b, 512, h, w] => [b, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1, 1])
        # print('after pool:', x.shape)
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)
        return x

 

你可能感兴趣的:(#,初始PyTorch)