NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST

目录

5.4 基于残差网络的手写体数字识别实验

5.4.1 模型构建

5.4.1.1 残差单元

 5.4.1.2 残差网络的整体结构

5.4.2 没有残差连接的ResNet18

5.4.2.1 模型训练

 5.4.2.2 模型评价

5.4.3 带残差连接的ResNet18

5.4.3.1 模型训练

 5.4.3.2 模型评价

 5.4.4 与高层API实现版本的对比实验


5.4 基于残差网络的手写体数字识别实验

残差网络(Residual Network,ResNet)是在神经网络模型中给非线性层增加直连边的方式来缓解梯度消失问题,从而使训练深度神经网络变得更加容易。

在残差网络中,最基本的单位为残差单元

假设 f(x;\theta )为一个或多个神经层,残差单元在f()的输入和输出之间加上一个直连边

不同于传统网络结构中让网络f(x;\theta )去逼近一个目标函数h(),在残差网络中,将目标函数h()拆为了两个部分:恒等函数x和残差函数h()-x

ResBlock_{f}(x)=f(x;\theta )+x,(5.22)   

其中\theta为可学习的参数。

一个典型的残差单元如图5.14所示,由多个级联的卷积层和一个跨层的直连边组成。

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第1张图片

 一个残差网络通常有很多个残差单元堆叠而成。下面我们来构建一个在计算机视觉中非常典型的残差网络:ResNet18,并重复上一节中的手写体数字识别任务。

5.4.1 模型构建

在本节中,我们先构建ResNet18的残差单元,然后在组建完整的网络。

5.4.1.1 残差单元

这里,我们实现一个算子ResBlock来构建残差单元,其中定义了use_residual参数,用于在后续实验中控制是否使用残差连接。

残差单元包裹的非线性层的输入和输出形状大小应该一致。如果一个卷积层的输入特征图和输出特征图的通道数不一致,则其输出与输入特征图无法直接相加。为了解决上述问题,我们可以使用1×1大小的卷积将输入特征图的通道数映射为与级联卷积输出特征图的一致通道数。

1×1卷积:与标准卷积完全一样,唯一的特殊点在于卷积核的尺寸是1×1,也就是不去考虑输入数据局部信息之间的关系,而把关注点放在不同通道间。通过使用1×1卷积,可以起到如下作用:

  • 实现信息的跨通道交互与整合。考虑到卷积运算的输入输出都是3个维度(宽、高、多通道),所以1×1卷积实际上就是对每个像素点,在不同的通道上进行线性组合,从而整合不同通道的信息;
  • 对卷积核通道数进行降维和升维,减少参数量。经过1×1卷积后的输出保留了输入数据的原有平面结构,通过调控通道数,从而完成升维或降维的作用;
  • 利用1×1卷积后的非线性激活函数,在保持特征图尺寸不变的前提下,大幅增加非线性。
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_residual=True):
        """
        残差单元
        输入:
            - in_channels:输入通道数
            - out_channels:输出通道数
            - stride:残差单元的步长,通过调整残差单元中第一个卷积层的步长来控制
            - use_residual:用于控制是否使用残差连接
        """
        super(ResBlock, self).__init__()
        self.stride = stride
        self.use_residual = use_residual
        # 第一个卷积层,卷积核大小为3×3,可以设置不同输出通道数以及步长
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=self.stride, bias=False)
        # 第二个卷积层,卷积核大小为3×3,不改变输入特征图的形状,步长为1
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
 
        # 如果conv2的输出和此残差块的输入数据形状不一致,则use_1x1conv = True
        # 当use_1x1conv = True,添加1个1x1的卷积作用在输入数据上,使其形状变成跟conv2一致
        if in_channels != out_channels or stride != 1:
            self.use_1x1conv = True
        else:
            self.use_1x1conv = False
        # 当残差单元包裹的非线性层输入和输出通道数不一致时,需要用1×1卷积调整通道数后再进行相加运算
        if self.use_1x1conv:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=self.stride, bias=False)
 
        # 每个卷积层后会接一个批量规范化层,批量规范化的内容在7.5.1中会进行详细介绍
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if self.use_1x1conv:
            self.bn3 = nn.BatchNorm2d(out_channels)
 
    def forward(self, inputs):
        y = F.relu(self.bn1(self.conv1(inputs)))
        y = self.bn2(self.conv2(y))
        if self.use_residual:
            if self.use_1x1conv:  # 如果为真,对inputs进行1×1卷积,将形状调整成跟conv2的输出y一致
                shortcut = self.shortcut(inputs)
                shortcut = self.bn3(shortcut)
            else:  # 否则直接将inputs和conv2的输出y相加
                shortcut = inputs
            y = torch.add(shortcut, y)
        out = F.relu(y)
        return out

 5.4.1.2 残差网络的整体结构

