【YOLOv3 NMS】YOLOv3中的非极大值抑制

文章目录

  • 1 NMS问题由来
  • 2 NMS操作流程
    • 2.1 进行NMS前要先有什么
    • 2.2 NMS流程
  • 3 NMS代码解读
  • 4 感谢链接

1 NMS问题由来

利用YOLOv3网络结构提取到out0、out1、out2之后,不同尺度下每个网格点上均有先验框,网络训练过程会对先验框的参数进行调整,继而得到预测框,从不同尺度下预测框还原到原图输入图像上,同时包括该框内目标预测的结果情况(预测框位置、类别概率、置信度分数)。

问题: 每个网格节点上都有3个框,每个框对应到不同类别上也都有概率,然而最终输出不可能把这些框都输出了,怎么办呢?
回答: 非极大值抑制(Non Maximum Suppression, NMS)登场。
NMS功能:筛选出一定区域内,属于 同一种类 得分最大的方框。

以知乎上一张图来说明问题的由来和NMS后效果,感谢原文章。

【YOLOv3 NMS】YOLOv3中的非极大值抑制_第1张图片

2 NMS操作流程

2.1 进行NMS前要先有什么

先行准备:预测得到边界框列表及其对应的置信度得分列表,设定阈值,阈值用来删除重叠较大的边界框。

IoU定义:intersection-over-union,即两个边界框的交集部分除以它们的并集部分。

2.2 NMS流程

非极大值抑制的流程如下:

(1) 对 候选边界框列表(称之为List_A) 中的边界框,根据置信度得分进行排序,选择 置信度得分最高的边界框(称之为Max_A) 添加到 最终输出列表(称之为List_B) 中,并将其从候选边界框列表List_A中删除;
(2) 在候选边界框列表List_A中,遍历其它候选框,若其与 Max_A 的IoU 大于一定阈值,我们将这个候选框删除;(为什么要删除? 因为超过设定的阈值,认为两个框的里面的物体属于同一个物体,留下一个置信度更高的候选框即可。)
(3) 此时再从候选框边界列表List_A中,选一个置信度分数最高的添加到List_B中,重复(1)、(2)过程,直至候选边界框列表List_A为空。

问题: 为什么NMS阈值小一些,显示的框会少一些。
答案: NMS的阈值越小,说明同一块区域的两个框的IoU只要 稍微大一点 (交集多一点),就会被认定为一个物体,自然预测框就会少一点了。

3 NMS代码解读

没有前因后果的NMS代码,是没有灵魂的,比如直接给出这个:

  • torchvision中提供的nms:
from torchvision.ops import nms
  • nms源码实现:
    def nms_origin_code(self, bboxes, scores, threshold=0.5):
        x1 = bboxes[:,0]
        y1 = bboxes[:,1]
        x2 = bboxes[:,2]
        y2 = bboxes[:,3]
        areas = (x2-x1)*(y2-y1)   # [N,] 每个bbox的面积
        _, order = scores.sort(0, descending=True)    # 降序排列

        keep = []
        while order.numel() > 0:       # torch.numel()返回张量元素个数
            if order.numel() == 1:     # 保留框只剩一个
                i = order.item()
                keep.append(i)
                break
            else:
                i = order[0].item()    # 保留scores最大的那个框box[i]
                keep.append(i)

            # 计算box[i]与其余各框的IOU(思路很好)
            xx1 = x1[order[1:]].clamp(min=x1[i])   # [N-1,]
            yy1 = y1[order[1:]].clamp(min=y1[i])
            xx2 = x2[order[1:]].clamp(max=x2[i])
            yy2 = y2[order[1:]].clamp(max=y2[i])
            inter = (xx2-xx1).clamp(min=0) * (yy2-yy1).clamp(min=0)   # [N-1,]

            iou = inter / (areas[i]+areas[order[1:]]-inter)  # [N-1,]
            idx = (iou <= threshold).nonzero().squeeze() # 注意此时idx为[N-1,] 而order为[N,]
            if idx.numel() == 0:
                break
            order = order[idx+1]  # 修补索引之间的差值
        return torch.LongTensor(keep)   # Pytorch的索引值为LongTensor

何时用?何处用?怎么用?

故,以在YOLOv3解码DecodeBox()中的NMS代码为例,如下:

import torch
from torchvision.ops import nms		# torchvision中提供了nms
import numpy as np


