PyTorch 学习笔记(七):卷积神经网络案例分析——AlexNet、VGGNet、GoogLeNet、ResNet

我写这篇的目的主要是想熟悉一下PyTorch搭建模型的方法。

一. AlexNet

PyTorch 学习笔记(七):卷积神经网络案例分析——AlexNet、VGGNet、GoogLeNet、ResNet_第1张图片
五个卷积层加3个全连接层,话不多说,直接上代码:

import torch
from torch import nn
from torchstat import stat
class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()      # b, 3, 224, 224
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),   # b, 64, 55, 55
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2),    # b, 64, 27, 27

            nn.Conv2d(64, 192, kernel_size=5, padding=2),      # b, 192, 27, 27
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2),   # b, 192, 13, 13

            nn.Conv2d(192, 384, kernel_size=3, padding=1),   # b, 384, 13, 13
            nn.ReLU(True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),   # b, 256, 13, 13
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),   # b, 256, 13, 13
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2))    # b, 256, 6, 6
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, num_classes))
    def forward(self, x):
        x = self.features(x)
        print(x.size())
        x = x.view(x.size(0), 256*6*6)
        x = self.classifier(x)
        return x

model = AlexNet(10)
stat(model, (3, 224, 224))

使用stat模块对模型参数量和计算量进行估计,顺便也验证了模型是否正确,运行结果:

torch.Size([1, 256, 6, 6])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
        module name  input shape output shape      params memory(MB)             MAdd          Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64  55  55     23296.0       0.74    140,553,600.0   70,470,400.0     695296.0     774400.0      55.56%    1469696.0
1        features.1   64  55  55   64  55  55         0.0       0.74        193,600.0      193,600.0     774400.0     774400.0       0.00%    1548800.0
2        features.2   64  55  55   64  27  27         0.0       0.18        373,248.0      193,600.0     774400.0     186624.0       5.57%     961024.0
3        features.3   64  27  27  192  27  27    307392.0       0.53    447,897,600.0  224,088,768.0    1416192.0     559872.0      22.21%    1976064.0
4        features.4  192  27  27  192  27  27         0.0       0.53        139,968.0      139,968.0     559872.0     559872.0       0.00%    1119744.0
5        features.5  192  27  27  192  13  13         0.0       0.12        259,584.0      139,968.0     559872.0     129792.0       0.00%     689664.0
6        features.6  192  13  13  384  13  13    663936.0       0.25    224,280,576.0  112,205,184.0    2785536.0     259584.0       0.00%    3045120.0
7        features.7  384  13  13  384  13  13         0.0       0.25         64,896.0       64,896.0     259584.0     259584.0       0.00%     519168.0
8        features.8  384  13  13  256  13  13    884992.0       0.17    299,040,768.0  149,563,648.0    3799552.0     173056.0       5.56%    3972608.0
9        features.9  256  13  13  256  13  13         0.0       0.17         43,264.0       43,264.0     173056.0     173056.0       0.00%     346112.0
10      features.10  256  13  13  256  13  13    590080.0       0.17    199,360,512.0   99,723,520.0    2533376.0     173056.0       0.00%    2706432.0
11      features.11  256  13  13  256  13  13         0.0       0.17         43,264.0       43,264.0     173056.0     173056.0       0.00%     346112.0
12      features.12  256  13  13  256   6   6         0.0       0.04         73,728.0       43,264.0     173056.0      36864.0       0.00%     209920.0
13     classifier.0         9216         9216         0.0       0.04              0.0            0.0          0.0          0.0       0.00%          0.0
14     classifier.1         9216         4096  37752832.0       0.02     75,493,376.0   37,748,736.0  151048192.0      16384.0       5.56%  151064576.0
15     classifier.2         4096         4096         0.0       0.02          4,096.0        4,096.0      16384.0      16384.0       0.00%      32768.0
16     classifier.3         4096         4096         0.0       0.02              0.0            0.0          0.0          0.0       0.00%          0.0
17     classifier.4         4096         4096  16781312.0       0.02     33,550,336.0   16,777,216.0   67141632.0      16384.0       5.55%   67158016.0
18     classifier.5         4096         4096         0.0       0.02          4,096.0        4,096.0      16384.0      16384.0       0.00%      32768.0
19     classifier.6         4096           10     40970.0       0.00         81,910.0       40,960.0     180264.0         40.0       0.00%     180304.0
total                                          57044810.0       4.15  1,421,458,422.0  711,488,448.0     180264.0         40.0     100.00%  237378896.0
=======================================================================================================================================================
Total params: 57,044,810
-------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 4.15MB
Total MAdd: 1.42GMAdd
Total Flops: 711.49MFlops
Total MemR+W: 226.38MB