残差网络就是将很多个残差单元串联起来构成的一个非常深的网络。ResNet18 的网络结构如图5.16所示。

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第2张图片

其中为了便于理解,可以将ResNet18网络划分为6个模块:

  • 第一模块:包含了一个步长为2,大小为7×7的卷积层,卷积层的输出通道数为64,卷积层的输出经过批量归一化、ReLU激活函数的处理后,接了一个步长为2的3×3的最大汇聚层;
  • 第二模块:包含了两个残差单元,经过运算后,输出通道数为64,特征图的尺寸保持不变;
  • 第三模块:包含了两个残差单元,经过运算后,输出通道数为128,特征图的尺寸缩小一半;
  • 第四模块:包含了两个残差单元,经过运算后,输出通道数为256,特征图的尺寸缩小一半;
  • 第五模块:包含了两个残差单元,经过运算后,输出通道数为512,特征图的尺寸缩小一半;
  • 第六模块:包含了一个全局平均汇聚层,将特征图变为1×1的大小,最终经过全连接层计算出最后的输出。

ResNet18模型的代码实现如下:

定义模块一。

def make_first_module(in_channels):
    # 模块一:7*7卷积、批量规范化、汇聚
    m1 = nn.Sequential(nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
                    nn.BatchNorm2d(64), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    return m1

  定义模块二到模块五。

def resnet_module(input_channels, out_channels, num_res_blocks, stride=1, use_residual=True):
    blk = []
    # 根据num_res_blocks,循环生成残差单元
    for i in range(num_res_blocks):
        if i == 0: # 创建模块中的第一个残差单元
            blk.append(ResBlock(input_channels, out_channels,
                                stride=stride, use_residual=use_residual))
        else:      # 创建模块中的其他残差单元
            blk.append(ResBlock(out_channels, out_channels, use_residual=use_residual))
    return blk

封装模块二到模块五。

def make_modules(use_residual):
    # 模块二:包含两个残差单元,输入通道数为64,输出通道数为64,步长为1,特征图大小保持不变
    m2 = nn.Sequential(*resnet_module(64, 64, 2, stride=1, use_residual=use_residual))
    # 模块三:包含两个残差单元,输入通道数为64,输出通道数为128,步长为2,特征图大小缩小一半。
    m3 = nn.Sequential(*resnet_module(64, 128, 2, stride=2, use_residual=use_residual))
    # 模块四:包含两个残差单元,输入通道数为128,输出通道数为256,步长为2,特征图大小缩小一半。
    m4 = nn.Sequential(*resnet_module(128, 256, 2, stride=2, use_residual=use_residual))
    # 模块五:包含两个残差单元,输入通道数为256,输出通道数为512,步长为2,特征图大小缩小一半。
    m5 = nn.Sequential(*resnet_module(256, 512, 2, stride=2, use_residual=use_residual))
    return m2, m3, m4, m5

定义完整网络。

class Model_ResNet18(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, use_residual=True):
        super(Model_ResNet18,self).__init__()
        m1 = make_first_module(in_channels)
        m2, m3, m4, m5 = make_modules(use_residual)
        # 封装模块一到模块6
        self.net = nn.Sequential(m1, m2, m3, m4, m5,
                        # 模块六:汇聚层、全连接层
                        nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, num_classes) )
 
    def forward(self, x):
        return self.net(x)

这里同样可以使用paddle.summary统计模型的参数量。

