retina face 学习总结

想到人脸检测就会想到几个关键的问题:
a)人脸检测多尺度问题(大家会用图像金字塔、特征金字塔等)
b)

RetinaFace 人脸检测代码(std)

  • 网络结构如下图所示,主要包括四个部分,:(1)backbone(2)特征金字塔(3)上下文结构(4)损失函数

0.1 代码部分定义如下,

第一部分:body需要在backbone上做一定改动,对参数中的return_layers做改动,将需要输出的特征层输出,用于特征金字塔上采样。
第二部分:特征金字塔:上一层backbone将输出三层特征后,FPN对数据进行上采样(F.interpolate)+合并(直接元素相加)的操作。输入为存三个feature map的列表,输出也是三个feature map的列表。

Note1:将高层特征语义信息传回底层特征,这样可以使得底层特征融合之后语义信息更强,使得网络可以综合考虑不同尺度的人脸(浅层网络可以提取小脸特征、深层网络提取深层特征)。
Note2:a)MTCNN采用图像金字塔来综合多尺度信息;b)CenterFace也用了特征金字塔网络;金字塔网络必备

  • 第三部分:上下文结构:三个上下文结构分别对特征金字塔输出的三层feature map处理

Note3:

  • 第四部分:损失计算:对三组特征分别先卷积再view计算class、bbox、landmark的输出,并计算损失。

Note4:多任务损失把landmark引入进来带来更多的监督信息…

0.2代码中的模块定义

#step 1 backbone 主体 body返回三个output分别代表特征金字塔结构的三个特征
self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
# step 2 特征金字塔 输入为body的输出,输出为将三组特征
self.fpn = FPN(in_channels_list, out_channels)
# step 3 上下文结构:三个上下文结构分别对特征金字塔的
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
# step 4 多尺度损失函数计算
self.ClassHead
self.BboxHead
self.LandmarkHead

0.3 代码前向传播

    def forward(self, inputs):
        out = self.body(inputs)

        # FPN
        fpn = self.fpn(out)
        # SSH
        feature1 = self.ssh1(fpn[0])
        feature2 = self.ssh2(fpn[1])
        feature3 = self.ssh3(fpn[2])
        features = [feature1, feature2, feature3]
        bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
        classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
        ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)

        if self.phase == 'train':
            output = (bbox_regressions, classifications, ldm_regressions)
        else:
            output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
        return output

0.4 几个关键模块的代码

特征金字塔:特征金字塔需要利用backbone的中间层,所以要在backbone上做修改,所以会有一个IntermediateLayerGetter 函数去修改,对backbone的module遍历,找到需要return的层,并合并输出。主要结构是这样的:以下列出其中几个模块是如何定义的(自己学习)

    def forward(self, x):

#先定义一个有序集合,set这种无序集合会搞乱输出。
        out = OrderedDict()

#遍历子模块
        for name, module in self.named_children():
            #传播保持不变

x = module(x)

#碰到需要输出的,保存到out
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

改好backbone后需要按照特征金字塔去搭建,每个位置用了BN,将深层特征上采样到与浅层特征channel相同,再元素相加合并,最后需要把三层特征输出。

class FPN(nn.Module):
    def __init__(self,in_channels_list,out_channels):
        super(FPN,self).__init__()
        leaky = 0
        if (out_channels <= 64):
            leaky = 0.1
        self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)



        self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
        self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)

    def forward(self, input):
        # names = list(input.keys())
        input = list(input.values())

        output1 = self.output1(input[0])
        output2 = self.output2(input[1])
        output3 = self.output3(input[2])

        up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
        output2 = output2 + up3
        output2 = self.merge2(output2)

        up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
        output1 = output1 + up2
        output1 = self.merge1(output1)

        out = [output1, output2, output3]
        return out

上下文结构:这部分比较简单,做了层次卷积,然后cat到一起。

class SSH(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(SSH, self).__init__()
        assert out_channel % 4 == 0
        leaky = 0
        if (out_channel <= 64):
            leaky = 0.1
        self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)

        self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
        self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
        self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
        self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)

    def forward(self, input):
        conv3X3 = self.conv3X3(input)

        conv5X5_1 = self.conv5X5_1(input)
        conv5X5 = self.conv5X5_2(conv5X5_1)

        conv7X7_2 = self.conv7X7_2(conv5X5_1)
        conv7X7 = self.conv7x7_3(conv7X7_2)

        out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
        out = F.relu(out)
        return out

你可能感兴趣的:(人脸识别,计算机视觉CV)