二. VGGNet

VGGNet是ImageNet 2014的亚军,总结起来就是它使用了更小的滤波器,同时使用了更深的网络结构,AlexNet只有8层,而VGGNet有16-19层网络,也不像AlexNet那样使用 11x11 这么大的卷积核,它只使用 3x3 的 卷积滤波器和 2x2 的大池化层。

# VGG-16模型

from torch import nn
from torchstat import stat
class VGG(nn.Module):
    def __init__(self, num_classes):
        super(VGG, self).__init__()     # b, 3, 224, 224
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),   # b, 64, 224, 224
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),    # b, 64, 224, 224
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),     # b, 64, 112, 112

            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # b, 128, 112, 112
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),   # b, 128, 112, 112
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),   # b, 128, 56, 56

            nn.Conv2d(128, 256, kernel_size=3, padding=1),    # b, 256, 56, 56
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),    # b, 256, 56, 56
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # b, 256, 56, 56
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),    # b, 256, 28, 28

            nn.Conv2d(256, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # b, 512, 14, 14

            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # b, 512, 7, 7
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes))

    def forward(self, x):
        x = self.features(x)
        print(x.size())
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
model = VGG(1000)
stat(model, (3, 224, 224))

使用stat模块对模型参数量和计算量进行估计,顺便也验证了模型是否正确,运行结果:

torch.Size([1, 512, 7, 7])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.25     173,408,256.0      89,915,392.0     609280.0   12845056.0       1.67%   13454336.0
1        features.1   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.00%   25690112.0
2        features.2   64 224 224   64 224 224      36928.0      12.25   3,699,376,128.0   1,852,899,328.0   12992768.0   12845056.0      18.33%   25837824.0
3        features.3   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.00%   25690112.0
4        features.4   64 224 224   64 112 112          0.0       3.06       2,408,448.0       3,211,264.0   12845056.0    3211264.0       3.33%   16056320.0
5        features.5   64 112 112  128 112 112      73856.0       6.12   1,849,688,064.0     926,449,664.0    3506688.0    6422528.0       6.67%    9929216.0
6        features.6  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       1.67%   12845056.0
7        features.7  128 112 112  128 112 112     147584.0       6.12   3,699,376,128.0   1,851,293,696.0    7012864.0    6422528.0      11.67%   13435392.0
8        features.8  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.00%   12845056.0
9        features.9  128 112 112  128  56  56          0.0       1.53       1,204,224.0       1,605,632.0    6422528.0    1605632.0       1.67%    8028160.0
10      features.10  128  56  56  256  56  56     295168.0       3.06   1,849,688,064.0     925,646,848.0    2786304.0    3211264.0       5.00%    5997568.0
11      features.11  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       1.67%    6422528.0
12      features.12  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       8.33%    8782848.0
13      features.13  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.00%    6422528.0
14      features.14  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0      10.00%    8782848.0
15      features.15  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.00%    6422528.0
16      features.16  256  56  56  256  28  28          0.0       0.77         602,112.0         802,816.0    3211264.0     802816.0       0.00%    4014080.0
17      features.17  256  28  28  512  28  28    1180160.0       1.53   1,849,688,064.0     925,245,440.0    5523456.0    1605632.0       3.33%    7129088.0
18      features.18  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       1.67%    3211264.0
19      features.19  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       6.67%   12650496.0
20      features.20  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.00%    3211264.0
21      features.21  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       6.67%   12650496.0
22      features.22  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.00%    3211264.0
23      features.23  512  28  28  512  14  14          0.0       0.38         301,056.0         401,408.0    1605632.0     401408.0       0.00%    2007040.0
24      features.24  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       1.67%   10242048.0
25      features.25  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
26      features.26  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       3.33%   10242048.0
27      features.27  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
28      features.28  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       1.67%   10242048.0
29      features.29  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
30      features.30  512  14  14  512   7   7          0.0       0.10          75,264.0         100,352.0     401408.0     100352.0       0.00%     501760.0
31     classifier.0        25088         4096  102764544.0       0.02     205,516,800.0     102,760,448.0  411158528.0      16384.0       3.33%  411174912.0
32     classifier.1         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
33     classifier.2         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.00%          0.0
34     classifier.3         4096         4096   16781312.0       0.02      33,550,336.0      16,777,216.0   67141632.0      16384.0       1.67%   67158016.0
35     classifier.4         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
36     classifier.5         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.00%          0.0
37     classifier.6         4096         1000    4097000.0       0.00       8,191,000.0       4,096,000.0   16404384.0       4000.0       0.00%   16408384.0
total                                          138357544.0     109.29  30,958,666,264.0  15,503,489,024.0   16404384.0       4000.0     100.00%  783170624.0
============================================================================================================================================================
Total params: 138,357,544
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 109.29MB
Total MAdd: 30.96GMAdd
Total Flops: 15.5GFlops
Total MemR+W: 746.89MB