from torchsummary import summary
 
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True)
params_info = summary(model,input_size=(1,64,32))
print(params_info)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 16, 16]           3,200
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,928
       BatchNorm2d-6             [-1, 64, 8, 8]             128
            Conv2d-7             [-1, 64, 8, 8]          36,928
       BatchNorm2d-8             [-1, 64, 8, 8]             128
          ResBlock-9             [-1, 64, 8, 8]               0
           Conv2d-10             [-1, 64, 8, 8]          36,928
      BatchNorm2d-11             [-1, 64, 8, 8]             128
           Conv2d-12             [-1, 64, 8, 8]          36,928
      BatchNorm2d-13             [-1, 64, 8, 8]             128
         ResBlock-14             [-1, 64, 8, 8]               0
           Conv2d-15            [-1, 128, 4, 4]          73,856
      BatchNorm2d-16            [-1, 128, 4, 4]             256
           Conv2d-17            [-1, 128, 4, 4]         147,584
      BatchNorm2d-18            [-1, 128, 4, 4]             256
           Conv2d-19            [-1, 128, 4, 4]           8,320
      BatchNorm2d-20            [-1, 128, 4, 4]             256
         ResBlock-21            [-1, 128, 4, 4]               0
           Conv2d-22            [-1, 128, 4, 4]         147,584
      BatchNorm2d-23            [-1, 128, 4, 4]             256
           Conv2d-24            [-1, 128, 4, 4]         147,584
      BatchNorm2d-25            [-1, 128, 4, 4]             256
         ResBlock-26            [-1, 128, 4, 4]               0
           Conv2d-27            [-1, 256, 2, 2]         295,168
      BatchNorm2d-28            [-1, 256, 2, 2]             512
           Conv2d-29            [-1, 256, 2, 2]         590,080
      BatchNorm2d-30            [-1, 256, 2, 2]             512
           Conv2d-31            [-1, 256, 2, 2]          33,024
      BatchNorm2d-32            [-1, 256, 2, 2]             512
         ResBlock-33            [-1, 256, 2, 2]               0
           Conv2d-34            [-1, 256, 2, 2]         590,080
      BatchNorm2d-35            [-1, 256, 2, 2]             512
           Conv2d-36            [-1, 256, 2, 2]         590,080
      BatchNorm2d-37            [-1, 256, 2, 2]             512
         ResBlock-38            [-1, 256, 2, 2]               0
           Conv2d-39            [-1, 512, 1, 1]       1,180,160
      BatchNorm2d-40            [-1, 512, 1, 1]           1,024
           Conv2d-41            [-1, 512, 1, 1]       2,359,808
      BatchNorm2d-42            [-1, 512, 1, 1]           1,024
           Conv2d-43            [-1, 512, 1, 1]         131,584
      BatchNorm2d-44            [-1, 512, 1, 1]           1,024
         ResBlock-45            [-1, 512, 1, 1]               0
           Conv2d-46            [-1, 512, 1, 1]       2,359,808
      BatchNorm2d-47            [-1, 512, 1, 1]           1,024
           Conv2d-48            [-1, 512, 1, 1]       2,359,808
      BatchNorm2d-49            [-1, 512, 1, 1]           1,024
         ResBlock-50            [-1, 512, 1, 1]               0
AdaptiveAvgPool2d-51            [-1, 512, 1, 1]               0
          Flatten-52                  [-1, 512]               0
           Linear-53                   [-1, 10]           5,130
================================================================
Total params: 11,180,170
Trainable params: 11,180,170
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.05
Params size (MB): 42.65
Estimated Total Size (MB): 43.71
----------------------------------------------------------------

Process finished with exit code 0

使用torchstat统计模型的计算量。 

