Pytorch入门教程(十):ResNet图片分类实战

1. 基本ResNet单元:

Pytorch入门教程(十):ResNet图片分类实战_第1张图片

import torch
from torch import nn
from torch.nn import functional as F


class Resnet(nn.Module):

    def __init__(self, ch_in, ch_out):
        super(Resnet, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        # 如果输出、输出维度不同,需转化后才能相加
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu((self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        return out

2. 18层残差网络

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )
        # 4 block [b, 64, h, w] => [b, 1024, h, w]
        self.blk1 = Resnet(64, 128)
        self.blk2 = Resnet(128, 256)
        self.blk3 = Resnet(256, 512)
        self.blk4 = Resnet(512, 1024)
        # 注意最后全连接层维度,进去之前需要先打平
        self.outlayer = nn.Linear(1024*32*32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))

        #
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        x = x.view(x.size(0), -1)  # 先打平,再进全连接

        x = self.outlayer(x)
        return x

3.  main()函数调用

将 CIFAR图片分类实战 中主程序27行,Lenet5()改为ResNet18()即可。

你可能感兴趣的:(Pytorch)