目录
一、前言
二、网络结构
(一)hard-swish激活函数
(二)bneck结构
(三)网络结构
三、参数量
四、代码
五、训练结果
六、完整代码
MobileNet_v3是在MobileNet_v2以及MnasNet基础上改进的。
先简要介绍下MnasNet网络,MnasNet是在MobileNet_v2与MobileNet_v3之间被提出的,主要基于NAS(神经架构搜索)来寻找不同的block。之前大多数NAS搜索的是一个cell,之后将cell进行反复堆叠,而这种网络往往是很受限的(主要表现在准确率以及延迟上)。MnasNet通过寻找不同的block增加了网络的多样性,此外为了更好的反应模型部署到手机设备上的延时,该网络引入了多目标优化,兼顾手机的延时以及模型准确率。
MobileNet_v3首先基于NAS搜索出模型,之后更改了模型第一层(原本第一个卷积核个数为32,现在更改为16),将最后一层(分类层)做了修改,同时论文中提出了hard-swish非线性激活函数(论文中作者通过实验证明,该激活函数只有在较深层时候效果会好,因此从网络结构中可以看到前面几层用的还是relu激活函数)。
MobileNet_v3仍然采用了MobileNet_v2中的倒残差结构(Inverted Residuals),同时引入了MnasNet中的注意力机制,这也就是论文中的bneck,论文中提出了两种模型,分别为MobileNetV3-Small以及MobileNetV3-large,本文代码实现的是MobileNetV3-large。
hard-swish是对swish激活函数做了优化,极大减少了计算的复杂度。
swish激活函数的表达式为,hard-swish的表达式为,可以很明显看出relu激活函数的计算复杂度要比sigmid的小很多,通过下图可以看出hard-swish的效果几乎与swish相同。
首先通过1*1卷积升维,升维后的feature map进行dw卷积,之后通过SE模块(相当于计算出升维后feature map各个维度乘以相应的占比因子),最后通过1*1卷积降维,可以看出此部分相当于在MobileNet_v2中间加了SE模块,shortcut部分与MobileNet_v2相同,只有当输入channel=输出channel,并且stride等于1时候才有残差连接。
这是论文中给出的网络结构,值得注意的是第一个卷积核的个数为16,并且采用了HS激活函数;表中exp_size代表benck中第一部分升维后的channel,SE代表是否使用SE模块,NL表示激活函数的类型,HS代表hard-swish激活函数,RE代表ReLU激活函数,s代表步长。
修改前的last stage与修改后的last stage(图来自原论文),明显看出计算量减少了很多,同时延时降低了2毫秒,另外需要注意的是,最后用于分类的部分(NBN)代表不采用BatchNormalization
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 112, 112] 432
BatchNorm2d-2 [-1, 16, 112, 112] 32
Hardswish-3 [-1, 16, 112, 112] 0
baseConv-4 [-1, 16, 112, 112] 0
Conv2d-5 [-1, 16, 112, 112] 144
BatchNorm2d-6 [-1, 16, 112, 112] 32
Hardswish-7 [-1, 16, 112, 112] 0
baseConv-8 [-1, 16, 112, 112] 0
Conv2d-9 [-1, 16, 112, 112] 256
BatchNorm2d-10 [-1, 16, 112, 112] 32
Identity-11 [-1, 16, 112, 112] 0
baseConv-12 [-1, 16, 112, 112] 0
bneckModule-13 [-1, 16, 112, 112] 0
Conv2d-14 [-1, 64, 112, 112] 1,024
BatchNorm2d-15 [-1, 64, 112, 112] 128
ReLU6-16 [-1, 64, 112, 112] 0
baseConv-17 [-1, 64, 112, 112] 0
Conv2d-18 [-1, 64, 56, 56] 576
BatchNorm2d-19 [-1, 64, 56, 56] 128
ReLU6-20 [-1, 64, 56, 56] 0
baseConv-21 [-1, 64, 56, 56] 0
Conv2d-22 [-1, 24, 56, 56] 1,536
BatchNorm2d-23 [-1, 24, 56, 56] 48
Identity-24 [-1, 24, 56, 56] 0
baseConv-25 [-1, 24, 56, 56] 0
bneckModule-26 [-1, 24, 56, 56] 0
Conv2d-27 [-1, 72, 56, 56] 1,728
BatchNorm2d-28 [-1, 72, 56, 56] 144
ReLU6-29 [-1, 72, 56, 56] 0
baseConv-30 [-1, 72, 56, 56] 0
Conv2d-31 [-1, 72, 56, 56] 648
BatchNorm2d-32 [-1, 72, 56, 56] 144
ReLU6-33 [-1, 72, 56, 56] 0
baseConv-34 [-1, 72, 56, 56] 0
Conv2d-35 [-1, 24, 56, 56] 1,728
BatchNorm2d-36 [-1, 24, 56, 56] 48
Identity-37 [-1, 24, 56, 56] 0
baseConv-38 [-1, 24, 56, 56] 0
bneckModule-39 [-1, 24, 56, 56] 0
Conv2d-40 [-1, 72, 56, 56] 1,728
BatchNorm2d-41 [-1, 72, 56, 56] 144
ReLU6-42 [-1, 72, 56, 56] 0
baseConv-43 [-1, 72, 56, 56] 0
Conv2d-44 [-1, 72, 28, 28] 1,800
BatchNorm2d-45 [-1, 72, 28, 28] 144
ReLU6-46 [-1, 72, 28, 28] 0
baseConv-47 [-1, 72, 28, 28] 0
AdaptiveAvgPool2d-48 [-1, 72, 1, 1] 0
Conv2d-49 [-1, 18, 1, 1] 1,314
ReLU6-50 [-1, 18, 1, 1] 0
Conv2d-51 [-1, 72, 1, 1] 1,368
Hardswish-52 [-1, 72, 1, 1] 0
SEModule-53 [-1, 72, 28, 28] 0
Conv2d-54 [-1, 40, 28, 28] 2,880
BatchNorm2d-55 [-1, 40, 28, 28] 80
Identity-56 [-1, 40, 28, 28] 0
baseConv-57 [-1, 40, 28, 28] 0
bneckModule-58 [-1, 40, 28, 28] 0
Conv2d-59 [-1, 120, 28, 28] 4,800
BatchNorm2d-60 [-1, 120, 28, 28] 240
ReLU6-61 [-1, 120, 28, 28] 0
baseConv-62 [-1, 120, 28, 28] 0
Conv2d-63 [-1, 120, 28, 28] 3,000
BatchNorm2d-64 [-1, 120, 28, 28] 240
ReLU6-65 [-1, 120, 28, 28] 0
baseConv-66 [-1, 120, 28, 28] 0
AdaptiveAvgPool2d-67 [-1, 120, 1, 1] 0
Conv2d-68 [-1, 30, 1, 1] 3,630
ReLU6-69 [-1, 30, 1, 1] 0
Conv2d-70 [-1, 120, 1, 1] 3,720
Hardswish-71 [-1, 120, 1, 1] 0
SEModule-72 [-1, 120, 28, 28] 0
Conv2d-73 [-1, 40, 28, 28] 4,800
BatchNorm2d-74 [-1, 40, 28, 28] 80
Identity-75 [-1, 40, 28, 28] 0
baseConv-76 [-1, 40, 28, 28] 0
bneckModule-77 [-1, 40, 28, 28] 0
Conv2d-78 [-1, 120, 28, 28] 4,800
BatchNorm2d-79 [-1, 120, 28, 28] 240
ReLU6-80 [-1, 120, 28, 28] 0
baseConv-81 [-1, 120, 28, 28] 0
Conv2d-82 [-1, 120, 28, 28] 3,000
BatchNorm2d-83 [-1, 120, 28, 28] 240
ReLU6-84 [-1, 120, 28, 28] 0
baseConv-85 [-1, 120, 28, 28] 0
AdaptiveAvgPool2d-86 [-1, 120, 1, 1] 0
Conv2d-87 [-1, 30, 1, 1] 3,630
ReLU6-88 [-1, 30, 1, 1] 0
Conv2d-89 [-1, 120, 1, 1] 3,720
Hardswish-90 [-1, 120, 1, 1] 0
SEModule-91 [-1, 120, 28, 28] 0
Conv2d-92 [-1, 40, 28, 28] 4,800
BatchNorm2d-93 [-1, 40, 28, 28] 80
Identity-94 [-1, 40, 28, 28] 0
baseConv-95 [-1, 40, 28, 28] 0
bneckModule-96 [-1, 40, 28, 28] 0
Conv2d-97 [-1, 240, 28, 28] 9,600
BatchNorm2d-98 [-1, 240, 28, 28] 480
Hardswish-99 [-1, 240, 28, 28] 0
baseConv-100 [-1, 240, 28, 28] 0
Conv2d-101 [-1, 240, 14, 14] 2,160
BatchNorm2d-102 [-1, 240, 14, 14] 480
Hardswish-103 [-1, 240, 14, 14] 0
baseConv-104 [-1, 240, 14, 14] 0
Conv2d-105 [-1, 80, 14, 14] 19,200
BatchNorm2d-106 [-1, 80, 14, 14] 160
Identity-107 [-1, 80, 14, 14] 0
baseConv-108 [-1, 80, 14, 14] 0
bneckModule-109 [-1, 80, 14, 14] 0
Conv2d-110 [-1, 200, 14, 14] 16,000
BatchNorm2d-111 [-1, 200, 14, 14] 400
Hardswish-112 [-1, 200, 14, 14] 0
baseConv-113 [-1, 200, 14, 14] 0
Conv2d-114 [-1, 200, 14, 14] 1,800
BatchNorm2d-115 [-1, 200, 14, 14] 400
Hardswish-116 [-1, 200, 14, 14] 0
baseConv-117 [-1, 200, 14, 14] 0
Conv2d-118 [-1, 80, 14, 14] 16,000
BatchNorm2d-119 [-1, 80, 14, 14] 160
Identity-120 [-1, 80, 14, 14] 0
baseConv-121 [-1, 80, 14, 14] 0
bneckModule-122 [-1, 80, 14, 14] 0
Conv2d-123 [-1, 184, 14, 14] 14,720
BatchNorm2d-124 [-1, 184, 14, 14] 368
Hardswish-125 [-1, 184, 14, 14] 0
baseConv-126 [-1, 184, 14, 14] 0
Conv2d-127 [-1, 184, 14, 14] 1,656
BatchNorm2d-128 [-1, 184, 14, 14] 368
Hardswish-129 [-1, 184, 14, 14] 0
baseConv-130 [-1, 184, 14, 14] 0
Conv2d-131 [-1, 80, 14, 14] 14,720
BatchNorm2d-132 [-1, 80, 14, 14] 160
Identity-133 [-1, 80, 14, 14] 0
baseConv-134 [-1, 80, 14, 14] 0
bneckModule-135 [-1, 80, 14, 14] 0
Conv2d-136 [-1, 184, 14, 14] 14,720
BatchNorm2d-137 [-1, 184, 14, 14] 368
Hardswish-138 [-1, 184, 14, 14] 0
baseConv-139 [-1, 184, 14, 14] 0
Conv2d-140 [-1, 184, 14, 14] 1,656
BatchNorm2d-141 [-1, 184, 14, 14] 368
Hardswish-142 [-1, 184, 14, 14] 0
baseConv-143 [-1, 184, 14, 14] 0
Conv2d-144 [-1, 80, 14, 14] 14,720
BatchNorm2d-145 [-1, 80, 14, 14] 160
Identity-146 [-1, 80, 14, 14] 0
baseConv-147 [-1, 80, 14, 14] 0
bneckModule-148 [-1, 80, 14, 14] 0
Conv2d-149 [-1, 480, 14, 14] 38,400
BatchNorm2d-150 [-1, 480, 14, 14] 960
Hardswish-151 [-1, 480, 14, 14] 0
baseConv-152 [-1, 480, 14, 14] 0
Conv2d-153 [-1, 480, 14, 14] 4,320
BatchNorm2d-154 [-1, 480, 14, 14] 960
Hardswish-155 [-1, 480, 14, 14] 0
baseConv-156 [-1, 480, 14, 14] 0
AdaptiveAvgPool2d-157 [-1, 480, 1, 1] 0
Conv2d-158 [-1, 120, 1, 1] 57,720
ReLU6-159 [-1, 120, 1, 1] 0
Conv2d-160 [-1, 480, 1, 1] 58,080
Hardswish-161 [-1, 480, 1, 1] 0
SEModule-162 [-1, 480, 14, 14] 0
Conv2d-163 [-1, 112, 14, 14] 53,760
BatchNorm2d-164 [-1, 112, 14, 14] 224
Identity-165 [-1, 112, 14, 14] 0
baseConv-166 [-1, 112, 14, 14] 0
bneckModule-167 [-1, 112, 14, 14] 0
Conv2d-168 [-1, 672, 14, 14] 75,264
BatchNorm2d-169 [-1, 672, 14, 14] 1,344
Hardswish-170 [-1, 672, 14, 14] 0
baseConv-171 [-1, 672, 14, 14] 0
Conv2d-172 [-1, 672, 14, 14] 6,048
BatchNorm2d-173 [-1, 672, 14, 14] 1,344
Hardswish-174 [-1, 672, 14, 14] 0
baseConv-175 [-1, 672, 14, 14] 0
AdaptiveAvgPool2d-176 [-1, 672, 1, 1] 0
Conv2d-177 [-1, 168, 1, 1] 113,064
ReLU6-178 [-1, 168, 1, 1] 0
Conv2d-179 [-1, 672, 1, 1] 113,568
Hardswish-180 [-1, 672, 1, 1] 0
SEModule-181 [-1, 672, 14, 14] 0
Conv2d-182 [-1, 112, 14, 14] 75,264
BatchNorm2d-183 [-1, 112, 14, 14] 224
Identity-184 [-1, 112, 14, 14] 0
baseConv-185 [-1, 112, 14, 14] 0
bneckModule-186 [-1, 112, 14, 14] 0
Conv2d-187 [-1, 672, 14, 14] 75,264
BatchNorm2d-188 [-1, 672, 14, 14] 1,344
Hardswish-189 [-1, 672, 14, 14] 0
baseConv-190 [-1, 672, 14, 14] 0
Conv2d-191 [-1, 672, 7, 7] 16,800
BatchNorm2d-192 [-1, 672, 7, 7] 1,344
Hardswish-193 [-1, 672, 7, 7] 0
baseConv-194 [-1, 672, 7, 7] 0
AdaptiveAvgPool2d-195 [-1, 672, 1, 1] 0
Conv2d-196 [-1, 168, 1, 1] 113,064
ReLU6-197 [-1, 168, 1, 1] 0
Conv2d-198 [-1, 672, 1, 1] 113,568
Hardswish-199 [-1, 672, 1, 1] 0
SEModule-200 [-1, 672, 7, 7] 0
Conv2d-201 [-1, 160, 7, 7] 107,520
BatchNorm2d-202 [-1, 160, 7, 7] 320
Identity-203 [-1, 160, 7, 7] 0
baseConv-204 [-1, 160, 7, 7] 0
bneckModule-205 [-1, 160, 7, 7] 0
Conv2d-206 [-1, 960, 7, 7] 153,600
BatchNorm2d-207 [-1, 960, 7, 7] 1,920
Hardswish-208 [-1, 960, 7, 7] 0
baseConv-209 [-1, 960, 7, 7] 0
Conv2d-210 [-1, 960, 7, 7] 24,000
BatchNorm2d-211 [-1, 960, 7, 7] 1,920
Hardswish-212 [-1, 960, 7, 7] 0
baseConv-213 [-1, 960, 7, 7] 0
AdaptiveAvgPool2d-214 [-1, 960, 1, 1] 0
Conv2d-215 [-1, 240, 1, 1] 230,640
ReLU6-216 [-1, 240, 1, 1] 0
Conv2d-217 [-1, 960, 1, 1] 231,360
Hardswish-218 [-1, 960, 1, 1] 0
SEModule-219 [-1, 960, 7, 7] 0
Conv2d-220 [-1, 160, 7, 7] 153,600
BatchNorm2d-221 [-1, 160, 7, 7] 320
Identity-222 [-1, 160, 7, 7] 0
baseConv-223 [-1, 160, 7, 7] 0
bneckModule-224 [-1, 160, 7, 7] 0
Conv2d-225 [-1, 960, 7, 7] 153,600
BatchNorm2d-226 [-1, 960, 7, 7] 1,920
Hardswish-227 [-1, 960, 7, 7] 0
baseConv-228 [-1, 960, 7, 7] 0
Conv2d-229 [-1, 960, 7, 7] 24,000
BatchNorm2d-230 [-1, 960, 7, 7] 1,920
Hardswish-231 [-1, 960, 7, 7] 0
baseConv-232 [-1, 960, 7, 7] 0
AdaptiveAvgPool2d-233 [-1, 960, 1, 1] 0
Conv2d-234 [-1, 240, 1, 1] 230,640
ReLU6-235 [-1, 240, 1, 1] 0
Conv2d-236 [-1, 960, 1, 1] 231,360
Hardswish-237 [-1, 960, 1, 1] 0
SEModule-238 [-1, 960, 7, 7] 0
Conv2d-239 [-1, 160, 7, 7] 153,600
BatchNorm2d-240 [-1, 160, 7, 7] 320
Identity-241 [-1, 160, 7, 7] 0
baseConv-242 [-1, 160, 7, 7] 0
bneckModule-243 [-1, 160, 7, 7] 0
Conv2d-244 [-1, 960, 7, 7] 153,600
BatchNorm2d-245 [-1, 960, 7, 7] 1,920
Hardswish-246 [-1, 960, 7, 7] 0
baseConv-247 [-1, 960, 7, 7] 0
AdaptiveAvgPool2d-248 [-1, 960, 1, 1] 0
Linear-249 [-1, 1280] 1,230,080
Hardswish-250 [-1, 1280] 0
Dropout-251 [-1, 1280] 0
Linear-252 [-1, 10] 12,810
================================================================
Total params: 4,213,008
Trainable params: 4,213,008
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 143.36
Params size (MB): 16.07
Estimated Total Size (MB): 160.01
----------------------------------------------------------------
import torch.nn as nn
from collections import OrderedDict
import torch
from torchsummary import summary
#定义基本的Conv_Bn_activate
class baseConv(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size,stride,groups=1,active=False,bias=False):
super(baseConv, self).__init__()
#定义使用的激活函数
if active=='HS':
ac=nn.Hardswish
elif active=='RE':
ac=nn.ReLU6
else:
ac=nn.Identity
pad=kernel_size//2
self.base=nn.Sequential(
nn.Conv2d(in_channels=inchannel,out_channels=outchannel,kernel_size=kernel_size,stride=stride,padding=pad,groups=groups,bias=bias),
nn.BatchNorm2d(outchannel),
ac()
)
def forward(self,x):
x=self.base(x)
return x
#定义SE模块
class SEModule(nn.Module):
def __init__(self,inchannels):
super(SEModule, self).__init__()
hidden_channel=int(inchannels/4)
self.pool=nn.AdaptiveAvgPool2d((1,1))
self.linear1=nn.Sequential(
nn.Conv2d(inchannels,hidden_channel,1),
nn.ReLU6()
)
self.linear2=nn.Sequential(
nn.Conv2d(hidden_channel,inchannels,1),
nn.Hardswish()
)
def forward(self,x):
out=self.pool(x)
out=self.linear1(out)
out=self.linear2(out)
return out*x
#定义bneck模块
class bneckModule(nn.Module):
def __init__(self,inchannels,expand_channels,outchannels,kernel_size,stride,SE,activate):
super(bneckModule, self).__init__()
self.module=[] #存放module
if inchannels!=expand_channels: #只有不相等时候才有第一层的升维操作
self.module.append(baseConv(inchannels,expand_channels,kernel_size=1,stride=1,active=activate))
self.module.append(baseConv(expand_channels,expand_channels,kernel_size=kernel_size,stride=stride,active=activate,groups=expand_channels))
#判断是否有se模块
if SE==True:
self.module.append(SEModule(expand_channels))
self.module.append(baseConv(expand_channels,outchannels,1,1))
self.module=nn.Sequential(*self.module)
#判断是否有残差结构
self.residual=False
if inchannels==outchannels and stride==1:
self.residual=True
def forward(self,x):
out1=self.module(x)
if self.residual:
return out1+x
else:
return out1
#定义v3结构
class mobilenet_v3(nn.Module):
def __init__(self,num_classes,init_weight=True):
super(mobilenet_v3, self).__init__()
# [inchannel,expand_channels,outchannels,kernel_size,stride,SE,activate]
net_config = [[16, 16, 16, 3, 1, False, 'HS'],
[16, 64, 24, 3, 2, False, 'RE'],
[24, 72, 24, 3, 1, False, 'RE'],
[24, 72, 40, 5, 2, True, 'RE'],
[40, 120, 40, 5, 1, True, 'RE'],
[40, 120, 40, 5, 1, True, 'RE'],
[40, 240, 80, 3, 2, False, 'HS'],
[80, 200, 80, 3, 1, False, 'HS'],
[80, 184, 80, 3, 1, False, 'HS'],
[80, 184, 80, 3, 1, False, 'HS'],
[80, 480, 112, 3, 1, True, 'HS'],
[112, 672, 112, 3, 1, True, 'HS'],
[112, 672, 160, 5, 2, True, 'HS'],
[160, 960, 160, 5, 1, True, 'HS'],
[160, 960, 160, 5, 1, True, 'HS']]
#定义一个有序字典存放网络结构
modules=OrderedDict()
modules.update({'layer1':baseConv(inchannel=3,kernel_size=3,outchannel=16,stride=2,active='HS')})
#开始配置
for idx,layer in enumerate(net_config):
modules.update({'bneck_{}'.format(idx):bneckModule(layer[0],layer[1],layer[2],layer[3],layer[4],layer[5],layer[6])})
modules.update({'conv_1*1':baseConv(layer[2],960,1,stride=1,active='HS')})
modules.update({'pool':nn.AdaptiveAvgPool2d((1,1))})
self.module=nn.Sequential(modules)
self.classifier=nn.Sequential(
nn.Linear(960,1280),
nn.Hardswish(),
nn.Dropout(p=0.2),
nn.Linear(1280,num_classes)
)
if init_weight:
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
elif isinstance(w, nn.Linear):
nn.init.normal_(w.weight, 0, 0.01)
nn.init.zeros_(w.bias)
def forward(self,x):
out=self.module(x)
out=out.view(out.size(0),-1)
out=self.classifier(out)
return out
if __name__ == '__main__':
net=mobilenet_v3(10).to('cuda')
summary(net,(3,224,224))
训练部分以及验证部分的代码与MobileNet_v2类似。
这里我迭代了9个epoch,准确率达到了79%左右(效果还是很不错的);
可视化训练损失以及验证集准确率(在终端输入tensorboard --logdir=runs):
代码地址:链接:https://pan.baidu.com/s/1dgLLRdco_kYPxEkNfFs3yw 提取码:brsm
权重下载地址:https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth