pytorch实战-基于ResNet神经网络

1. 数据集来源

这里利用的是cifar10数据集,具体数据集操作请参考我的另一篇博文基于LeNet网络结构实战:链接入口

2. 残差网络

作者发现,随着网络层数的增加,网络发生了退化(degradation)的现象:随着网络层数的增多,训练集loss逐渐下降,然后趋于饱和,当你再增加网络深度的话,训练集loss反而会增大。注意这并不是过拟合,因为在过拟合中训练loss是一直减小的。

当网络退化时,浅层网络能够达到比深层网络更好的训练效果,这时如果我们把低层的特征传到高层,那么效果应该至少不比浅层的网络效果差。

举例:如果一个VGG-100网络在第98层使用的是和VGG-16第14层一模一样的特征,那么VGG-100的效果应该会和VGG-16的效果相同。所以,我们可以在VGG-100的98层和14层之间添加一条直接映射(Identity Mapping)来达到此效果。

基于这种使用直接映射来连接网络不同层直接的思想,残差网络应运而生。

2.1 残差块

残差块结构图 如图1
pytorch实战-基于ResNet神经网络_第1张图片

2.2 残差块代码实现

# 残差块模型
class ResBlk(nn.Module):
    def __init__(self,ch_in,ch_out,stride=1):
        super(ResBlk,self).__init__()
        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=2,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)
        # 1.为了满足在汇聚点进行+操作,ch_out与输入图片的channel保持一样,才能进行+
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(ch_out)

        self.extra=nn.Sequential()
        # 2.不仅需要满足channel相同维度,还需要图片的h、w相同
        if ch_in!=ch_out:
            self.extra=nn.Sequential(
                # 这里主要是通过stride来使图像维度变化
                nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1),
                nn.BatchNorm2d(ch_out)
            )

对应的forward 实现代码:

 # 残差块实现
    def forward(self,x):
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        # 与shortcut相汇输出
        # 为了使经过shortcut的x:[b,ch_in,h,w]与进行两次卷积操作进行相加操作...
        #...经过shortcut的维度需要进行extra操作来更改
        out=out+self.extra(x)
        print('out:',out.shape)
        print('x:',x.shape)

        out=F.relu(out)
        return out

3. 完整代码实现

3.1 网络结构ResNet18

pytorch实战-基于ResNet神经网络_第2张图片

3.2 完整代码解释

(1)Resnet.py

# 定义残差网络结构

import torch

from torch import nn
from  torch.nn import functional  as F

# 残差块模型
class ResBlk(nn.Module):
    def __init__(self,ch_in,ch_out,stride=1):
        super(ResBlk,self).__init__()
        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=2,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)
        # 1.为了满足在汇聚点进行+操作,ch_out与输入图片的channel保持一样,才能进行+
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(ch_out)

        self.extra=nn.Sequential()
        # 2.不仅需要满足channel相同维度,还需要图片的h、w相同
        if ch_in!=ch_out:
            self.extra=nn.Sequential(
                # 这里主要是通过stride来使图像维度变化
                nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1),
                nn.BatchNorm2d(ch_out)
            )

    # 残差块实现
    def forward(self,x):
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        # 与shortcut相汇输出
        # 为了使经过shortcut的x:[b,ch_in,h,w]与进行两次卷积操作进行相加操作,经过shortcut的维度需要进行extra操作来更改
        out=out+self.extra(x)
        print('out:',out.shape)
        print('x:',x.shape)

        out=F.relu(out)
        return out

class ResNet18(nn.Module):

    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        # [b, 64, h, w] => [b, 128, h ,w]
        self.blk1 = ResBlk(64, 128, stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk2 = ResBlk(128, 256, stride=2)
        # # [b, 256, h, w] => [b, 512, h, w]
        self.blk3 = ResBlk(256, 512, stride=2)
        # # [b, 512, h, w] => [b, 1024, h, w]
        self.blk4 = ResBlk(512, 512, stride=2)

        self.outlayer = nn.Linear(512*1*1, 10)

    def forward(self, x):
        """

        :param x:
        :return:
        """
        x = F.relu(self.conv1(x))

        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)


        # print('after conv:', x.shape) #[b, 512, 2, 2]
        # [b, 512, h, w] => [b, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1, 1])
        # print('after pool:', x.shape)
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)


        return x


def main():
    # 小测试
    blk=ResBlk(64,128,stride=2)
    tmp=torch.randn(2,64,32,32)
    out=blk(tmp)
    print('block:',out.shape)


if __name__=='__main__':
    main()

(2)main.py

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from Resnet import ResNet18


# 导入数据集
def main():
    batchsz=128
    cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ]),download=True)
    cifar_train=DataLoader(cifar_train,batch_size=batchsz,shuffle=True)

    cifar_test=datasets.CIFAR10('cifar',False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ]),download=True)
    cifar_test=DataLoader(cifar_test,batch_size=batchsz,shuffle=True)

    device=torch.device('cuda')
    model=ResNet18().to(device)
    # 计算多分类的交叉熵
    criteon=nn.CrossEntropyLoss().to(device)
    optimizer=optim.Adam(model.parameters(),lr=1e-3)

    for epoch in range(10):
        for batchidx,(x,label) in enumerate(cifar_train):
            x,label=x.to(device),label.to(device)
            logits=model(x)

            loss=criteon(logits,label)

            # backpop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch,'loss:',loss.item())

    with torch.no_grad():
        # test
        total_correct = 0
        total_num = 0
        for x, label in cifar_test:
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            # [b, 10]
            logits = model(x)
            # [b]
            pred = logits.argmax(dim=1)
            # [b] vs [b] => scalar tensor
            correct = torch.eq(pred, label).float().sum().item()
            total_correct += correct
            total_num += x.size(0)
            # print(correct)

        acc = total_correct / total_num
        print(epoch, 'test acc:', acc)

if __name__=='__main__':
    main()

你可能感兴趣的:(杂七杂八记录)