一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)

目录

一、概述

二、MobileNetV1原理和实现

2.1 原理

2.1.1 深度卷积

2.1.2 逐像素卷积

2.2 Pytorch实现

三、MobileNetV2原理和实现

3.1 原理

3.1.1 ReLU6激活函数

3.1.2 Inverted Residual

3.2 Pytorch实现

四、应用(基于人像快速分割的背景虚化处理 )

4.1 概述

4.2 算法原理

4.3 训练

4.4 背景虚化

五、小结

参考文献


一、概述

现阶段深度神经网络在GPU上运行其速度已经可以达到实时性要求,但是如果将训练好的模型直接移植到手机端或者在CPU上运行,这时候速度和内存消耗就是非常致命的问题,只有对模型进行优化才能满足这种资源受限场景中的深度神经网络的使用。模型优化加速主要包含3种类型:1.设计轻量级的网络;2.网络模型压缩剪枝;3.其他的一些量化加速。本文主要探讨轻量级网络的运用。

在轻量级网络中,考虑到通用性和实用性,典型算法就是MobileNet系列,从其名字也可以看出来,该系列算法旨在为移动(Mobile)设备进行智慧赋能,具体包括MobileNetV1、MobileNetV2和MobileNetV3,三大算法按照时间先后顺序依次被研究学者提出。通过使用该系列算法,可以在原有模型的基础上大幅减少模型参数,从而提高模型处理速度和并且内存消耗,在实际工业级产品方案中该优势显得异常重要,因为工业场景中往往是资源受限的(可能在没有GPU的工控机或者在嵌入式上进行开发)。

如果从产品部署角度考虑,那么目前深度学习的热潮已经逐步从Web服务器端转向终端硬件,即转向所谓的“边缘计算”需求。众多大厂纷纷在此发力,力求能够推出自家的带GPU的终端硬件产品,其中以英伟达推出的Jetson系列最为成功。Jetson系列不仅体积小巧,而且自带GPU,因此已经推出收到广泛关注。但是,在这种嵌入式开发板上跑重量级的深度学习模型依然是一个难题,即使攻克显存的瓶颈,在速度上面依然很慢。为此,对原模型进行优化,使得模型参数大幅减少从而能够利用低廉的终端设备实现智能算法应用成为了一个AI工程师必经之路。

本文将详细阐述MobileNetV1和MobileNetV2,从原理切入,然后给出对应的Pytorch实现方法,最后结合MobileNetV2算法,给出一个具体应用的实例。

二、MobileNetV1原理和实现

2.1 原理

MobileNetV1的提出是为了解决移动端设备深度学习推理速度受限问题产生的。传统的卷积神经网络在移动设备上运行速度极慢并且会消耗移动设备大量内存资源。因此,MobileNetV1最大的贡献就是改进传统CNN结构,使得整个模型仅仅降低少量的精度但是却可以极大的提高速度,模型参数量可以减少8倍以上。具体的,MobileNetV1提出了深度可分离卷积(depthwise separable convolution)来代替传统的CNN。

2.1.1 深度卷积

首先来看一下传统卷积的实现方式,假设输入图像是3通道图,长宽均为256,采用5x5卷积核进行卷积操作,输出通道数为16,padding为0,stride为1,那么卷积后输出为252X252X16,如下图所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第1张图片

                                                                                            图1 传统卷积操作

传统的卷积网络是跨通道的,对于上图3通道的输入特征,我们要得到通道数为16的输出特征。普通卷积使用16个不同的 5x5x3 以滑窗的形式遍历输入特征,因此对于一个尺寸为5x5的卷积的参数个数为 5x5x3x16 。实际的计算量为5x5x3x16x256x256。可以看到,单层卷积计算量还是非常大的。MobileNetV1的提出就是为了解决这个问题。

在MobileNetV1中采用深度可分离卷积(depthwise separable convolution)来代替传统的CNN。depthwise separable convolution可以分为两部分:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第2张图片

                                                                                              图2 深度可分离卷积

其中Depthwise卷积是指不跨通道的卷积,也就是说特征图的每个通道有一个独立的卷积核,并且这个卷积核作用且仅作用在这个通道之上。对于图1所示的3通道输入特征,如果采用deepwise卷积进行操作,那么就变成了使用3个滤波器,每个滤波器单独的作用于一个通道上,每个卷积得到一个通道特征,最后合并产生3通道特征。也就是说使用deepwise卷积后不会改变原始输入通道数。

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第3张图片

                                                                                              图3 深度卷积