class DecodeBox():
    def __init__(self, anchors, num_classes, input_shape, anchors_mask=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]):
        super(DecodeBox, self).__init__()
        self.anchors = anchors				# ndarray:(9, 2)
        self.num_classes = num_classes  	# int   20
        self.bbox_attrs = 5 + num_classes  	# int   25=20+4+1
        self.input_shape = input_shape  	# (416, 416) list or tuple
        # -----------------------------------------------------------#
        #   13x13的特征层对应的anchor是[116,90],[156,198],[373,326]
        #   26x26的特征层对应的anchor是[30,61],[62,45],[59,119]
        #   52x52的特征层对应的anchor是[10,13],[16,30],[33,23]
        # -----------------------------------------------------------#
        self.anchors_mask = anchors_mask

    # ----------------------------------------------------------------------------------#
    #   得到out0、out1、out2不同尺度下每个网格点上的的预测情况(预测框位置、类别概率、置信度分数)
    # ----------------------------------------------------------------------------------#
    def decode_box(self, inputs):  # input一共有三组数据,out0,out1,out2
        outputs = []
        # ...中间部分参考链接 https://blog.csdn.net/weixin_45377629/article/details/124144913
        #	强烈建议先去上面的链接中,看懂该部分内容后,再往下看代码。
        # ...
        # 得到out0、out1、out2不同尺度下 每个网格点上 的预测情况(预测框位置、类别概率、置信度分数)
        return outputs

    # ------------------------------------------------------#
    #   NMS一通操作后,把在网络输入图像上的box都整出来了,
    #		但我们在把图像输入网络前,进行了resize操作,
    #		通过这个函数,把这些box的信息,真正整到原图上。
    #		若用了letter_box,产生灰条,还原方式有点不同
    #	先看non_max_suppression()函数
    #	box_xy与box_wh:框中心点坐标和宽高,shape:(9,2)  array  float32
    #   input_shape:[416,416],网络输入尺寸
    #   image_shape:原来图像尺寸
    #	letterbox_image:是否使用letterbox缩放,True or False
    # ------------------------------------------------------#
    def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
        # -----------------------------------------------------------------#
        #   把y轴放前面是因为方便预测框和图像的宽高进行相乘
        #	在这儿要说明一下:YOLO系列中,box的中心点位置以及宽高都是归一化(0~1)的。
        #		详细解读见:https://blog.csdn.net/weixin_45377629/article/details/124116916
        # -----------------------------------------------------------------#
        box_yx = box_xy[..., ::-1]		# box_yx shape:(9,2)
        box_hw = box_wh[..., ::-1]		# box_hw shape:(9,2)
        input_shape = np.array(input_shape)
        image_shape = np.array(image_shape)

        if letterbox_image:
            # -----------------------------------------------------------------#
            #   这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
            #   new_shape指的是宽高缩放情况
            #	先看懂letterbox怎么回事:https://blog.csdn.net/weixin_45377629/article/details/124027705
            # -----------------------------------------------------------------#
            new_shape = np.round(image_shape * np.min(input_shape / image_shape))
            offset = (input_shape - new_shape) / 2. / input_shape
            scale = input_shape / new_shape

            box_yx = (box_yx - offset) * scale
            box_hw *= scale

		# 似乎又变成box左上、右下坐标了!
        box_mins = box_yx - (box_hw / 2.)		# box_mins shape:(9,2)
        box_maxes = box_yx + (box_hw / 2.)		# box_maxes shape:(9,2)
        # -----------------------------------------------------------------#
        # 	boxes shape: (9,4),里面元素值都很小
        # -----------------------------------------------------------------#
        boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
        # -----------------------------------------------------------------# 
        # 	boxes shape: (9,4),里面元素值都很大
        #	    image_shape = np.array([5,4])
    	#		x = np.concatenate([image_shape, image_shape], axis=-1)
    	#		print(x)     # [5 4 5 4]
    	#	归一化后的box坐标数据,去乘原图的宽高,得到在原图上的box位置,很合理!
    	# -----------------------------------------------------------------#
        boxes *= np.concatenate([image_shape, image_shape], axis=-1)
        return boxes
    
    # --------------------------------------------------------------------#
    #	非极大值抑制NMS源码
	#	先看non_max_suppression()函数
    #   nms_origin_code() 等效于 from torchvision.ops import nms
    #	返回的内容:tensor([1])或tensor([14,12,7]),装的是NMS后,所需要的
    #		box在 detections_class/bboxes 中的索引,根据索引就可以取到box了	
    # --------------------------------------------------------------------#
    def nms_origin_code(self, bboxes, scores, threshold=0.5):
        # ---------------------------#
        #   x1,y1均为张量列表
        #   areas 张量列表,每个bbox的面积
        #   order:张量列表,里面存放的是索引
        # ---------------------------#
        x1 = bboxes[:,0]
        y1 = bboxes[:,1]
        x2 = bboxes[:,2]
        y2 = bboxes[:,3]
        areas = (x2-x1)*(y2-y1)         
        _, order = scores.sort(0, descending=True)    # 降序排列

        keep = []
        while order.numel() > 0:        # tensor.numel()返回张量元素个数,该函数只用于tensor
            if order.numel() == 1:      # 保留框只剩一个
                i = order.item()        # tensor.item()返回张量的值,i变成了一个值
                keep.append(i)
                break
            else:
                i = order[0].item()    # 保留scores最大的那个框的索引i。box[i]
                keep.append(i)

            # ---------------------------------------------------------------------------#
            #   计算box[i]与其余各框的IOU
            #   tensor.clamp(min,max):将输入张量每个元素的夹紧到区间[min,max],并返回结果到一个新张量。
            #   左上要注意小值,右下要注意大值,超出可不行
            # ---------------------------------------------------------------------------#
            xx1 = x1[order[1:]].clamp(min=x1[i])   # x1有4个元素的话,xx1就只有3个,要少一个
            yy1 = y1[order[1:]].clamp(min=y1[i])
            xx2 = x2[order[1:]].clamp(max=x2[i])
            yy2 = y2[order[1:]].clamp(max=y2[i])
            # -------------------------------------------------------------------#
            #   交集面积inter,张量列表
            #   注意inter是order[0]和其余的交集面积,len(inter)要比len(order)少1
            #   下面的iou同理
            # -------------------------------------------------------------------#  
            inter = (xx2-xx1).clamp(min=0) * (yy2-yy1).clamp(min=0)   # [N-1,]

            iou = inter / (areas[i]+areas[order[1:]]-inter)  # [N-1,]
            # -----------------------------------------------#
            #   tensor.nonzero()找到tensor中所有不为0的索引
            # -----------------------------------------------
            idx = (iou <= threshold).nonzero().squeeze() # 注意此时idx为[N-1,] 而order为[N,]
            if idx.numel() == 0:        # 没元素了就退出,一定会没的
                break
            order = order[idx+1]        # 修补索引之间的差值
        return torch.LongTensor(keep)   # Pytorch的索引值为LongTensor


    # -----------------------------------------------------------------------------------------------#
    #   prediction:torch.cat(outputs, 1),将预测框进行堆叠,详解见下方
    #   image_shape:数据集中图片真实尺寸
    #   letterbox_image:True/False,表示是否使用letterbox方式处理数据,
    #		letterbox详细解读可见https://blog.csdn.net/weixin_45377629/article/details/124027705
    #   conf_thres=confidence:表示置信度阈值,只有得分大于置信度的预测框会被保留下来,范围是0~1,常选0.5
    #   nms_thres=nms_iou:非极大抑制所用到的nms_iou阈值大小
    # -----------------------------------------------------------------------------------------------#
    def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5,
                            nms_thres=0.4):
        # ------------------------------------------------------------------------#
        #   将预测结果的格式从(中心、宽高)形式转换成(左上角、右下角)的格式。
        #   prediction表示预测结果  
        #	prediction.shape:torch.size([batch_size, num_anchors, 25])
        #	以voc为例:torch.size([1, 10647, 25])
        #	10647怎么来的:10647=(13x13+26x26+52x52)x3
        #	故box_corner.shape:torch.size([1, 10647, 25])
        #	box_corner也就是个中间变量,用一下就不用了
        # ------------------------------------------------------------------------#
        box_corner = prediction.new(prediction.shape)	# 里面此时全为0
        box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2  # 左上x
        box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2  # 左上y
        box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2  # 右下x
        box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2  # 右下y
        prediction[:, :, :4] = box_corner[:, :, :4]		# 回归到prediction
		
		# ------------------------------------------#
		# 	len(prediction)=batch_size,一张图一个output
		#	output:[None]
		# ------------------------------------------#
        output = [None for _ in range(len(prediction))]
        # ----------------------------------------------------------------------------#
        #   prediction是有batchsize维度的,故有下面这个循环。
        #	当只有一张图片时,batchsize=1,只操作一次。
        #	i=0; 
        #	image_pred为预测结果信息,shape: torch.size([10647,25])
        #		image_pred[[左上x, 左上y, 右下x, 右下y, 预测分数, 类别0, 类别1, ...], ...]
        # ----------------------------------------------------------------------------#
        for i, image_pred in enumerate(prediction):  
            # --------------------------------------------------------------------------------#
            #   得到种类置信度和对应类别。
            #   class_conf:种类置信度,torch.size([10647,1]):[num_anchors, 1]      
            #		元素范围:0~1,torch.float32
            #   class_pred:所属类别,torch.size([10647,1]):[num_anchors, 1]
            #		元素范围:0~20,torch.int64
            #		为啥这个就是所属类别?回答:torch.max()特点,返回最大值和对应索引,根据索引顺序即可得到类别。
            #	image_pred第5列到第25列存放着每个box针对VOC 20个类别的概率
            #		这个地方涉及到YOLOv3用的是BCELoss,它对每个类别都有预测概率,
            #		且所有类别概率和不为0,关于BCELoss更细节的理解可参考:
            #		https://blog.csdn.net/weixin_45377629/article/details/124006451
            #	torch.max(input, dim, keepdim=False):
            #		dim=0寻找每一列的最大值,dim=1寻找每一行的最大值
            #		keepdim 表示是否需要保持输出的维度与输入一样,keepdim=True表示输出和输入的维度一样,
            #			keepdim=False表示输出的维度被压缩了,也就是输出会比输入低一个维度。
            #		输出:按照规则的最大值,最大值索引
            # --------------------------------------------------------------------------------#
            class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

            # -----------------------------------------------------------------------#
            #   利用置信度进行第一轮筛选,不是谁都配进NMS的。
            #	image_pred第4列预测分数 和 class_conf第0列(class_conf只有一列)种类置信度得分 相乘
            #		得到置信度分数,和设定的阈值进行比较。
            #	conf_mask:内部元素为bool类型,.squeeze()函数使其维度变为torch.size([10647])
            # -----------------------------------------------------------------------#
            conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()

            # ----------------------------------------------------------#
            #   根据置信度进行预测结果的筛选
            #	conf_mask元素值为True的才留下,比如
            #	conf_mask shape: torch.size([29,25])
            #	class_conf shape: torch.size([29,1])
            #	class_pred shape: torch.size([29,1])
            # ----------------------------------------------------------#
            image_pred = image_pred[conf_mask]
            class_conf = class_conf[conf_mask]
            class_pred = class_pred[conf_mask]
            if not image_pred.size(0):
                continue
            # -------------------------------------------------------------------------#
            #   detections  [num_anchors, 7]  堆叠是为了方便下面处理
            #	detections.shape: torch.size([29,7])
            #   7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
            # -------------------------------------------------------------------------#
            detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

            # ----------------------------------------------------------------------------------#
            #   获得预测结果中包含的所有种类
            # 	detections[:, -1]:得到每一个box所属类别
            #	unique()表示以数组形式(np.ndarray)返回列的所有唯一值,可能一张图片中只有三个类别
            #	unique_labels:例如tensor([1., 6., 14.])
            # ----------------------------------------------------------------------------------#
            unique_labels = detections[:, -1].cpu().unique()  

            if prediction.is_cuda:
                unique_labels = unique_labels.cuda()
                detections = detections.cuda()

			# ------------------------------------------#
			#	针对所有类别,挨个来操作
			# ------------------------------------------#
            for c in unique_labels:
                # ------------------------------------------#
                #   获得某一类得分筛选后全部的预测结果
                #   属于这个种类的框筛选出来
                #	detections_class:torch.size([4,7])
                #	举个例子帮助理解:
                #	if __name__=='__main__':
                #		import torch
                #		x = torch.randint(0,3,(3,4))
                #		print(x)
                #		print(x[:,-1]==1)
                #		print(x[x[:,-1]==1])
                """
						tensor([[1, 0, 2, 2],
						        [0, 0, 1, 2],
						        [2, 0, 1, 1]])
						tensor([False, False,  True])
						tensor([[2, 0, 1, 1]])
				"""
                # ------------------------------------------#
                detections_class = detections[detections[:, -1] == c]

                #------------------------------------------#
                #   使用官方自带的非极大抑制会速度更快一些!
                #   也就是from torchvision.ops import nms
                #	nms(boxes, scores, iou_threshold) -> Tensor
                #	detections_class:torch.size([4,7])
            	#   7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
            	#	keep内容:tensor([1])或tensor([14,12,7]),装的是NMS后,所需要的
            	#		box在detections_class中的索引,根据索引就可以取到box了		
            	#	那里面是怎么回事呢?得看源码
                #------------------------------------------#
                # keep = nms(
                #     detections_class[:, :4],
                #     detections_class[:, 4] * detections_class[:, 5],
                #     nms_thres
                # )

                # ------------------------------------------#
                #   使用源码编写!!!
                # ------------------------------------------#
                keep = self.nms_origin_code(
                    detections_class[:, :4],
                    detections_class[:, 4] * detections_class[:, 5],
                    nms_thres
                )
                # ------------------------------------------#
                #	需要的box 参数信息
                #	max_detections shape:torch.size([1,7])或torch.size([3,7])
                # ------------------------------------------#
                max_detections = detections_class[keep]

				# ------------------------------------------#
                # 	每一次的max_detections都添加到outputs里去
                #	output shape:例子:torch.size([4,7])
                # ------------------------------------------#
                output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))

            if output[i] is not None:
                output[i] = output[i].cpu().numpy()
                # -----------------------------------------#
                #	box中心点坐标x,y以及宽高
                #	box_xy与box_wh shape:(9,2)  array  float32
                # -----------------------------------------#
                box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4]) / 2, output[i][:, 2:4] - output[i][:, 0:2]
                # -----------------------------------------#
                # 	前期resize了,得给它还原回去
                # 	若用了letter_box,产生灰条,还原方式有点不同
                #   input_shape:[416,416],网络输入尺寸
                #   image_shape:原来图像尺寸
                # -----------------------------------------#
                output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
        return output