from torchstat import stat
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True)
stat(model, (1, 32, 32))
[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
[MAdd]: Flatten is not supported!
[Flops]: Flatten is not supported!
[Memory]: Flatten is not supported!
            module name  input shape output shape      params memory(MB)          MAdd         Flops  MemRead(B)  MemWrite(B) duration[%]   MemR+W(B)
0               net.0.0    1  32  32   64  16  16      3200.0       0.06   1,605,632.0     819,200.0     16896.0      65536.0      47.93%     82432.0
1               net.0.1   64  16  16   64  16  16       128.0       0.06      65,536.0      32,768.0     66048.0      65536.0       6.51%    131584.0
2               net.0.2   64  16  16   64  16  16         0.0       0.06      16,384.0      16,384.0     65536.0      65536.0       1.30%    131072.0
3               net.0.3   64  16  16   64   8   8         0.0       0.02      32,768.0      16,384.0     65536.0      16384.0       1.30%     81920.0
4         net.1.0.conv1   64   8   8   64   8   8     36928.0       0.02   4,718,592.0   2,363,392.0    164096.0      16384.0      18.27%    180480.0
5         net.1.0.conv2   64   8   8   64   8   8     36928.0       0.02   4,718,592.0   2,363,392.0    164096.0      16384.0       0.00%    180480.0
6           net.1.0.bn1   64   8   8   64   8   8       128.0       0.02      16,384.0       8,192.0     16896.0      16384.0       0.00%     33280.0
7           net.1.0.bn2   64   8   8   64   8   8       128.0       0.02      16,384.0       8,192.0     16896.0      16384.0       1.31%     33280.0
8         net.1.1.conv1   64   8   8   64   8   8     36928.0       0.02   4,718,592.0   2,363,392.0    164096.0      16384.0       0.00%    180480.0
9         net.1.1.conv2   64   8   8   64   8   8     36928.0       0.02   4,718,592.0   2,363,392.0    164096.0      16384.0       0.00%    180480.0
10          net.1.1.bn1   64   8   8   64   8   8       128.0       0.02      16,384.0       8,192.0     16896.0      16384.0       0.00%     33280.0
11          net.1.1.bn2   64   8   8   64   8   8       128.0       0.02      16,384.0       8,192.0     16896.0      16384.0       0.00%     33280.0
12        net.2.0.conv1   64   8   8  128   4   4     73856.0       0.01   2,359,296.0   1,181,696.0    311808.0       8192.0       0.00%    320000.0
13        net.2.0.conv2  128   4   4  128   4   4    147584.0       0.01   4,718,592.0   2,361,344.0    598528.0       8192.0       2.60%    606720.0
14     net.2.0.shortcut   64   8   8  128   4   4      8320.0       0.01     262,144.0     133,120.0     49664.0       8192.0       0.00%     57856.0
15          net.2.0.bn1  128   4   4  128   4   4       256.0       0.01       8,192.0       4,096.0      9216.0       8192.0       0.00%     17408.0
16          net.2.0.bn2  128   4   4  128   4   4       256.0       0.01       8,192.0       4,096.0      9216.0       8192.0       0.00%     17408.0
17          net.2.0.bn3  128   4   4  128   4   4       256.0       0.01       8,192.0       4,096.0      9216.0       8192.0       0.00%     17408.0
18        net.2.1.conv1  128   4   4  128   4   4    147584.0       0.01   4,718,592.0   2,361,344.0    598528.0       8192.0       0.00%    606720.0
19        net.2.1.conv2  128   4   4  128   4   4    147584.0       0.01   4,718,592.0   2,361,344.0    598528.0       8192.0       0.00%    606720.0
20          net.2.1.bn1  128   4   4  128   4   4       256.0       0.01       8,192.0       4,096.0      9216.0       8192.0       0.00%     17408.0
21          net.2.1.bn2  128   4   4  128   4   4       256.0       0.01       8,192.0       4,096.0      9216.0       8192.0       0.00%     17408.0
22        net.3.0.conv1  128   4   4  256   2   2    295168.0       0.00   2,359,296.0   1,180,672.0   1188864.0       4096.0       1.30%   1192960.0
23        net.3.0.conv2  256   2   2  256   2   2    590080.0       0.00   4,718,592.0   2,360,320.0   2364416.0       4096.0       4.47%   2368512.0
24     net.3.0.shortcut  128   4   4  256   2   2     33024.0       0.00     262,144.0     132,096.0    140288.0       4096.0       0.00%    144384.0
25          net.3.0.bn1  256   2   2  256   2   2       512.0       0.00       4,096.0       2,048.0      6144.0       4096.0       0.00%     10240.0
26          net.3.0.bn2  256   2   2  256   2   2       512.0       0.00       4,096.0       2,048.0      6144.0       4096.0       0.00%     10240.0
27          net.3.0.bn3  256   2   2  256   2   2       512.0       0.00       4,096.0       2,048.0      6144.0       4096.0       0.00%     10240.0
28        net.3.1.conv1  256   2   2  256   2   2    590080.0       0.00   4,718,592.0   2,360,320.0   2364416.0       4096.0       1.31%   2368512.0
29        net.3.1.conv2  256   2   2  256   2   2    590080.0       0.00   4,718,592.0   2,360,320.0   2364416.0       4096.0       0.00%   2368512.0
30          net.3.1.bn1  256   2   2  256   2   2       512.0       0.00       4,096.0       2,048.0      6144.0       4096.0       0.00%     10240.0
31          net.3.1.bn2  256   2   2  256   2   2       512.0       0.00       4,096.0       2,048.0      6144.0       4096.0       0.00%     10240.0
32        net.4.0.conv1  256   2   2  512   1   1   1180160.0       0.00   2,359,296.0   1,180,160.0   4724736.0       2048.0       3.29%   4726784.0
33        net.4.0.conv2  512   1   1  512   1   1   2359808.0       0.00   4,718,592.0   2,359,808.0   9441280.0       2048.0       1.31%   9443328.0
34     net.4.0.shortcut  256   2   2  512   1   1    131584.0       0.00     262,144.0     131,584.0    530432.0       2048.0       0.00%    532480.0
35          net.4.0.bn1  512   1   1  512   1   1      1024.0       0.00       2,048.0       1,024.0      6144.0       2048.0       0.00%      8192.0
36          net.4.0.bn2  512   1   1  512   1   1      1024.0       0.00       2,048.0       1,024.0      6144.0       2048.0       0.00%      8192.0
37          net.4.0.bn3  512   1   1  512   1   1      1024.0       0.00       2,048.0       1,024.0      6144.0       2048.0       0.00%      8192.0
38        net.4.1.conv1  512   1   1  512   1   1   2359808.0       0.00   4,718,592.0   2,359,808.0   9441280.0       2048.0       1.30%   9443328.0
39        net.4.1.conv2  512   1   1  512   1   1   2359808.0       0.00   4,718,592.0   2,359,808.0   9441280.0       2048.0       1.30%   9443328.0
40          net.4.1.bn1  512   1   1  512   1   1      1024.0       0.00       2,048.0       1,024.0      6144.0       2048.0       0.00%      8192.0
41          net.4.1.bn2  512   1   1  512   1   1      1024.0       0.00       2,048.0       1,024.0      6144.0       2048.0       0.00%      8192.0
42                net.5  512   1   1  512   1   1         0.0       0.00           0.0           0.0         0.0          0.0       5.20%         0.0
43                net.6  512   1   1          512         0.0       0.00           0.0           0.0         0.0          0.0       0.00%         0.0
44                net.7          512           10      5130.0       0.00      10,230.0       5,120.0     22568.0         40.0       1.30%     22608.0
total                                              11180170.0       0.47  71,073,782.0  35,595,776.0     22568.0         40.0     100.00%  45714000.0
=====================================================================================================================================================
Total params: 11,180,170
-----------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 0.47MB
Total MAdd: 71.07MMAdd
Total Flops: 35.6MFlops
Total MemR+W: 43.6MB


Process finished with exit code 0

 为了验证残差连接对深层卷积神经网络的训练可以起到促进作用,接下来先使用ResNet18(use_residual设置为False)进行手写数字识别实验,再添加残差连接(use_residual设置为True),观察实验对比效果。

5.4.2 没有残差连接的ResNet18

为了验证残差连接的效果,先使用没有残差连接的ResNet18进行实验。

5.4.2.1 模型训练

使用训练集和验证集进行模型训练,共训练5个epoch。在实验中,保存准确率最高的模型作为最佳模型。代码实现如下

from nndl import plot
 
# 学习率大小
lr = 0.005
# 批次大小
batch_size = 64
# 加载数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# 定义网络,不使用残差结构的深层网络
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=False)
# 定义优化器
optimizer = opt.SGD(model.parameters(), lr)
loss_fn = F.cross_entropy
# 定义评价指标
metric = Accuracy()
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 15
eval_steps = 15
runner.train(train_loader, dev_loader, num_epochs=5, log_steps=log_steps,
             eval_steps=eval_steps, save_path="best_model.pdparams")
