我写这篇的目的主要是想熟悉一下PyTorch搭建模型的方法。
import torch
from torch import nn
from torchstat import stat
class AlexNet(nn.Module):
def __init__(self, num_classes):
super(AlexNet, self).__init__() # b, 3, 224, 224
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # b, 64, 55, 55
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2), # b, 64, 27, 27
nn.Conv2d(64, 192, kernel_size=5, padding=2), # b, 192, 27, 27
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2), # b, 192, 13, 13
nn.Conv2d(192, 384, kernel_size=3, padding=1), # b, 384, 13, 13
nn.ReLU(True),
nn.Conv2d(384, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2)) # b, 256, 6, 6
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256*6*6, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Linear(4096, num_classes))
def forward(self, x):
x = self.features(x)
print(x.size())
x = x.view(x.size(0), 256*6*6)
x = self.classifier(x)
return x
model = AlexNet(10)
stat(model, (3, 224, 224))
使用stat模块对模型参数量和计算量进行估计,顺便也验证了模型是否正确,运行结果:
torch.Size([1, 256, 6, 6])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 features.0 3 224 224 64 55 55 23296.0 0.74 140,553,600.0 70,470,400.0 695296.0 774400.0 55.56% 1469696.0
1 features.1 64 55 55 64 55 55 0.0 0.74 193,600.0 193,600.0 774400.0 774400.0 0.00% 1548800.0
2 features.2 64 55 55 64 27 27 0.0 0.18 373,248.0 193,600.0 774400.0 186624.0 5.57% 961024.0
3 features.3 64 27 27 192 27 27 307392.0 0.53 447,897,600.0 224,088,768.0 1416192.0 559872.0 22.21% 1976064.0
4 features.4 192 27 27 192 27 27 0.0 0.53 139,968.0 139,968.0 559872.0 559872.0 0.00% 1119744.0
5 features.5 192 27 27 192 13 13 0.0 0.12 259,584.0 139,968.0 559872.0 129792.0 0.00% 689664.0
6 features.6 192 13 13 384 13 13 663936.0 0.25 224,280,576.0 112,205,184.0 2785536.0 259584.0 0.00% 3045120.0
7 features.7 384 13 13 384 13 13 0.0 0.25 64,896.0 64,896.0 259584.0 259584.0 0.00% 519168.0
8 features.8 384 13 13 256 13 13 884992.0 0.17 299,040,768.0 149,563,648.0 3799552.0 173056.0 5.56% 3972608.0
9 features.9 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 0.00% 346112.0
10 features.10 256 13 13 256 13 13 590080.0 0.17 199,360,512.0 99,723,520.0 2533376.0 173056.0 0.00% 2706432.0
11 features.11 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 0.00% 346112.0
12 features.12 256 13 13 256 6 6 0.0 0.04 73,728.0 43,264.0 173056.0 36864.0 0.00% 209920.0
13 classifier.0 9216 9216 0.0 0.04 0.0 0.0 0.0 0.0 0.00% 0.0
14 classifier.1 9216 4096 37752832.0 0.02 75,493,376.0 37,748,736.0 151048192.0 16384.0 5.56% 151064576.0
15 classifier.2 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
16 classifier.3 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0
17 classifier.4 4096 4096 16781312.0 0.02 33,550,336.0 16,777,216.0 67141632.0 16384.0 5.55% 67158016.0
18 classifier.5 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
19 classifier.6 4096 10 40970.0 0.00 81,910.0 40,960.0 180264.0 40.0 0.00% 180304.0
total 57044810.0 4.15 1,421,458,422.0 711,488,448.0 180264.0 40.0 100.00% 237378896.0
=======================================================================================================================================================
Total params: 57,044,810
-------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 4.15MB
Total MAdd: 1.42GMAdd
Total Flops: 711.49MFlops
Total MemR+W: 226.38MB
VGGNet是ImageNet 2014的亚军,总结起来就是它使用了更小的滤波器,同时使用了更深的网络结构,AlexNet只有8层,而VGGNet有16-19层网络,也不像AlexNet那样使用 11x11 这么大的卷积核,它只使用 3x3 的 卷积滤波器和 2x2 的大池化层。
# VGG-16模型
from torch import nn
from torchstat import stat
class VGG(nn.Module):
def __init__(self, num_classes):
super(VGG, self).__init__() # b, 3, 224, 224
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # b, 64, 224, 224
nn.ReLU(True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), # b, 64, 224, 224
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 64, 112, 112
nn.Conv2d(64, 128, kernel_size=3, padding=1), # b, 128, 112, 112
nn.ReLU(True),
nn.Conv2d(128, 128, kernel_size=3, padding=1), # b, 128, 112, 112
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 128, 56, 56
nn.Conv2d(128, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 256, 28, 28
nn.Conv2d(256, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 512, 14, 14
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)) # b, 512, 7, 7
self.classifier = nn.Sequential(
nn.Linear(512*7*7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes))
def forward(self, x):
x = self.features(x)
print(x.size())
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = VGG(1000)
stat(model, (3, 224, 224))
使用stat模块对模型参数量和计算量进行估计,顺便也验证了模型是否正确,运行结果:
torch.Size([1, 512, 7, 7])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 features.0 3 224 224 64 224 224 1792.0 12.25 173,408,256.0 89,915,392.0 609280.0 12845056.0 1.67% 13454336.0
1 features.1 64 224 224 64 224 224 0.0 12.25 3,211,264.0 3,211,264.0 12845056.0 12845056.0 0.00% 25690112.0
2 features.2 64 224 224 64 224 224 36928.0 12.25 3,699,376,128.0 1,852,899,328.0 12992768.0 12845056.0 18.33% 25837824.0
3 features.3 64 224 224 64 224 224 0.0 12.25 3,211,264.0 3,211,264.0 12845056.0 12845056.0 0.00% 25690112.0
4 features.4 64 224 224 64 112 112 0.0 3.06 2,408,448.0 3,211,264.0 12845056.0 3211264.0 3.33% 16056320.0
5 features.5 64 112 112 128 112 112 73856.0 6.12 1,849,688,064.0 926,449,664.0 3506688.0 6422528.0 6.67% 9929216.0
6 features.6 128 112 112 128 112 112 0.0 6.12 1,605,632.0 1,605,632.0 6422528.0 6422528.0 1.67% 12845056.0
7 features.7 128 112 112 128 112 112 147584.0 6.12 3,699,376,128.0 1,851,293,696.0 7012864.0 6422528.0 11.67% 13435392.0
8 features.8 128 112 112 128 112 112 0.0 6.12 1,605,632.0 1,605,632.0 6422528.0 6422528.0 0.00% 12845056.0
9 features.9 128 112 112 128 56 56 0.0 1.53 1,204,224.0 1,605,632.0 6422528.0 1605632.0 1.67% 8028160.0
10 features.10 128 56 56 256 56 56 295168.0 3.06 1,849,688,064.0 925,646,848.0 2786304.0 3211264.0 5.00% 5997568.0
11 features.11 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 1.67% 6422528.0
12 features.12 256 56 56 256 56 56 590080.0 3.06 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 8.33% 8782848.0
13 features.13 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.00% 6422528.0
14 features.14 256 56 56 256 56 56 590080.0 3.06 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 10.00% 8782848.0
15 features.15 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.00% 6422528.0
16 features.16 256 56 56 256 28 28 0.0 0.77 602,112.0 802,816.0 3211264.0 802816.0 0.00% 4014080.0
17 features.17 256 28 28 512 28 28 1180160.0 1.53 1,849,688,064.0 925,245,440.0 5523456.0 1605632.0 3.33% 7129088.0
18 features.18 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 1.67% 3211264.0
19 features.19 512 28 28 512 28 28 2359808.0 1.53 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 6.67% 12650496.0
20 features.20 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.00% 3211264.0
21 features.21 512 28 28 512 28 28 2359808.0 1.53 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 6.67% 12650496.0
22 features.22 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.00% 3211264.0
23 features.23 512 28 28 512 14 14 0.0 0.38 301,056.0 401,408.0 1605632.0 401408.0 0.00% 2007040.0
24 features.24 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 1.67% 10242048.0
25 features.25 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
26 features.26 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 3.33% 10242048.0
27 features.27 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
28 features.28 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 1.67% 10242048.0
29 features.29 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
30 features.30 512 14 14 512 7 7 0.0 0.10 75,264.0 100,352.0 401408.0 100352.0 0.00% 501760.0
31 classifier.0 25088 4096 102764544.0 0.02 205,516,800.0 102,760,448.0 411158528.0 16384.0 3.33% 411174912.0
32 classifier.1 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
33 classifier.2 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0
34 classifier.3 4096 4096 16781312.0 0.02 33,550,336.0 16,777,216.0 67141632.0 16384.0 1.67% 67158016.0
35 classifier.4 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
36 classifier.5 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0
37 classifier.6 4096 1000 4097000.0 0.00 8,191,000.0 4,096,000.0 16404384.0 4000.0 0.00% 16408384.0
total 138357544.0 109.29 30,958,666,264.0 15,503,489,024.0 16404384.0 4000.0 100.00% 783170624.0
============================================================================================================================================================
Total params: 138,357,544
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 109.29MB
Total MAdd: 30.96GMAdd
Total Flops: 15.5GFlops
Total MemR+W: 746.89MB
GoogLeNet也叫InceptionNet,是在2014年被提出来的,如今已经进化到了v4版本。采用了比VGGNet更深的结构,一共有22层,但参数量缺少了很多。
GoogleNet家族:
(1) V1版本提出了inception的理念,大胆使用了 1 x 1 的卷积核来压缩通道数;
(2) V2版本借鉴了VGG的理念(定制Inception时,在其内部采用标准化卷积核);
(3) V3(2015)版本将VGG的理念发扬广大,将“标准化”推广到一般情况,并加入了BN;
(4) V4 (2016) 版本在V3的基础上选定了合适的超参,没有引入残差的情况下,网络层数仍旧达到了76层。
其最为创新的地方就在于Inception模块,它是一个局部的网络拓扑结构,然后将这些模块堆叠在一起形成一个抽象层网络结构。具体来说就是运行几个并行的滤波器对输入进行卷积和池化,这些滤波器有不同的感受野,最后将输出的结果按深度拼接在一起形成输出层。
下面来实现GoogLeNet中的Iception模块,整个GoogLeNet都是由这些Inception模块组成的。
from torch import nn
import torch
# 卷积 + BN
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv= nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Inception(nn.Module):
def __init__(self, in_channels):
super(Inception, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
self.branchpool_1 = nn.AvgPool2d(kernel_size=3, stride=1)
self.branch_pool = BasicConv2d(in_channels, 64, kernel_size=1, padding=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
print('branch1x1_size: ', branch1x1.size())
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
print('branch5x5_size: ', branch5x5.size())
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
print('branch3x3_size: ', branch3x3.size())
branch_pool = self.branchpool_1(x)
branch_pool = self.branch_pool(branch_pool)
print('branch_pool_size: ', branch_pool.size())
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, 1)
model = Inception(3)
from torchstat import stat
stat(model, (3, 224, 224))
ResNet是2015年ImageNet竞赛的冠军,由微软研究院提出,通过残差模块能够成功训练出高达152层 的深的神经网络。
ResNet最初设计灵感来源于这样一个问题:在不断加深深度神经网络的时候,会出现一个Degradation,即准确率会先上升然后达到饱和,再继续增加深度则会导致模型准确率下降。这并不是过拟合的问题,因为不仅在测试集合上误差增加,在训练集上误差也增加。假设一个比较浅的网络达到了饱和的准确率,那么在后面加上几个恒等映射层,误差不会增加,也就是说更深的模型起码不会使得模型效果下降。
from torch import nn
from torchstat import stat
def conv3x3(in_planes, out_planes, stride=1):
# 3x3 convolution with padding
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
# downsample对应着一个下采样函数
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out