通过上述深度卷积示意图可以知道,该操作的参数量为5x5x3,计算量为5x5x3x256x256。

2.1.2 逐像素卷积

采用深度卷积对每层的通道数均进行了卷积特征提取操作,但是这些特征是在单一的特征通道上完成的,各个通道特征之间的信息是独立的,那么如何对各通道特征进行融合,使得最终的输入通道数为16呢?这里就可以是使用逐点卷积来完成。本质上来说,逐点卷积的作用就是来对特征通道进行升维和降维。

实际操作时使用1x1卷积来完成逐点卷积这个功能。该操作参数量为1x1x3x16, 计算量为1x1x3x252x252x16。

因此综合一下,采用深度可分离卷积总的参数量为5x5x3+1x1x3x16,相比于普通卷积的5x5x3x16,占(1/16+1/25)。而计算量上来看,采用深度可分离卷积总的计算量为5x5x3x256x256+1x1x3x252x252x16,相比于普通卷积的5x5x3x16x256x256,同样占(1/16+1/25)左右。如果我们采用的不是5x5卷积,而是常用的3x3卷积,那么一般来说使用深度可分离卷积仅仅只需要普通卷积的1/9左右计算量。

MobileNetV1正是基于这个原理,实现了模型参数和计算量的大幅减少。最为重要的是,尽管对模型进行了高度压缩,但是采用该算法精度上并没有下降很多,具体指标可以参考相关论文。

2.2 Pytorch实现

使用Pytorch实现深度可分离卷积比较简单,只需要设置好torch.nn.Conv2d()命令中的groups参数即可。根据官方定义该参数控制输入和输出之间的连接:group=1时输出是所有的输入的卷积;group=2,此时相当于有并排的两个卷积层,每个卷积层计算输入通道的一半,并且产生的输出是输出通道的一半,随后将这两个输出连接起来。当group=输入通道数时,此时就是我们需要的深度卷积。在深度卷积之后再跟一个1x1的卷积即可实现完整的深度可分离卷积。

具体代码如下:

def conv_dw(inp, oup, stride):
    '''
    深度可分离卷积
    inp:输入通道数
    oup:输出通道数
    stride:步长
    '''
    return nn.Sequential(
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU(inplace=True),
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),
    )

 

三、MobileNetV2原理和实现

3.1 原理

MobileNetV2主要创新点就是在MobileNetV1中加入了残差网络,同时提出了一个新的I激活函数ReLU6。

3.1.1 ReLU6激活函数

在MobileNetV2论文中指出,当输出通道数较少的时候使用ReLU会导致信息损耗严重,因此需要将ReLU替换成线性激活函数。为此,MobileNetV2提出了ReLU6激活函数,它是对ReLU在整数6上的截断,数学形式为:

                                                                       ReLU6=min(max(0,x),6)

示意图如下图所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第4张图片

也就是说输出值如果在0到6之间,那么输出值不变,当超过6时输出统一截断为6。该论文作者通过实验验证了上述激活函数的有效性。实际使用Pytorch时可以直接使用现成的nn.ReLU6()函数实现。

3.1.2 Inverted Residual

MobileNetV2使用了残差网路结构,并且在设计该结构时与以往的不同。深度卷积本身没有改变通道的能力,输入多少通道输出就是多少通道。如果输入通道很少的话,深度卷积(DW)只能在低维度上工作,这样效果并不会很好,所以MobileNetV2首先会“扩张”通道。通过前面可以知道逐点卷积(PW)也就是1×1卷积可以用来升维和降维,那就可以在深度卷积DW之前使用逐点卷积PW行升维(升维倍数为t,t=6),再在一个更高维的空间中进行卷积操作来提取特征,最后再采用PW将通道数下降还原回来,如下图所示:

最后像Resnet一样复用输入特征,引入shortcut连接,这样V2的单个block就是如下图形式:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第5张图片

可以发现,MobileNetV2采用了1×1 -> 3 ×3 -> 1 × 1 的卷积模式,并且采用了Shortcut结构。但是整体与Resnet结构有不同:

  • ResNet 先降维 (0.25倍)、卷积、再升维。
  • MobileNetV2 则是 先升维 (6倍)、卷积、再降维。