# 可视化观察训练集与验证集的Loss变化情况
plot(runner, 'cnn-loss2.pdf')
[Train] epoch: 0/5, step: 0/80, loss: 2.36940
[Train] epoch: 0/5, step: 15/80, loss: 1.13062
[Evaluate]  dev score: 0.11000, dev loss: 2.30928
[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.11000
[Train] epoch: 1/5, step: 30/80, loss: 0.35934
[Evaluate]  dev score: 0.13000, dev loss: 2.31319
[Evaluate] best accuracy performence has been updated: 0.11000 --> 0.13000
[Train] epoch: 2/5, step: 45/80, loss: 0.27394
[Evaluate]  dev score: 0.71000, dev loss: 1.33839
[Evaluate] best accuracy performence has been updated: 0.13000 --> 0.71000
[Train] epoch: 3/5, step: 60/80, loss: 0.11306
[Evaluate]  dev score: 0.90000, dev loss: 0.42154
[Evaluate] best accuracy performence has been updated: 0.71000 --> 0.90000
[Train] epoch: 4/5, step: 75/80, loss: 0.06916
[Evaluate]  dev score: 0.90000, dev loss: 0.33076
[Evaluate]  dev score: 0.91000, dev loss: 0.31727
[Evaluate] best accuracy performence has been updated: 0.90000 --> 0.91000
[Train] Training done!

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第3张图片

 5.4.2.2 模型评价

使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在测试集上的准确率以及损失情况。代码实现如下

# 加载最优模型
runner.load_model('best_model.pdparams')
# 模型评价
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第4张图片

 从输出结果看,对比LeNet-5模型评价实验结果,网络层级加深后,训练效果不升反降。

5.4.3 带残差连接的ResNet18

5.4.3.1 模型训练

使用带残差连接的ResNet18重复上面的实验,代码实现如下:

# 学习率大小
lr = 0.01
# 批次大小
batch_size = 64
# 加载数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# 定义网络,不使用残差结构的深层网络
model = Model_ResNet18(in_channels=1, num_classes=10, use_residual=True)
# 定义优化器
optimizer = opt.SGD(model.parameters(), lr)
loss_fn = F.cross_entropy
# 定义评价指标
metric = Accuracy()
# 实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
# 启动训练
log_steps = 15
eval_steps = 15
runner.train(train_loader, dev_loader, num_epochs=5, log_steps=log_steps,
             eval_steps=eval_steps, save_path="best_model.pdparams")
# 可视化观察训练集与验证集的Loss变化情况
plot(runner, 'cnn-loss2.pdf')
[Train] epoch: 0/5, step: 0/80, loss: 2.55280
[Train] epoch: 0/5, step: 15/80, loss: 0.41595
[Evaluate]  dev score: 0.23000, dev loss: 2.22856
[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.23000
[Train] epoch: 1/5, step: 30/80, loss: 0.13791
[Evaluate]  dev score: 0.56000, dev loss: 1.56380
[Evaluate] best accuracy performence has been updated: 0.23000 --> 0.56000
[Train] epoch: 2/5, step: 45/80, loss: 0.02838
[Evaluate]  dev score: 0.91000, dev loss: 0.34748
[Evaluate] best accuracy performence has been updated: 0.56000 --> 0.91000
[Train] epoch: 3/5, step: 60/80, loss: 0.01100
[Evaluate]  dev score: 0.92500, dev loss: 0.19308
[Evaluate] best accuracy performence has been updated: 0.91000 --> 0.92500
[Train] epoch: 4/5, step: 75/80, loss: 0.00627
[Evaluate]  dev score: 0.94500, dev loss: 0.16854
[Evaluate] best accuracy performence has been updated: 0.92500 --> 0.94500
[Evaluate]  dev score: 0.94000, dev loss: 0.15962
[Train] Training done!

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第5张图片

 5.4.3.2 模型评价

使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在测试集上的准确率以及损失情况

# 加载最优模型
runner.load_model('best_model.pdparams')
# 模型评价
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))

