非极大值抑制算法(Non-Maximum Suppression,NMS)

所属知识点:Computer Vision:Object Detection
归纳和总结机器学习技术的库:ViolinLee/ML_notes


关键概念:NMS;IOU;Bounding Box(BBox);Region Proposal or Candidates;

1. NMS 介绍
       在执行目标检测任务时,算法可能对同一目标有多次检测。NMS 是一种让你确保算法只对每个对象得到一个检测的方法,即“清理检测”。如下图所示:

非极大值抑制算法(Non-Maximum Suppression,NMS)_第1张图片

       现实中,在正式使用NMS之前,通常会有一个候选框预清理的工作(简单引入一个置信度阈值),如下图所示:
 

非极大值抑制算法(Non-Maximum Suppression,NMS)_第2张图片

       NMS 算法的大致过程:每轮选取置信度最大的 Bounding Box(简称 BBox,有时也会看到用 Pc,Possible Candidates 代替讲解的) ,接着关注所有剩下的 BBox 中与选取的 BBox 有着高重叠(IOU)的,它们将在这一轮被抑制。这一轮选取的 BBox 会被保留输出,且不会在下一轮出现。接着开始下一轮,重复上述过程:选取置信度最大 BBox ,抑制高 IOU BBox。

       NMS 算法流程:这是一般文章中介绍的 NMS,比较难懂。但实际上 NMS 的实现反而简单很多。

非极大值抑制算法(Non-Maximum Suppression,NMS)_第3张图片

       
       NMS 过程图例:单类别 NMS 的例子:有两只狗,怎样用 NMS 保证只留下两个 BBox?
1)理解 BBox 输入或输出格式,通常会见到两种格式:

  • 第一种,BBox 中心位置(x, y) + BBox 長寬(h, w) + Confidence Score;
  • 第二种,BBox 左上角点(x1,y1) + BBox 右下角点(x2,y2) + Confidence Score;

两种表达的本质是一样的,均为五个变量。与 BBox 相关的四个变量用于计算 IOU,Confidence Score 用于排序。

2)理解评估重叠的 IOU 指标,即“交并比”,如下图所示:

非极大值抑制算法(Non-Maximum Suppression,NMS)_第4张图片

       注意:不同人写的代码,计算重叠的方法可能有差别。下面的 C++ 和 Matlab 实现使用的不是上面定义的IoU,因为分母没有使用面积的并集,而是用其中一者的面积。

3)步骤:
       第一步:对 BBox 按置信度排序,选取置信度最高的 BBox(所以一开始置信度最高的 BBox 一定会被留下来);
       第二步:对剩下的 BBox 和已经选取的 BBox 计算 IOU,淘汰(抑制) IOU 大于设定阈值的 BBox(在图例中这些淘汰的 BBox 的置信度被设定为0)。
       第三步:重复上述两个步骤,直到所有的 BBox 都被处理完,这时候每一轮选取的 BBox 就是最后结果。

非极大值抑制算法(Non-Maximum Suppression,NMS)_第5张图片

       在上面这个例子中,NMS 只运行了两轮就选取出最终结果:第一轮选择了红色 BBox,淘汰了粉色 BBox;第二轮选择了黄色 BBox,淘汰了紫色 BBox 和青色 BBox。注意到这里设定的 IOU 阈值是0.5,假设将阈值提高为0.7,结果又是如何?

非极大值抑制算法(Non-Maximum Suppression,NMS)_第6张图片

       可以看到,NMS 用了更多轮次来确定最终结果,并且最终结果保留了更多的 BBox,但结果并不是我们想要的。因此,在使用 NMS 时,IOU 阈值的确定是比较重要的,但一开始我们可以选定 default 值(论文使用的值)进行尝试。

       NMS 针对多类别的情况:吴恩达在 deeplearning 专项课程中指出,如果有多个分类,正确做法应该是运行多次独立的NMS,每次针对一种输出分类。

2. NMS 代码实现(参考链接见文末,代码做了小改动,使两种语言下的输出效果一致):
C++ 和 OpenCV 实现:
首先实现了简单选择排序函数 void sort()。关于排序算法,推荐参考这篇博文。

#include 
#include 
#include 

using namespace std;
using namespace cv;