MobileV2的block刚好与Resnet的block相反,因此将其命名为Inverted residuals(反向残差)。

3.2 Pytorch实现

MobileNetV2中最关键的就是反向残差模型,代码如下:

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.use_res_connect = self.stride == 1 and inp == oup

        self.conv = nn.Sequential(
            # pw
            nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # dw
            nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # pw-linear
            nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

其中expand_ratio参数为了与论文一致一般取6。另外,上述代码中当stride == 1 且输入通道数 inp == 输出通道数oup时,采用resnet残差网络,在前向推理时会链接输入端。

至此,本文已讲述完MobileV1和MobileV2系列,下面将重点使用MobileV2算法,结合Unet网络模型进行人像和背景分割,最终实现类似单反的背景虚化效果。

四、应用(基于人像快速分割的背景虚化处理 )

4.1 概述

单反相机经常会被用来进行背景虚化拍摄以获取一些很漂亮的照片,通过镜头操作凸显照片中的主体内容,而其余背景部分呈现模糊效果。如下图所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第6张图片

但是,单反相机本身成本较高,因此,出现了很多软件算法来实现类似的背景虚化效果。算法实现时一般分为下面三个步骤:

  • 对前景物体进行抠图,得到前景抠图掩码;
  • 前景之外的背景进行模糊操作;
  • 模糊背景图和原始高清图按照抠图掩码进行融合

可以看到,整个背景虚化算法的核心在于准确的提取出前景的抠图掩码,为了能够高效率且准确的完成上述任务 ,我们采用基于深度学习的语义分割算法来实现,同时结合MobileNetV2算法,进一步加快算法的执行速度并且降低模型参数量,方便未来将应用集成到手机端运行。

4.2 算法原理

考虑到前景物体的多样性,一种有效的解决方案就是先用显著性检测算法将显著物体定位出来,再进行背景虚化,但是这种处理方式在物体边缘处分割精度不高,为此,我们进一步聚焦,将目标对准人像处理,即实现人像的自动分割,在这个基础上,可以训练出一个较准确的人像自动分割模型。

整个训练算法采用爱分割提供的3万多张高精度人像分割数据集进行实验,部分样例如下所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第7张图片一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第8张图片           一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第9张图片一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第10张图片

所有图像均已处理成600X800像素大小,每张图像均提供标注好的高精度掩码。

算法部分采用UNet网络结构,分成编码encode和解码decode两部分,其中编码部分使用MobileNetV2提供逐级下采样的特征图。模型示意图如下所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第11张图片

输入图像为3通道数据(R、G、B),输出也是三通道,对应前景掩码、背景掩码、不确定区域掩码。之所以不是二值分割,而是采用三值分割主要是仿照human semantic matting那篇论文。

对照前面的反向残差模型InvertedResidual,我们构造mobilenet_v2模型用于实现UNet中的编码网络:

class mobilenet_v2(nn.Module):
    def __init__(self, nInputChannels=3):
        super(mobilenet_v2, self).__init__()
        # 1
        self.head_conv = nn.Sequential(nn.Conv2d(nInputChannels, 32, 3, 1, 1, bias=False),
                                       nn.BatchNorm2d(32),
                                       nn.ReLU())
        # 1
        self.block_1 = InvertedResidual(32, 16, 1, 1)
        # 1/2 
        self.block_2 = nn.Sequential( 
            InvertedResidual(16, 24, 2, 6),
            InvertedResidual(24, 24, 1, 6)
            )
        # 1/4 
        self.block_3 = nn.Sequential( 
            InvertedResidual(24, 32, 2, 6),
            InvertedResidual(32, 32, 1, 6),
            InvertedResidual(32, 32, 1, 6)
            )
        # 1/8 
        self.block_4 = nn.Sequential( 
            InvertedResidual(32, 64, 2, 6),
            InvertedResidual(64, 64, 1, 6),
            InvertedResidual(64, 64, 1, 6),
            InvertedResidual(64, 64, 1, 6)            
            )
        # 1/16
        self.block_5 = nn.Sequential( 
            InvertedResidual(64, 96, 1, 6),
            InvertedResidual(96, 96, 1, 6),
            InvertedResidual(96, 96, 1, 6)          
            )
        # 1/32 
        self.block_6 = nn.Sequential( 
            InvertedResidual(96, 160, 2, 6),
            InvertedResidual(160, 160, 1, 6),
            InvertedResidual(160, 160, 1, 6)          
            )
        # 1/32
        self.block_7 = InvertedResidual(160, 320, 1, 6)

    def forward(self, x):
        x = self.head_conv(x)
        # 1
        s1 = self.block_1(x)
        # 1/2 
        s2 = self.block_2(s1)
        # 1/4
        s3 = self.block_3(s2)
        # 1/8
        s4 = self.block_4(s3)
        s4 = self.block_5(s4)
        # 1/16
        s5 = self.block_6(s4)
        s5 = self.block_7(s5)

        return s1, s2, s3, s4, s5

完整的模型定义为TNet,代码如下:

class tnet(nn.Module):
    '''
        mmobilenet v2 + unet 

    '''

    def __init__(self, classes=3):

        super(tnet, self).__init__()
        # -----------------------------------------------------------------
        # encoder  
        # ---------------------
        self.feature = mobilenet_v2()

        # -----------------------------------------------------------------
        # decoder 
        # ---------------------

        self.s5_up_conv = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                        nn.Conv2d(320, 96, 3, 1, 1),
                                        nn.BatchNorm2d(96),
                                        nn.ReLU())
        self.s4_fusion = nn.Sequential(nn.Conv2d(96, 96, 3, 1, 1),
                                       nn.BatchNorm2d(96))

        self.s4_up_conv = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                        nn.Conv2d(96, 32, 3, 1, 1),
                                        nn.BatchNorm2d(32),
                                        nn.ReLU())
        self.s3_fusion = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),
                                       nn.BatchNorm2d(32))

        self.s3_up_conv = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                        nn.Conv2d(32, 24, 3, 1, 1),
                                        nn.BatchNorm2d(24),
                                        nn.ReLU())
        self.s2_fusion = nn.Sequential(nn.Conv2d(24, 24, 3, 1, 1),
                                       nn.BatchNorm2d(24))

        self.s2_up_conv = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                        nn.Conv2d(24, 16, 3, 1, 1),
                                        nn.BatchNorm2d(16),
                                        nn.ReLU())
        self.s1_fusion = nn.Sequential(nn.Conv2d(16, 16, 3, 1, 1),
                                       nn.BatchNorm2d(16))

        self.last_conv = nn.Conv2d(16, classes, 3, 1, 1)
        self.last_up = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, input):

        # -----------------------------------------------
        # encoder 
        # ---------------------
        s1, s2, s3, s4, s5 = self.feature(input)
        # -----------------------------------------------
        # decoder
        # ---------------------
        s4_ = self.s5_up_conv(s5)
        s4_ = s4_ + s4
        s4 = self.s4_fusion(s4_)

        s3_ = self.s4_up_conv(s4)
        s3_ = s3_ + s3
        s3 = self.s3_fusion(s3_)

        s2_ = self.s3_up_conv(s3)
        s2_ = s2_ + s2
        s2 = self.s2_fusion(s2_)

        s1_ = self.s2_up_conv(s2)
        s1_ = s1_ + s1
        s1 = self.s1_fusion(s1_)

        out = self.last_conv(s1)

        return out

