pytorch-ssd源码解读(二)------------detection(预测层)

一、定义

detection层是ssd预测阶段的最后一层。它接收底层网络输出的位置偏移量(loc_data)、各个框置信度(conf_data)以及默认框(prior_data)。该层的作用是整合各层的预测结果,过滤置信度太低的预测框,通过类内nms抑制大量相同的预测框。

二、代码解读

1.输入

  • loc_data    网络六个层预测的坐标偏移。Shape:[batch,num_priors*4]
  • conf_data     各预测框对应各类的置信度(每个预测框针对每一类都预测一个得分,因此每个预测框对应num_classes个得分)                            Shape: [batch*num_priors,num_classes]
  • prior_data   默认框(上篇博客中介绍过),网络的预测框其实就是针对这个默认框的偏移量。 Shape: [num_priors,4]

2.参数配置

  • top_k                          一张图片中,每一类保存top_k个预测框
  • conf_thresh                置信度阈值,置信度低于该阈值的预测框会被抛弃
  • nms_thresh                nms阈值

3.输出格式

    output的Shape为 [batch, num_classes, top_k, 5] 

   取其中一个输出output[i,j,k,:]表示在当前mini_batch中的第i张图片的第j类的第k个框对应的预测结果。

   最后一维的五个数依次为[score,xmin,ymin,xmax,ymax]

 4.重点代码解读

NO.1

"""
conf_data  :  [batch*num_priors,num_classes]
conf_preds :  [batch,num_classes,num_priors]
"""
conf_preds = conf_data.view(num, num_priors,
                              self.num_classes).transpose(2, 1)

conf_data是预测框对应各类的置信度,为一个二维Tensor,Shape: [batch*num_priors,num_classes]。

conf_data.view(num, num_priors,self.num_classes)先将Tensor由[batch*num_priors,num_classes]转变为[batch,num_priors,num_classes]。然后通过transpose(2,1)调换第1和第2维度,最终变为[batch,num_classes,num_priors]形式.

该步骤的目的是方便下面按照batch,class的顺序进行处理。

NO.2

decoded_boxes = decode(loc_data[i], prior_data, self.variance)

由预测的偏移值和默认框生成最终的预测框 。

偏移值和默认框为[x,y,w,h]形式,解码的预测框为[xmin,ymin,xmax,ymax]形式。

最终decode_boxes的shape为[num_priors,4],关于decode函数文末会详细讲一下。

NO.3

c_mask = conf_scores[cl].gt(self.conf_thresh)  #选择大于设置的阈值的得分掩码
scores = conf_scores[cl][c_mask]               #筛选大于阈值的得分
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)   #c_mask是一维的,将其扩展为[num_priors,4]
boxes = decoded_boxes[l_mask].view(-1, 4) 

这段代码先筛选出大于阈值的置信度,然后取出对应的预测框。

个人感觉不如直接写为下面的代码简洁:

c_mask = conf_scores[cl].gt(self.conf_thresh)  
scores = conf_scores[cl][c_mask]              
boxes = decoded_boxes[c_mask]

NO.5

ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)

该步骤对上一步经过阈值筛选的boxes和scores进行nms极大值抑制,取出符合条件的预测框。

输出ids为符合条件的预测框的index是一个列表,count为符合条件的框的个数。

附件:

def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers           预测框的偏移量
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers   每个预测框中各类别的得分
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers       默认框
                Shape: [1,num_priors,4]

        该函数根据默认框和预测框计算出最终的检测框
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)  #默认框的数目
        output = torch.zeros(num, self.num_classes, self.top_k, 5)
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)

        # Decode predictions into bboxes.
        for i in range(num):
            """
            由预测的偏移值和默认框生成最终的预测框
            偏移值和默认框为[x,y,w,h]形式,解码的预测框为[xmin,ymin,xmax,ymax]形式
            decode_boxes的shape为[num_priors,4]
            """
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)
            # For each class, perform nms
            #一张图片中所有的预测框得分,shape为[num_classes,num_priors]
            conf_scores = conf_preds[i].clone()
            #类内做nms
            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh)  #选择大于设置的阈值的得分掩码
                scores = conf_scores[cl][c_mask]               #筛选大于阈值的得分
                if scores.dim() == 0:                          #如果当前类没有符合条件的预测框,继续下一个类的循环
                    continue
                """
                decoded_boxes[l_mask]其实是一维的,排列方式形如[x1min,y1min,x1max,y1max,x2min,y2min,x2max,y2max...]
                因此decoded_boxes[l_mask].view(-1, 4) 才会转变为[num_priors,4]的形状,使得每一行对应一个bbox
                个人感觉大可不必使用l_mask,可将下面代码直接替代为
                boxes = decode_boxes[c_mask]
            
                """
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)   #c_mask是一维的,将其扩展为[num_priors,4]
                boxes = decoded_boxes[l_mask].view(-1, 4)               
                # idx of highest scoring and non-overlapping boxes per class
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                """
                构造输出,最后一个output为(batchsize, num_classes, self.top_k, 5)形状
                最后一纬的5个数为[score,xmin,ymin,xmax,ymax]
                """
                output[i, cl, :count] = \
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]]), 1)
        flt = output.contiguous().view(num, -1, 5)
        _, idx = flt[:, :, 0].sort(1, descending=True)
        _, rank = idx.sort(1)
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
        return output

 

你可能感兴趣的:(pytorch)