static void sort(int n, const vector x, vector indices)
{
// 排序函数,排序后进行交换的是indices中的数据
// n:排序总数// x:待排序数// indices:初始为0~n-1数目 
	
	int i, j;
	for (i = 0; i < n; i++)
		for (j = i + 1; j < n; j++)
		{
			if (x[indices[j]] > x[indices[i]])
			{
				//float x_tmp = x[i];
				int index_tmp = indices[i];
				//x[i] = x[j];
				indices[i] = indices[j];
				//x[j] = x_tmp;
				indices[j] = index_tmp;
			}
		}
}

int nonMaximumSuppression(int numBoxes, const vector points,const vector oppositePoints, 
	const vector score,	float overlapThreshold,int& numBoxesOut, vector& pointsOut,
	vector& oppositePointsOut, vector scoreOut) 
{
// 实现检测出的矩形窗口的非极大值抑制nms
// numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点// score:窗口得分
// overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目// pointsOut:输出窗口左上角坐标点
// oppositePoints:输出窗口右下角坐标点// scoreOut:输出窗口得分
	int i, j, index;
	vector box_area(numBoxes);				// 定义窗口面积变量并分配空间 
	vector indices(numBoxes);					// 定义窗口索引并分配空间 
	vector is_suppressed(numBoxes);			// 定义是否抑制表标志并分配空间 
	// 初始化indices、is_supperssed、box_area信息 
	for (i = 0; i < numBoxes; i++)
	{
		indices[i] = i;
		is_suppressed[i] = 0;
		box_area[i] = (float)( (oppositePoints[i].x - points[i].x + 1) *(oppositePoints[i].y - points[i].y + 1));
	}
	// 对输入窗口按照分数比值进行排序,排序后的编号放在indices中 
	sort(numBoxes, score, indices);
	for (i = 0; i < numBoxes; i++)                // 循环所有窗口 
	{
		if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制 
		{
			for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口 
			{
				if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制 
				{
					int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值 
					int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值 
					int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值 
					int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值 
					int overlapWidth = x2min - x1max + 1;     // 计算两矩形重叠的宽度 
					int overlapHeight = y2min - y1max + 1;     // 计算两矩形重叠的高度 
					if (overlapWidth > 0 && overlapHeight > 0)
					{
						float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率 
						if (overlapPart > overlapThreshold)   // 判断重叠比率是否超过重叠阈值 
						{
							is_suppressed[indices[j]] = 1;     // 将窗口j标记为抑制 
						}
					}
				}
			}
		}
	}
 
	numBoxesOut = 0;    // 初始化输出窗口数目0 
	for (i = 0; i < numBoxes; i++)
	{
		if (!is_suppressed[i]) numBoxesOut++;    // 统计输出窗口数目 
	}
	index = 0;
	for (i = 0; i < numBoxes; i++)            // 遍历所有输入窗口 
	{
		if (!is_suppressed[indices[i]])       // 将未发生抑制的窗口信息保存到输出信息中 
		{
			pointsOut.push_back(Point(points[indices[i]].x,points[indices[i]].y));
			oppositePointsOut.push_back(Point(oppositePoints[indices[i]].x,oppositePoints[indices[i]].y));
			scoreOut.push_back(score[indices[i]]);
			index++;
		}
 
	}
 
	return true;
}