4.3 训练

在训练阶段,需要为每张图像提供Trimap图,因此,需要对数据进行预处理,代码如下:

def genAiFenGe():
    """
    生成标准化的AiFenGe数据集,同时生成JSON文件列表
    """
    # 设置拷贝路径
    src_img_folder='E:\deeplearn\Matting_Human_Half\clip_img' 
    src_alpha_folder='E:\deeplearn\Matting_Human_Half\matting'
    des_img_folder='./data/AiFenGe/img' 
    des_alpha_folder='./data/AiFenGe/alpha' 
    des_trimap_folder='./data/AiFenGe/trimap' 

    # 检索文件
    imglist = getFileList(src_img_folder, [], 'jpg')
    alphalist = getFileList(src_alpha_folder, [], 'png')

    print('检索到 '+str(len(imglist))+' 个原始图像')
    print('检索到 '+str(len(alphalist))+ '个alpha通道图')

    # 逐张检查
    index=0
    save_img_list=list()
    save_alpha_list=list()
    save_trimap_list=list()

    for imgpath in imglist:
        imgname= os.path.splitext(os.path.basename(imgpath))[0]
        alphaname=imgname+'.png'

        for j in range(len(alphalist)):
            if alphaname in alphalist[j]:
                alphapath = alphalist[j]
                try:
                    img = cv2.imread(imgpath, cv2.IMREAD_COLOR)

                    alpha = cv2.imread(alphapath, cv2.IMREAD_UNCHANGED)
                    alpha = alpha[:,:,3] # 分离alpha通道
                    ret,alpha = cv2.threshold(alpha,50,255,cv2.THRESH_BINARY)

                    # 生成trimap
                    trimap = erode_dilate(alpha)

                    # 保存   
                    cv2.imwrite(des_img_folder+('/%d.png' % (index)),img)
                    cv2.imwrite(des_alpha_folder+('/%d.png' % (index)),alpha)
                    cv2.imwrite(des_trimap_folder+('/%d.png' % (index)),trimap)

                    # 记录
                    save_img_list.append(des_img_folder+('/%d.png' % (index)))
                    save_alpha_list.append(des_alpha_folder+('/%d.png' % (index)))
                    save_trimap_list.append(des_trimap_folder+('/%d.png' % (index)))

                    index += 1
                    print('当前写入第 %d 张图片' % (index))

                except Exception as err:
                    print(err)

    # 写入json文件
    with open('./data/aifenge_img.json', 'w') as jsonfile1:
        json.dump(save_img_list, jsonfile1)

    with open('./data/aifenge_alpha.json', 'w') as jsonfile2:
        json.dump(save_alpha_list, jsonfile2)

    with open('./data/aifenge_trimap.json', 'w') as jsonfile3:
        json.dump(save_trimap_list, jsonfile3)

    print('共写入 %d 张图片' % (index))