if __name__ == '__main__':
    anchors = [10.0, 13.0, 16.0, 30.0, 33.0, 23.0, 30.0, 61.0, 62.0, 45.0, 59.0, 119.0, 116.0, 90.0, 156.0, 198.0,
               373.0, 326.0]
    # 	anchors: ndarray:(9, 2)
    anchors = np.array(anchors).reshape(-1, 2)
    
    num_classes = 20  # voc类别个数
    anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]	# anchors索引
    input_shape = [416, 416]
    
	#	初始化一个类
    bbox_util = DecodeBox(anchors, num_classes, (input_shape[0], input_shape[1]), anchors_mask)

    # ---------------------------------------------------------#
    #   将图像输入网络当中进行预测!
    # 	YoloBody详解可见https://blog.csdn.net/weixin_45377629/article/details/124080087
    # ---------------------------------------------------------#
    net = YoloBody(anchors_mask, num_classes)  
    outputs = net(images)  # 此地images表示输入图片,outputs为三个输出out0, out1, out2
    
    # -----------------------------------------------------------------------------------#
    # 	得到out0、out1、out2不同尺度下每个网格点上的预测情况(预测框位置、类别概率、置信度分数)
    #	bbox_util.decode_box解读详见https://blog.csdn.net/weixin_45377629/article/details/124144913
    # -----------------------------------------------------------------------------------#
    outputs = bbox_util.decode_box(outputs)  
    
    # ---------------------------------------------------------#
    #   将预测框进行堆叠,然后进行非极大抑制
    #   torch.cat(outputs, 1):将预测框进行堆叠
    #   image_shape:数据集中图片真实尺寸
    #   letterbox_image:True/False,表示是否使用letterbox方式处理数据,
    #		letterbox详细解读可见https://blog.csdn.net/weixin_45377629/article/details/124027705
    #   conf_thres=confidence:表示置信度阈值,只有得分大于置信度的预测框会被保留下来,范围是0~1,常选0.5
    #   nms_thres=nms_iou:非极大抑制所用到的nms_iou阈值大小
    # ---------------------------------------------------------#
    results = bbox_util.non_max_suppression(torch.cat(outputs, 1), num_classes, input_shape,
                                            image_shape, letterbox_image, conf_thres=confidence,
                                            nms_thres=nms_iou)

4 感谢链接

https://www.bilibili.com/video/BV1Hp4y1y788?p=6&spm_id_from=pageDriver
https://www.jianshu.com/p/d452b5615850
https://blog.csdn.net/zouxiaolv/article/details/107400193
https://blog.csdn.net/zylooooooooong/article/details/112576268
https://zhuanlan.zhihu.com/p/422545531

你可能感兴趣的:(目标检测系列,python,深度学习,pytorch,神经网络)