int main()
{
	Mat image = Mat::zeros(600,600,CV_8UC3);
	int numBoxes=4;
	vector points(numBoxes);
	vector oppositePoints(numBoxes);
	vector score(numBoxes);
 
	points[0]=Point(200,200);oppositePoints[0]=Point(400,400);score[0]=0.99;
	points[1]=Point(220,220);oppositePoints[1]=Point(420,420);score[1]=0.9;
	points[2]=Point(100,100);oppositePoints[2]=Point(150,150);score[2]=0.82;
	points[3]=Point(200,240);oppositePoints[3]=Point(400,440);score[3]=0.5;
	
	float overlapThreshold=0.8;
	int numBoxesOut;
	vector pointsOut;
	vector oppositePointsOut;
	vector scoreOut;
 
	nonMaximumSuppression( numBoxes,points,oppositePoints,score,overlapThreshold,numBoxesOut,pointsOut,oppositePointsOut,scoreOut);
	for (int i=0;i

输出效果:

非极大值抑制算法(Non-Maximum Suppression,NMS)_第7张图片

Matlab 实现:

function main()
boxes=[200,200,400,400,0.99;
        220,220,420,420,0.9;
        100,100,150,150,0.82;
        200,240,400,440,0.5];
overlap=0.8;
pick = NMS(boxes, overlap);
figure;
for (i=1:size(boxes,1))
    rectangle('Position',[boxes(i,1),boxes(i,2),boxes(i,3)-boxes(i,1),boxes(i,4)-boxes(i,2)],'EdgeColor','y','LineWidth',6);
    text(boxes(i,1),boxes(i,2),num2str(boxes(i,5)),'FontSize',14,'color','b');
end
for (i=1:size(pick,1))
    rectangle('Position',[boxes(pick(i),1),boxes(pick(i),2),boxes(pick(i),3)-boxes(pick(i),1),boxes(pick(i),4)-boxes(pick(i),2)],'EdgeColor','r','LineWidth',2);
end
axis ij;
axis equal;
axis([0 600 0 600]);
end

function pick = NMS(boxes, overlap)

% pick = nms(boxes, overlap) 
% Non-maximum suppression.
% Greedily select high-scoring detections and skip detections
% that are significantly covered by a previously selected detection.

if isempty(boxes)
  pick = [];
else
  x1 = boxes(:,1);          %所有候选框的左上角顶点x 
  y1 = boxes(:,2);          %所有候选框的左上角顶点y 
  x2 = boxes(:,3);          %所有候选框的右下角顶点x 
  y2 = boxes(:,4);          %所有候选框的右下角顶点y
  s = boxes(:,end);         %所有候选框的置信度,可以包含1列或者多列,用于表示不同准则的置信度
  area = (x2-x1+1) .* (y2-y1+1);%所有候选框的面积

  [vals, I] = sort(s);      %将所有候选框进行从小到大排序,vals为排序后结果,I为排序后标签
  pick = [];
  while ~isempty(I)
    last = length(I);       %last代表标签I的长度,即最后一个元素的位置,(matlab矩阵从1开始计数)
    i = I(last);            %所有候选框的中置信度最高的那个的标签赋值给i
    pick = [pick; i];       %将i存入pick中,pick为一个列向量,保存输出的NMS处理后的box的序号
    suppress = [last];      %将I中最大置信度的标签在I中位置赋值给suppress,suppress作用为类似打标志,
                            %存入suppress,证明该元素处理过
    for pos = 1:last-1      %从1到倒数第二个进行循环
      j = I(pos);           %得到pos位置的标签,赋值给j
      xx1 = max(x1(i), x1(j));%左上角最大的x(求两个方框的公共区域)
      yy1 = max(y1(i), y1(j));%左上角最大的y
      xx2 = min(x2(i), x2(j));%右下角最小的x
      yy2 = min(y2(i), y2(j));%右下角最小的y
      w = xx2-xx1+1;          %公共区域的宽度
      h = yy2-yy1+1;          %公共区域的高度
      if w > 0 && h > 0     %w,h全部>0,证明2个候选框相交
        o = w * h / area(j);%计算overlap比值,即交集占候选框j的面积比例
        if o > overlap      %如果大于设置的阈值就去掉候选框j,因为候选框i的置信度最高
          suppress = [suppress; pos];%大于规定阈值就加入到suppress,证明该元素被处理过
        end
      end
    end
    I(suppress) = [];%将处理过的suppress置为空,当I为空结束循环
  end  
end
end

输出效果:

非极大值抑制算法(Non-Maximum Suppression,NMS)_第8张图片

3. 应用——车辆检测:
       这部分参照吴恩达 deeplearning 专项课程课后编程作业 Car Detection with YOLOv2:先给出结果,详细实现后续附上。

非极大值抑制算法(Non-Maximum Suppression,NMS)_第9张图片


参考和推荐:
非极大值抑制(nonMaximumSuppression):NMS 的 C++ 和 Matlab 实现
機器/深度學習: 物件偵測 Non-Maximum Suppression (NMS):形象易懂地讲解 NMS 
Non-Maximum Suppression:吴恩达关于 NMS 的介绍

你可能感兴趣的:(机器学习(Machine,Learning))