其中用于生成trimap图的erode_dilate函数如下:

def erode_dilate(mask, size=(10, 10), smooth=True):
    """
    腐蚀膨胀生成trimap
    输入 mask:单通道二值掩码图
    """
    # 构造核
    if smooth:
        size = (size[0]-4, size[1]-4)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size)

    # 膨胀
    dilated = cv2.dilate(mask, kernel, iterations=1)
    if smooth:  
        dilated[(dilated>5)] = 255
        dilated[(dilated <= 5)] = 0
    else:
        dilated[(dilated>0)] = 255

    # 腐蚀
    eroded = cv2.erode(mask, kernel, iterations=1)
    if smooth:
        eroded[(eroded<250)] = 0
        eroded[(eroded >= 250)] = 255
    else:
        eroded[(eroded < 255)] = 0

    res = dilated.copy()
    res[((dilated == 255) & (eroded == 0))] = 128    
    return res

通过上述放方式,我们就形成了img、alpha和trimap三个文件夹,分别用于存放原始图像、分割真值、trimap图,同时形成了3个用于训练的json列表文件。接下来就是构造数据加载器来加载数据,代码如下:

class HumanDataset(Dataset):
    """
    人像数据集
    """
    def __init__(self, dataname, transforms=None):

        items = []
        img_path = './data/'+ dataname + '_img.json'
        trimap_path = './data/'+ dataname + '_trimap.json'
        alpha_path = './data/'+ dataname + '_alpha.json'

        with open(img_path, 'r') as j:
            imglist = json.load(j)
        with open(trimap_path, 'r') as j:
            trimaplist = json.load(j)
        with open(alpha_path, 'r') as j:
            alphalist = json.load(j)

        for i in range(len(imglist)):
            items.append((imglist[i], trimaplist[i], alphalist[i]))

        self.items = items
        self.transforms = transforms

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        image_name, trimap_name, alpha_name = self.items[index]
        image = cv2.imread(image_name, cv2.IMREAD_COLOR)
        trimap = cv2.imread(trimap_name, cv2.IMREAD_GRAYSCALE)
        alpha = cv2.imread(alpha_name, cv2.IMREAD_GRAYSCALE)

        if self.transforms is not None:
            for transform in self.transforms:
                image, trimap, alpha = transform(image, trimap, alpha)

        return image, trimap, alpha

其中,给出几个变换函数:

