Resnet50 pytorch复现

Resnet50 pytorch复现

之前复现过Resnet18,今天复现与Resnet网络结构稍有不同的Resnet50

Resnet50 pytorch复现_第1张图片

Resnet50的基本结构是1x1卷积->3x3卷积->1x1卷积。而每一组卷积是这样的结构:卷积->BN->RELU组合而成。

Resnet50 pytorch复现_第2张图片

如果所示,上面左边的为Resnet18,34的残差结构,右边的是Resnet50 101 152的残差结构。从图中可以看出Resnet50 的 1x1的卷积->3x3卷积->1x1卷积结构中,第一个1x1的卷积是进行降维操作,再做3x3的卷积提取特征,再做1x1的卷积进行升维。与Unet的网络不同,这里的残差结构是add操作,即两组同样的shape的权值,对应位置的值进行相加,而Unet中concatenate的操作,是对应宽高的权值,进行通道叠加。

以下代码是基本结构的代码

import torch.nn as nn
import torch.onnx


class Block(nn.Module):
    def __init__(self, in_channels, channels, stride, downsample=None):
        super(Block, self).__init__()
        # 1x1的卷积降维操作
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels, kernel_size=(1, 1),
                               bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        # 3x3的卷积提取特征操作
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3),
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        # 1x1的卷积升维操作
        self.conv3 = nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=(1, 1),
                               bias=False)
        self.bn3 = nn.BatchNorm2d(channels * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        # 4组卷积层的头一层网络会做一次降采样
        if self.downsample is not None:
            self.dconv = nn.Conv2d(in_channels, channels * 4, stride=stride, kernel_size=(1, 1), bias=False)
            self.dbn = nn.BatchNorm2d(channels * 4)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.dconv(identity)
            identity = self.dbn(identity)

        out += identity
        out = self.relu(out)

        return out

以下是主体结构

class Resnet50(nn.Module):
    def __init__(self, num_classes):
        super(Resnet50, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # 对应第1组网络层,3*Resnet的基本结构
        self.conv64_1 = Block(64, 64, stride=1, downsample=True)
        self.conv64_2 = Block(256, 64, stride=1)
        self.conv64_3 = Block(256, 64, stride=1)
        # 对应第2组网络层,4*Resnet的基本结构
        self.conv128_1 = Block(256, 128, stride=2, downsample=True)
        self.conv128_2 = Block(128 * 4, 128, stride=1)
        self.conv128_3 = Block(128 * 4, 128, stride=1)
        self.conv128_4 = Block(128 * 4, 128, stride=1)
        # 对应第3组网络层,6*Resnet的基本结构
        self.conv256_1 = Block(512, 256, stride=2, downsample=True)
        self.conv256_2 = Block(256 * 4, 256, stride=1)
        self.conv256_3 = Block(256 * 4, 256, stride=1)
        self.conv256_4 = Block(256 * 4, 256, stride=1)
        self.conv256_5 = Block(256 * 4, 256, stride=1)
        self.conv256_6 = Block(256 * 4, 256, stride=1)
        # 对应第4组网络层,3*Resnet的基本结构
        self.conv512_1 = Block(1024, 512, stride=2, downsample=True)
        self.conv512_2 = Block(512 * 4, 512, stride=1)
        self.conv512_3 = Block(512 * 4, 512, stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv64_1(x)
        x = self.conv64_2(x)
        x = self.conv64_3(x)
        x = self.conv128_1(x)
        x = self.conv128_2(x)
        x = self.conv128_3(x)
        x = self.conv128_4(x)
        x = self.conv256_1(x)
        x = self.conv256_2(x)
        x = self.conv256_3(x)
        x = self.conv256_4(x)
        x = self.conv256_5(x)
        x = self.conv256_6(x)
        x = self.conv512_1(x)
        x = self.conv512_2(x)
        x = self.conv512_3(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


if __name__ == '__main__':
    model = Resnet50(1000)

    from torchsummary import summary

    print(summary(model, (3, 224, 224)))
    torch.onnx.export(model, torch.randn(1,3,224,224),r"Y:\code\python\202203\Resnet50.onnx")

下图是Resnet50的网络结构图

Resnet50 pytorch复现_第3张图片

你可能感兴趣的:(#,图像分类,python,深度学习,计算机视觉)