



代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class h_swish(nn.Module):
def forward(self,x):
return x*F.relu6(x+3)/6
class swish(nn.Module):
def forward(self,x):
return x*F.sigmoid(x)
class h_sigmoid(nn.Module):
def forward(self,x):
return F.relu6(x+3)/6
def _make_divisor(ch, divisor, min_ch = None):
if not min_ch:
min_ch = divisor
new_ch = max(min_ch,int(ch+divisor/2)//divisor*divisor)
if new_ch < 0.9*ch:
new_ch += divisor
return new_ch
class SE_module(nn.Module):
def __init__(self,inchannel):
super(SE_module, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(inchannel,inchannel//4,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(inchannel//4,inchannel,1),
h_sigmoid()
)
def forward(self,x):
mul = self.se(x)
return x * mul
class bneck(nn.Module):
def __init__(self,inchannel,outchannel,hidden_channel,nonlinear,stride,SE=False):
super(bneck, self).__init__()
self.shortcut = True if stride == 1 and inchannel == outchannel else False
layers = []
if inchannel != hidden_channel:
layers.extend([
nn.Conv2d(inchannel,hidden_channel,1),
nn.BatchNorm2d(hidden_channel),
nonlinear()
])
layers.extend([
nn.Conv2d(hidden_channel,hidden_channel,3,1,1,groups = hidden_channel),
nn.BatchNorm2d(hidden_channel),
nonlinear()
])
self.conv1 = nn.Sequential(*layers)
self.se = SE_module(hidden_channel)
self.conv2 = nn.Sequential(
nn.Conv2d(hidden_channel,outchannel,1,1),
nn.BatchNorm2d(outchannel)
)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(self.se(x))
return x
class MobileNet_V3(nn.Module):
def __init__(self, setting, inchannel, classes, alpha = 0.2, round_nearest = 8):
super(MobileNet_V3, self).__init__()
input_channel = _make_divisor(16*alpha,round_nearest)
last_channel = _make_divisor(setting[-1][3]*alpha,round_nearest)
self.HS = h_swish
self.RE = nn.ReLU
self.conv1 = nn.Sequential(nn.Conv2d(inchannel,input_channel,3,2,1),
nn.BatchNorm2d(input_channel),
self.HS()
)
self.block = bneck
self.blocks = nn.ModuleList([])
self.nonlin = self.HS
for _, kernel_size, hidden, out_channels, SE, nonlinear, stride in setting:
self.nonlin = self.RE if nonlinear == 'RE' else self.HS
self.hidden = _make_divisor(hidden*alpha,round_nearest)
out_channels = _make_divisor(out_channels*alpha,round_nearest)
self.blocks.append(self.block(input_channel, out_channels, self.hidden, self.nonlin, stride, SE))
input_channel = out_channels
self.conv2 = nn.Sequential(
nn.Conv2d(input_channel, self.hidden, 1,1),
nn.BatchNorm2d(self.hidden),
SE_module(self.hidden),
self.HS()
)
self.pool = nn.AdaptiveAvgPool2d((1))
self.conv3 = nn.Sequential(
nn.Conv2d(self.hidden, 1024, 1, 1),
self.HS(),
nn.Dropout(0.2),
nn.Conv2d(1024,classes,1,1)
)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode = 'fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self,x):
x = self.conv1(x)
for block in self.blocks:
x = block(x)
x = self.conv2(x)
x = self.pool(x)
x = self.conv3(x)
return x
def MobileNet_V3_large(inchannel,classes):
setting = [
[16, 3, 16, 16, False, 'RE', 1],
[16, 3, 64, 24, False, 'RE', 2],
[24, 3, 72, 24, False, 'RE', 1],
[24, 5, 72, 40, True, 'RE', 2],
[40, 5, 120, 40, True, 'RE', 1],
[40, 5, 120, 40, True, 'RE', 1],
[40, 3, 240, 80, False, 'HS', 2],
[80, 3, 200, 80, False, 'HS', 1],
[80, 3, 184, 80, False, 'HS', 1],
[80, 3, 184, 80, False, 'HS', 1],
[80, 3, 480, 112, True, 'HS', 1],
[112, 3, 672, 112, True, 'HS', 1],
[112, 5, 672, 160, True, 'HS', 2],
[160, 5, 960, 160, True, 'HS', 1],
[160, 5, 960, 160, True, 'HS', 1]
]
return MobileNet_V3(setting,inchannel,classes)
def MobileNet_V3_small(inchannel,classes):
setting = [
[16, 3, 16, 16, True, 'RE', 2],
[16, 3, 72, 24, False, 'RE', 2],
[24, 3, 88, 24, False, 'RE', 1],
[24, 5, 96, 40, True, 'HS', 2],
[40, 5, 240, 40, True, 'HS',1],
[40, 5, 240, 40, True, 'HS', 1],
[40, 5, 120, 48, True, 'HS',1],
[48, 5, 144, 48, True, 'HS', 1],
[48, 5, 288, 96, True, 'HS',2],
[96, 5, 576, 96, True, 'HS',1],
[96, 5, 576, 96, True, 'HS',1]
]
return MobileNet_V3(setting,inchannel,classes)
if __name__ == '__main__':
input = torch.empty(1,3,224,224)
m = MobileNet_V3_small(3,10)
out = m(input)
print(out)