三. GoogLeNet

GoogLeNet也叫InceptionNet,是在2014年被提出来的,如今已经进化到了v4版本。采用了比VGGNet更深的结构,一共有22层,但参数量缺少了很多。

GoogleNet家族:
(1) V1版本提出了inception的理念,大胆使用了 1 x 1 的卷积核来压缩通道数;
(2) V2版本借鉴了VGG的理念(定制Inception时,在其内部采用标准化卷积核);
(3) V3(2015)版本将VGG的理念发扬广大,将“标准化”推广到一般情况,并加入了BN;
(4) V4 (2016) 版本在V3的基础上选定了合适的超参,没有引入残差的情况下,网络层数仍旧达到了76层。

其最为创新的地方就在于Inception模块,它是一个局部的网络拓扑结构,然后将这些模块堆叠在一起形成一个抽象层网络结构。具体来说就是运行几个并行的滤波器对输入进行卷积和池化,这些滤波器有不同的感受野,最后将输出的结果按深度拼接在一起形成输出层。

下面来实现GoogLeNet中的Iception模块,整个GoogLeNet都是由这些Inception模块组成的。

from torch import nn
import torch
# 卷积 + BN
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv= nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class Inception(nn.Module):
    def __init__(self, in_channels):
        super(Inception, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branchpool_1 = nn.AvgPool2d(kernel_size=3, stride=1)
        self.branch_pool = BasicConv2d(in_channels, 64, kernel_size=1, padding=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        print('branch1x1_size: ', branch1x1.size())
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        print('branch5x5_size: ', branch5x5.size())
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)
        print('branch3x3_size: ', branch3x3.size())
        branch_pool = self.branchpool_1(x)
        branch_pool = self.branch_pool(branch_pool)
        print('branch_pool_size: ', branch_pool.size())
        outputs = [branch1x1, branch5x5, branch3x3, branch_pool]

        return torch.cat(outputs, 1)

model = Inception(3)
from torchstat import stat
stat(model, (3, 224, 224))

四. ResNet

ResNet是2015年ImageNet竞赛的冠军,由微软研究院提出,通过残差模块能够成功训练出高达152层 的深的神经网络。

ResNet最初设计灵感来源于这样一个问题:在不断加深深度神经网络的时候,会出现一个Degradation,即准确率会先上升然后达到饱和,再继续增加深度则会导致模型准确率下降。这并不是过拟合的问题,因为不仅在测试集合上误差增加,在训练集上误差也增加。假设一个比较浅的网络达到了饱和的准确率,那么在后面加上几个恒等映射层,误差不会增加,也就是说更深的模型起码不会使得模型效果下降。

from torch import nn
from torchstat import stat

def conv3x3(in_planes, out_planes, stride=1):
    # 3x3 convolution with padding
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        # downsample对应着一个下采样函数
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

你可能感兴趣的:(pytorch框架)