UNet是一个著名的语义分割模型,UNet网络在被提出后,就大范围的用于医学图像、以及自动驾驶场景的分割,甚至是很多全球的AI算法大赛如kaggle等等。而语义分割则是一种经典的深度学习计算机视觉任务,即将不同的类别用不同的RGB展示出来。如图,不同的类别,如人、火车、小轿车,路标等都被不同的颜色体现了出来。
MobileNet是由谷歌开源的一款轻量级的神经网络backbone。在大多实践和工程项目中,要求推理inference的实时性,即对深度学习模型的特征提取网络的大小有一定的要求,我们今天的主角MobileNet就是性价比极高的一个轻量级网络。而UNet的backbone,即特征提取网络为一个参数量极大的VGG16模型,可想而知很多嵌入式设备是带不动的,更不能得到实时的分割效果。因此,本人想通过使用MobileNet替换VGG16的方式来轻量化我们的UNet模型,使得参数量减少,来达到加速推理的效果。本文中,本人基于pytorch深度学习框架成功修改了网络的backbone,并进行模型融合,提高了模型特征提取的准确性。
MobileNet使用的核心思想便为depthwise separable convolution(深度可分离卷积)
假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,最后可得到所需的32个输出通道,所需参数为16×32×3×3=4608个。
应用深度可分离卷积,用16个3×3大小的卷积核分别遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,所需参数为16×3×3+16×32×1×1=656个。
可以看出来depthwise separable convolution可以减少模型的参数。
如下这张图就是深度可分离卷积的结构:
以下就是MobileNetV1的网络结构,其中第一层是一个普通的卷积块,由于步长stride为2,因此会对图片的长宽进行一次压缩。之后我们可以看到他会经历一个convdw(深度可分离卷积块),以及一次普通的1x1卷积,其中深度可分离卷积用来进行特征提取,1x1卷积块用来调整通道数,当然这里没有加上标准化BN以及激活函数RELU6,这些在后面的代码里会有体现。通过不断的convdw和1x1conv的叠加,最后通过平均池化和全连接通过softmax函数输出结果。这就是整个MobileNet的网络结构,接下来我们来通过代码,基于pytorch搭建一下网络结构。
import time
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models._utils as _utils
from torch.autograd import Variable
# conv_bn为网络的第一个卷积块,步长为2
def conv_bn(inp, oup, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
# conv_dw为深度可分离卷积
def conv_dw(inp, oup, stride=1):
return nn.Sequential(
# 3x3卷积提取特征,步长为2
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU6(inplace=True),
# 1x1卷积,步长为1
nn.Conv2d(inp, oup, 1, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True),
)
class MobileNet(nn.Module):
def __init__(self, n_channels):
super(MobileNet, self).__init__()
self.layer1 = nn.Sequential(
# 第一个卷积块,步长为2,压缩一次
conv_bn(n_channels, 32, 1), # 416,416,3 -> 208,208,32
# 第一个深度可分离卷积,步长为1
conv_dw(32, 64, 1), # 208,208,32 -> 208,208,64
# 两个深度可分离卷积块
conv_dw(64, 128, 2), # 208,208,64 -> 104,104,128
conv_dw(128, 128, 1),
# 104,104,128 -> 52,52,256
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
)
# 52,52,256 -> 26,26,512
self.layer2 = nn.Sequential(
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
)
# 26,26,512 -> 13,13,1024
self.layer3 = nn.Sequential(
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(1024, 1000)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x
我们刚刚通过pytorch搭建了MobileNet的网络结构。接下来我们需要将MobileNet加入到UNet中,替换之前的backbone。首先我们先看一下UNet的网络结构:
我们通过找到UNet中压缩次数与MobileNet中压缩次数相同的feature map(特征层),将二者对应替换,并与后面上采样得到的特征层进行堆叠(torch.cat),则可以实现整个模型的成功替换,实现代码如下:
from mobilenet.mobile import MobileNet
import torch.nn as nn
from collections import OrderedDict
import torch
import torchsummary as summary
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
def conv2d(filter_in, filter_out, kernel_size, groups=1, stride=1):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, groups=groups, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.ReLU6(inplace=True)),
]))
class mobilenet(nn.Module):
def __init__(self, n_channels):
super(mobilenet, self).__init__()
self.model = MobileNet(n_channels)
def forward(self, x):
out3 = self.model.layer1(x)
out4 = self.model.layer2(out3)
out5 = self.model.layer3(out4)
return out3, out4, out5
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, num_classes):
super(UNet, self).__init__()
self.n_channels = n_channels
self.num_classes = num_classes
# ---------------------------------------------------#
# 64,64,256;32,32,512;16,16,1024
# ---------------------------------------------------#
self.backbone = mobilenet(n_channels)
self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv1 = DoubleConv(1024, 512)
self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv2 = DoubleConv(1024, 256)
self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv3 = DoubleConv(512, 128)
self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
#nn.Upsample(scale_factor=2, mode='bilinear')
self.conv4 = DoubleConv(128, 64)
self.oup = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
# backbone
x2, x1, x0 = self.backbone(x)
# print(f"x2.shape: {x2.shape}, x1: {x1.shape}, x0: {x0.shape} ")
P5 = self.up1(x0)
P5 = self.conv1(P5) # P5: 26x26x512
# print(P5.shape)
P4 = x1 # P4: 26x26x512
P4 = torch.cat([P4, P5], axis=1) # P4(堆叠后): 26x26x1024
# print(f"cat 后是: {P4.shape}")
P4 = self.up2(P4) # 52x52x1024
P4 = self.conv2(P4) # 52x52x256
P3 = x2 # x2 = 52x52x256
P3 = torch.cat([P4, P3], axis=1) # 52x52x512
P3 = self.up3(P3)
P3 = self.conv3(P3)
P3 = self.up4(P3)
P3 = self.conv4(P3)
out = self.oup(P3)
# print(f"out.shape is {out.shape}")
return out
可以看到在UNet的类中,我们通过输出MobileNet即backbone的三个特征层的结果,再进行上采样,即可以在后面的前向传播部分进行堆叠。
参见:
https://github.com/YZY-stack/UNet-MobileNet-Pytorch
欢迎大家star或fork我的项目,接下来会继续完善和改进我的代码,也欢迎大家提问、批评指正。