添加了残差连接后,模型收敛曲线更平滑。
从输出结果看,和不使用残差连接的ResNet相比,添加了残差连接后,模型效果有了一定的提升。

 5.4.4 与高层API实现版本的对比实验

飞桨高层 API是对飞桨API的进一步封装与升级,提供了更加简洁易用的API,进一步提升了飞桨的易学易用性。

其中,飞桨高层API封装了以下模块:

  1. Model类,支持仅用几行代码完成模型的训练;
  2. 图像预处理模块,包含数十种数据处理函数,基本涵盖了常用的数据处理、数据增强方法;
  3. 计算机视觉领域和自然语言处理领域的常用模型,包括但不限于mobilenet、resnet、yolov3、cyclegan、bert、transformer、seq2seq等等,同时发布了对应模型的预训练模型,可以直接使用这些模型或者在此基础上完成二次开发。

飞桨高层 API主要包含在paddle.vision和paddle.text目录中。
对于Reset18这种比较经典的图像分类网络,飞桨高层API中都为大家提供了实现好的版本,大家可以不再从头开始实现。

这里为高层API版本的resnet18模型和自定义的resnet18模型赋予相同的权重,并使用相同的输入数据,观察输出结果是否一致。

from collections import OrderedDict
import warnings
 
warnings.filterwarnings("ignore")
 