class RandomPatch(object):
    """
    自定义压缩变换
    """
    def __init__(self, patch_size):
        self.patch_size = patch_size

    def __call__(self, image, trimap, alpha):
        image = cv2.resize(image, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)
        trimap = cv2.resize(trimap, (self.patch_size, self.patch_size), interpolation=cv2.INTER_NEAREST)
        alpha = cv2.resize(alpha, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)

        return image, trimap, alpha


class Normalize(object):
    """
    自定义归一化操作
    """
    def __call__(self, image, trimap, alpha):
        image = (image.astype(np.float32) - (114., 121., 134.,)) / 255.0
        trimap[trimap == 0] = 0
        trimap[trimap == 128] = 1
        trimap[trimap == 255] = 2
        alpha = alpha.astype(np.float32) / 255.0
        return image, trimap, alpha


class NumpyToTensor(object):
    """
    numpy数组转张量tensor
    """
    def __call__(self, image, trimap, alpha):
        h, w, c = image.shape
        image = torch.from_numpy(image.transpose((2, 0, 1))).view(c, h, w).float()
        trimap = torch.from_numpy(trimap).view(-1, h, w).long()  
        alpha = torch.from_numpy(alpha).view(1, h, w).float()
        return image, trimap, alpha
        
    
class TrimapToCategorical(object):
    """
    单通道trimap变成3通道图:b、u、f
    """
    def __call__(self, image, trimap, alpha):
        trimap = np.array(trimap, dtype=np.int)
        input_shape = trimap.shape
        trimap = trimap.ravel()
        n = trimap.shape[0]
        categorical = np.zeros((3, n), dtype=np.long)
        categorical[trimap, np.arange(n)] = 1
        output_shape = (3,) + input_shape
        categorical = np.reshape(categorical, output_shape)
        return image, categorical, alpha

最后给出完整的训练脚本,采用的是Pytroch1.4版本。

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from model.tnet import tnet
from datasets import HumanDataset,RandomPatch,Normalize,NumpyToTensor
import torch.nn.functional as F
import time
from utils import *
from loss import ClassificationLoss


# 数据集参数
data_folder = './data/'   # 数据存放路径
dataname = 'aifenge'      # 数据集名称

# 学习参数
checkpoint = './results/tnet.pth'         # 预训练模型路径,如果不存在则为None
batch_size = 128          # 批大小
start_epoch = 146           # 轮数起始位置
epochs = 300              # 迭代轮数
workers = 4               # 工作线程数
lr = 0.00001              # 学习率             
weight_decay = 0.00001    # 权重延迟

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 4           # 用来运行的gpu数量
cudnn.benchmark = True # 对卷积进行加速

# 日志
writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看

