非极大值抑制(NMS)

NMS应用在图像识别方面

相关概述:非极大值抑制(Non-Maximum Suppression,NMS),顾名思义就是抑制不是极大值的元素,可以理解为局部最大搜索。这个局部代表的是一个邻域,邻域有两个参数可变,一是邻域的维数,二是邻域的大小。这里不讨论通用的NMS算法(参考论文《Efficient Non-Maximum Suppression》对1维和2维数据的NMS实现),而是用于目标检测中提取分数最高的窗口的。例如在行人检测中,滑动窗口经提取特征,经分类器分类识别后,每个窗口都会得到一个分数。但是滑动窗口会导致很多窗口与其他窗口存在包含或者大部分交叉的情况。这时就需要用到NMS来选取那些邻域里分数最高(是行人的概率最大),并且抑制那些分数低的窗口。
 

处理流程:

 取出图像框中得分最高的一个

其他框与最大值框进行重叠率(重叠区域面积比例IOU)阈值判断

常见阈值0.3~0.5

大于一定阈值则将该框去除

所有框判断一轮后保存选出得分最高的那个框用来输出

在还存在的框中循环执行操作。

 

box :id score x y w h;

id :框的类别

score:框的得分

(x,y)左下角坐标

w,h 宽高

 

#include 
#include 
#include
#include
using namespace std;
typedef struct {
	float x, y, w, h;
} box;
int max_index(float *a, int n)
{
	if (n <= 0) return -1;
	int i, max_i = 0;
	float max1 = a[0];
	for (i = 1; i < n; ++i) {
		if (a[i] > max1) {
			max1 = a[i];
			max_i = i;
		}
	}
	return max_i;
}
float box_iou(box a, box b)
{
	float y1 = (a.y - a.h / 2.0) > (b.y - b.h / 2.0) ? (a.y - a.h / 2.0) : (b.y - b.h / 2.0);
	float y2 = (a.y + a.h / 2.0) < (b.y + b.h / 2.0) ? (a.y + a.h / 2.0) : (b.y + b.h / 2.0);
	float x1 = max(a.x - a.w / 2.0, b.x - b.w / 2.0);
	float x2 = min(a.x + a.w / 2.0, b.x + b.w / 2.0);
	float area1 = (x2 - x1)*(y2 - y1);
	if ((x2 - x1) < 0.0 || (y2 - y1) < 0.0)
		area1 = 0.0;
	float area2 = a.w*a.h + b.w*b.h - area1;
	return area1 / area2;
}


void do_nms_sort(box* boxes, float** probs, int total, int classes, float thresh) {

	int i, j, k;
	int vis[163840];
	int now_id = 0;
	int max_id = 0;
	float max_score = 0.0;
	for (i = 0; i < total; i++)
		vis[i] = 0;
	i = 0;
	while(i max_score  || vis[now_id] == 1 )
					max_score=probs[j][k],now_id = j, max_id = k;
			}
		}
		vis[now_id] = 1;
		
		i++;
		for (k = 0; k < total; k++) {
			if (vis[k] == 0 && box_iou(boxes[now_id], boxes[k]) > thresh) {
				if (probs[k][max_id] > 0)
					probs[k][max_id] = 0, vis[k] = 1, i++;
			}

		}

	}



}

int main() {
	int total = 163840;
	int classes = 20;

	box *boxes = NULL;
	if (boxes == NULL)boxes = (box*)malloc(total * sizeof(box));

	float **probs = NULL;
	if (probs == NULL)probs = (float**)malloc(total * sizeof(float **));

	int i = 0;
	for (i = 0; i < total; ++i) {
		probs[i] = NULL;
		if (probs[i] == NULL)probs[i] = (float*)malloc(classes * sizeof(float *));
	}

	i = 0;
	int id;
	float score, x, y, w, h;
	while (scanf("%d %f %f %f %f %f\n", &id, &score, &x, &y, &w, &h) != EOF) {
		boxes[i].x = x;
		boxes[i].y = y;
		boxes[i].w = w;
		boxes[i].h = h;
		probs[i][id] = score;
		i++;
	}
	total = i;
	float nms_thresh = 0.4;
	do_nms_sort(boxes, probs, total, classes, nms_thresh);

	for (i = 0; i < total; ++i) {
		int class1 = max_index(probs[i], classes);
		float prob = probs[i][class1];
		if (prob < .24)continue;
		printf("%d %f %f %f %f %f\n", class1, prob, boxes[i].x, boxes[i].y, boxes[i].w, boxes[i].h);
	}
	if (boxes) {
		free(boxes);
		boxes = NULL;
	}
	if (probs) {
		for (i = 0; i < total; ++i) {
			free(probs[i]);
			probs[i] = NULL;
		}
		free(probs);
		probs = NULL;
	}
	return 0;
}

 

你可能感兴趣的:(非极大值抑制(NMS))