# 使用飞桨HAPI中实现的resnet18模型,该模型默认输入通道数为3,输出类别数1000
hapi_model = resnet18()
# 自定义的resnet18模型
model = Model_ResNet18(in_channels=3, num_classes=1000, use_residual=True)
 
# 获取网络的权重
params = hapi_model.state_dict()
 
# 用来保存参数名映射后的网络权重
new_params = {}
# 将参数名进行映射
for key in params:
    if 'layer' in key:
        if 'downsample.0' in key:
            new_params['net.' + key[5:8] + '.shortcut' + key[-7:]] = params[key]
        elif 'downsample.1' in key:
            new_params['net.' + key[5:8] + '.bn3.' + key[22:]] = params[key]
        else:
            new_params['net.' + key[5:]] = params[key]
    elif 'conv1.weight' == key:
        new_params['net.0.0.weight'] = params[key]
    elif 'conv1.bias' == key:
        new_params['net.0.0.bias'] = params[key]
    elif 'bn1' in key:
        new_params['net.0.1' + key[3:]] = params[key]
    elif 'fc' in key:
        new_params['net.7' + key[2:]] = params[key]
    new_params['net.0.0.bias'] = torch.zeros([64])
# 将飞桨HAPI中实现的resnet18模型的权重参数赋予自定义的resnet18模型,保持两者一致
model.load_state_dict(OrderedDict(new_params))
 
# 这里用np.random创建一个随机数组作为测试数据
inputs = np.random.randn(*[3, 3, 32, 32])
inputs = inputs.astype('float32')
x = torch.tensor(inputs)
 
output = model(x)
hapi_out = hapi_model(x)
 
# 计算两个模型输出的差异
diff = output - hapi_out
# 取差异最大的值
max_diff = torch.max(diff)
print(max_diff)
tensor(0., grad_fn=)

可以看到,高层API版本的resnet18模型和自定义的resnet18模型输出结果是一致的,也就说明两个模型的实现完全一样。 

NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第6张图片

 NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第7张图片

 NNDL 实验六 卷积神经网络(4)ResNet18实现MNIST_第8张图片

ref:

https://blog.csdn.net/Chenzhinan1219/article/details/122042407

https://blog.csdn.net/qq_38975453/article/details/127167583

resnet18 — Torchvision 0.13 documentation (pytorch.org)

vision/resnet.py at main · pytorch/vision (github.com)

NNDL 实验5(上) - HBU_DAVID - 博客园 (cnblogs.com)

个人总结:

通过本次实验,我学习到了用ResNet18实现解决一些问题,通过实验对比清晰直观的感受到没有残差连接的ResNet18和带残差连接的ResNet18模型的区别,带有残差连接的ResNet18模型的模型效果有明显提示。还通过老师的链接和网络上的博客学习到了torchvision.models.resnet18()并在与高层API实现版本的对比实验进行使用,受益匪浅。
 

你可能感兴趣的:(cnn,深度学习)