def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 初始化
    model = tnet()
    # 初始化优化器
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=lr, betas=(0.9, 0.999),
                                    weight_decay=weight_decay)

    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = ClassificationLoss()
    criterion.to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['tnet'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的数据加载器
    transforms = [
                RandomPatch(320),
                Normalize(),
                NumpyToTensor()
            ]
    train_dataset = HumanDataset(dataname,transforms)
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    preloss = 10000000
    
    for epoch in range(start_epoch, epochs+1):

        # if epoch == 23:  # 适当降低学习率
        #     adjust_learning_rate(optimizer, 0.1)

        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)

        # 按批处理
        for i, (imgs, trimaps_gt, alphas) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            imgs = imgs.to(device)  
            trimaps_gt = trimaps_gt.to(device)  
 
            # 前向传播
            trimaps_pre = model(imgs)

            # 计算损失
            loss = criterion(trimaps_pre, trimaps_gt)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), imgs.size(0))

            # 监控图像变化
            if i == n_iter-2:
                trimaps_pre_temp = trimap_to_image(trimaps_pre[:4,:3,:,:])                
                writer.add_image('TNet/epoch_'+str(epoch)+'_1', make_grid(imgs[:4,:,:,:].cpu(), nrow=4,normalize=True),epoch)
                writer.add_image('TNet/epoch_'+str(epoch)+'_2', make_grid(trimaps_pre_temp, nrow=4, normalize=True),epoch)
                writer.add_image('TNet/epoch_'+str(epoch)+'_3', make_grid(trimaps_gt[:4,:,:,:].float().cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第 "+str(i+1)+ " 个batch训练结束")

        # 手动释放内存
        del imgs, trimaps_pre, trimaps_gt, alphas, trimaps_pre_temp
        print('第'+str(epoch)+'个epoch训练结束')
        
        # 监控损失值变化
        writer.add_scalar('pretrain_tnet/Loss', loss_epoch.val, epoch)    

        # 保存预训练模型
        if loss_epoch.val < preloss:
            preloss = loss_epoch.val
            print("保存预训练模型\n")
            torch.save({
                'epoch': epoch,
                'tnet': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/tnet.pth')

    # 训练结束关闭监控
    writer.close()


if __name__ == '__main__':
    main()

采用4块泰坦显卡,训练结果如下:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第12张图片

下图是训练结束时的人像分割效果图,第一行为输入原图,中间一行为模型预测结果,最后一行为groundtruth真值。可以看到,整体的分割精度还是不错的,使用MobileNetV2能够在保证分割精度的前提下大幅降低整个模型的参数,最终训练好的模型只有26M左右。

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第13张图片  一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第14张图片

 

4.4 背景虚化

本节,我们首先对原始图像Input进行高斯模糊得到模糊图像Blur,然后我们利用前面训练好的模型对输入图像进行语义分割,确定出人像掩码Mask,最后进行合成,合成公式如下图所示:

                                                      Output=Input.*mask+(255-mask).*Blur

完整代码如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from model.tnet import tnet
from utils import * 
import time
import cv2


# 模型参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")


if __name__ == '__main__':

    # 测试图像
    img_id='1'
    imgPath = './results/'+img_id+'.jpg'

    # 加载图像
    input = cv2.imread(imgPath, cv2.IMREAD_COLOR)
    width = input.shape[1]
    height = input.shape[0]
    
    # 多次高斯模糊
    blur = cv2.GaussianBlur(input,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
    blur = cv2.GaussianBlur(blur,(3,3),3)
 
    cv2.imwrite('./results/blur.jpg',blur)

    # 预训练模型
    checkpoint = "./results/tnet.pth"

    # 加载模型
    checkpoint = torch.load(checkpoint)
    model = tnet()

    model = model.to(device)
    model.load_state_dict(checkpoint['tnet'])

    model.eval()    
    
    # 图像预处理 
    img = cv2.resize(input, (320,320), interpolation = cv2.INTER_CUBIC)
    img = (img.astype(np.float32) - (114., 121., 134.,)) / 255.0
    h, w, c = img.shape
    img = torch.from_numpy(img.transpose((2, 0, 1))).view(c, h, w).float()
    img= img.view(1, 3, h, w)

    # 记录时间
    start = time.time()

    # 转移数据至设备
    img = img.to(device)

    # 模型推理
    with torch.no_grad():
        trimap = model(img)
        trimap=trimap_to_image(trimap)

        # 保存trimap
        trimap = trimap.squeeze(0).float().mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()  
        cv2.imwrite('./results/trimap'+img_id+'.png',trimap)

        # 缩放并保存alpha通道图
        trimap = cv2.resize(trimap, (width,height), interpolation = cv2.INTER_CUBIC)
        trimap = cv2.cvtColor(trimap,cv2.COLOR_GRAY2BGR)
        # 与原图合成,生成背景虚化图
        trimap_f = trimap / 255.
        comp = input * trimap_f + blur * (1. - trimap_f)
        cv2.imwrite('./results/comp'+img_id+'.png',comp.astype(np.uint8))

    print('用时  {:.3f} 秒'.format(time.time()-start))

最终效果图如下所示:

一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第15张图片一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化)_第16张图片

上图中成功的将前景人像进行了保持,而背景部分进行了虚化,从而实现了类似单反才能拍出的人像摄影特效。

五、小结

本文详细阐述了MobileNet系列算法原理,在此基础上进行了案例运用,通过语义分割算法实现了背景虚化应用,各部分给出了基于Pytorch的代码。后续将会进一步尝试语义分割相关内容,感兴趣的读者可以继续关注!

参考文献

【1】Howard A, Zhu M, Chen B, et al. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications[J]. arXiv: Computer Vision and Pattern Recognition, 2017.

【2】Sandler M, Howard A, Zhu M, et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks[C]. computer vision and pattern recognition, 2018: 4510-4520.

你可能感兴趣的:(一文掌握MobileNetV1和MobileNetV2(基于pytorch实现的人像背景虚化))