基于pytorch实现ResNet

基于pytorch框架 实现ResNet网络 (参考pytorch官方文档)

from torch import nn
import torch
import torchvision.models
from torch.utils.tensorboard import SummaryWriter

# 因为 ResNet 中只包含 3x3 和 1x1 卷积,下面定义这两种卷积
def con3x3(in_channel, out_channel, stride = 1 ):
    # 定义 3x3 卷积
    return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)

def con1x1(in_channel, out_channel, stride = 1):
    return  nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module): # resnet 18/34 拥有相同的 basicblock
    expansion = 1 # 最后一个卷积输出的通道和第一个卷积输出通道的比值

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        # 实现子 module:Basic Block
        super(BasicBlock, self).__init__()

        self.conv1 = con3x3(in_channel, out_channel, stride=stride)
        # self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)

        self.conv2 = con3x3(out_channel, out_channel)
        # self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.relu = nn.ReLU(inplace=True)

        self.downsample = downsample  # 是否执行下采样 一个 stage 结束之后要下采样 图片的 H 和 W 变为原来的 1/2
        self.stride = stride

    def forward(self, x):
        identity = x  # 跨层连接

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu(out) # 这里注意是先和identity相加,再relu

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        # 实现 Bottle Neck
        super(Bottleneck, self).__init__()

        self.conv1 = con1x1(in_channel, out_channel)
        # self.conv1 =nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)

        self.conv2 = con3x3(out_channel, out_channel, stride=stride)
        # self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)

        self.conv3 = con1x1(out_channel,out_channel * self.expansion)
        # self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU(inplace=True)

        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    """
    __init__
        block: 堆叠的基本模块[BasicBlock, Bottelneck]
        block_num: 基本模块堆叠个数
        num_classes: 全连接之后的类别个数
    """
    def __init__(self, block, block_num, num_classes=1000):
        super(ResNet, self).__init__()

        self.in_channel = 64 # conv1 的输入维度

        # conv1:经过一个7x7的conv和一个3x3的最大池化层,input:[224*224*3] output: [112*112*64]
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # conv2_x -> conv5_x
        self.layer1 = self._make_layer(block=block, channel=64, block_num=block_num[0], stride=1)
        self.layer2 = self._make_layer(block=block, channel=128, block_num=block_num[1], stride=2)
        self.layer3 = self._make_layer(block=block, channel=256, block_num=block_num[2], stride=2)
        self.layer4 = self._make_layer(block=block, channel=512, block_num=block_num[3], stride=2)

        # avgpool + 1000-fc
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, channel, block_num, stride=1):
        """
            block: 堆叠的基本模块
            channel: 每个stage中堆叠模块的第一个卷积的卷积核个数,对于 resnet 来说分别是:64,128,256,512
            block_num: 当期stage堆叠block个数
            stride: 默认卷积步长
        """
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                con1x1(self.in_channel, channel * block.expansion, stride=stride),
                # nn.Conv2d(in_channels=self.in_channel, out_channels=channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion),
            )
        layers = []
        layers.append(block(in_channel=self.in_channel, out_channel=channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion # 下一个 stage 的输入通道为当前stage输出通道的 expansion倍

        for _ in range(1, block_num):
            layers.append(block(in_channel=self.in_channel, out_channel=channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def resnet18(num_classes=10):
    return ResNet(block=BasicBlock, block_num=[2, 2, 2, 2], num_classes=num_classes)

def resnet34(num_classes=10):
    return ResNet(block=BasicBlock, block_num=[3, 4, 6, 3], num_classes=num_classes)

def resnet50(num_classes=10):
    return ResNet(block=Bottleneck, block_num=[3, 4, 6, 3], num_classes=num_classes)

def resnet101(num_classes=10):
    return ResNet(block=Bottleneck, block_num=[3, 4, 23, 3], num_classes=num_classes)

def resnet52(num_classes=10):
    return ResNet(block=Bottleneck, block_num=[3, 8, 36, 3], num_classes=num_classes)

# 简单的测试当前代码的正确性,可以去掉备注使用
# if __name__ == '__main__':
#     input = torch.ones(1, 3, 224, 224)  # BatchSize C H W
#     print(input.shape)
#     Net = resnet50(10)
#     output = Net.forward(input)
#     print(Net)
#
# writer = SummaryWriter("log_net50")
# writer.add_graph(Net, input)
# writer.close()

你可能感兴趣的:(pytorch,深度学习,神经网络)